Skip to content

Fix slicing and get_plain() in GemLite #2288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,81 @@ def test_slice_gemlite(self, device, dtype):
# in_feature not divisible by 1024
# out_feature not divisible by 8
# to test slice + padding for int4 weight only quantization
dummy = nn.Linear(256, 512, dtype=dtype, device=device)
quantize_(dummy, GemliteUIntXWeightOnlyConfig())
in_features, out_features, group_size, bit_width = 256, 512, 64, 4
orig_shape = [out_features, in_features]
dummy = nn.Linear(
in_features, out_features, bias=False, dtype=dtype, device=device
)
quantize_(
dummy,
GemliteUIntXWeightOnlyConfig(bit_width=bit_width, group_size=group_size),
)
W_group_mode = dummy.weight.tensor_impl.gemlite_kwargs["meta_args"][10]

# make sure these run without error
_ = dummy.weight.narrow(0, 0, 64)
_ = dummy.weight.narrow(1, 0, 128)

# Dequant op
import gemlite

def dequant(input_layer, in_features, orig_shape):
int_data = input_layer.tensor_impl.packed_weight
scale = input_layer.tensor_impl.scale
zero_point = input_layer.tensor_impl.zero_point

W_q = (
gemlite.bitpack.unpack_over_rows(
int_data,
W_nbits=bit_width,
num_output_rows=in_features,
dtype=torch.uint8,
)
.T.contiguous()
.view([-1, group_size])
)

s = scale.t().contiguous().view(-1, 1)
z = zero_point.t().contiguous().view(-1, 1)

if W_group_mode == 4: # FMA
W_deq = (W_q * s + z).view(orig_shape)
else:
W_deq = ((W_q - z) * s).view(orig_shape)

return W_deq

W_r = dequant(dummy.weight, dummy.in_features, orig_shape)

# Slicing in half
for slice_axis, start, end in [
(0, 0, 256),
(0, 256, 256),
(1, 0, 128),
(1, 128, 128),
]:
layer_sliced = dummy.weight.narrow(slice_axis, start, end)

if slice_axis == 0:
num_rows, out_shape = (
dummy.in_features,
(orig_shape[0] // 2, orig_shape[1]),
)
else:
num_rows, out_shape = (
dummy.in_features // 2,
(orig_shape[0], orig_shape[1] // 2),
)

W_slice = dequant(layer_sliced, num_rows, out_shape)

W_slice_ref = (
W_r[start : start + end, :]
if slice_axis == 0
else W_r[:, start : start + end]
)
self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0)

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16])
def test_matmul(self, device, dtype):
Expand Down
122 changes: 98 additions & 24 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
except:
gemlite = None


aten = torch.ops.aten


Expand All @@ -35,7 +34,12 @@ def _same_metadata(
) -> bool:
kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs)
for k, v in self.gemlite_kwargs.items():
if k != "scale_activations":
if k in [
"in_features",
"out_features",
"packing_bitwidth",
"elements_per_sample",
]:
kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k])

return (
Expand Down Expand Up @@ -80,6 +84,7 @@ def get_gemlite_aqt_kwargs(
weight,
group_size=64,
bit_width=4,
packing_bitwidth=None,
use_hqq=True,
):
if gemlite is None:
Expand All @@ -99,6 +104,9 @@ def get_gemlite_aqt_kwargs(
assert group_size is None or bit_width != 8, (
"gemlite only works with group_size=None for bit_width=8"
)
assert packing_bitwidth in [8, 16, 32, None], (
f"Invalid packing bitwidth, got {packing_bitwidth}"
)

out_features, in_features = weight.shape
group_size = in_features if group_size is None else group_size
Expand All @@ -107,15 +115,17 @@ def get_gemlite_aqt_kwargs(
aqt_kwargs["_layout"] = GemlitePackedLayout(
group_size=group_size,
bit_width=bit_width,
packing_bitwidth=packing_bitwidth,
)
aqt_kwargs["use_hqq"] = use_hqq
return aqt_kwargs


@dataclass(frozen=True)
class GemlitePackedLayout(Layout):
group_size: Optional[int] = 64
group_size: Optional[int] = 128
bit_width: int = 4
packing_bitwidth: Optional[int] = None


@register_layout(GemlitePackedLayout)
Expand Down Expand Up @@ -191,24 +201,36 @@ def from_plain(

group_size, bit_width = _layout.group_size, _layout.bit_width
out_features, in_features = int_data.shape
packing_bitwidth = _layout.packing_bitwidth

if bit_width == 8 and group_size == in_features:
gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights(
int_data, scales=scale, bias=None
)
else:
gemlite_linear = gemlite.helper.A16Wn(device=int_data.device).from_weights(
gemlite_linear = gemlite.helper.A16Wn(
device=int_data.device, packing_bitwidth=packing_bitwidth
).from_weights(
int_data, scale, zero_point, bit_width, group_size, bias=None
)

meta_args = gemlite_linear.get_meta_args()
gemlite_kwargs = {
"in_features": in_features,
"out_features": out_features,
"meta_args": gemlite_linear.get_meta_args(),
"packing_bitwidth": packing_bitwidth,
"data_contiguous": gemlite_linear.data_contiguous,
"elements_per_sample": gemlite_linear.elements_per_sample,
"W_group_mode": gemlite_linear.W_group_mode,
"meta_args": meta_args,
}

packed_weight, scale, zero_point = gemlite_linear.get_tensor_args()
packed_weight = packed_weight.to(device)
if zero_point is None:
zero_point = torch.tensor(
[[]], device=packed_weight.device, dtype=torch.int32
)

return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout)

Expand All @@ -235,18 +257,39 @@ def _apply_fn_to_data(self, fn):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device = self.packed_weight.device
int_data = (
gemlite.bitpack.unpack_over_rows(
self.packed_weight.cuda(),
W_nbits=self._layout.bit_width,
num_output_rows=self.gemlite_kwargs["out_features"],
dtype=torch.uint8,
(
gemlite.bitpack.unpack_over_rows(
self.packed_weight.cuda(),
W_nbits=self._layout.bit_width,
num_output_rows=self.gemlite_kwargs["in_features"],
dtype=torch.uint8,
)
)
.to(device)
.t()
.contiguous()
).to(device)
)

# Preserve col-row major layout
if self.gemlite_kwargs["data_contiguous"]:
int_data = int_data.contiguous()

# Handle FMA mode: W_q * s + z -> (W_q - z) * s
if self.gemlite_kwargs["W_group_mode"] == 4:
scale_min_val = 1e-8
scale = self.scale.clone().float()
scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = (
scale_min_val
)
scale[
torch.logical_and(scale < 0, scale.abs() <= scale_min_val)
] = -scale_min_val
zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100)
zero_point = zero_point.to(self.scale.dtype)
else:
zero_point = self.zero_point

scale = self.scale.t().contiguous()
zero_point = self.zero_point.t().contiguous()
zero_point = zero_point.t().contiguous()

return int_data, scale, zero_point

Expand Down Expand Up @@ -274,30 +317,47 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
assert step == 1, "Only step == 1 is supported in slicing right now"

if dim in [0, 1]:
int_data, scale, zero_point = self.get_plain()
data_len = int_data.shape[dim]
# data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T
dim = 1 - dim
packed_weight = self.packed_weight
scale = self.scale
zero_point = self.zero_point

gemlite_kwargs = self.gemlite_kwargs.copy()
orig_shape = [
gemlite_kwargs["in_features"],
gemlite_kwargs["out_features"],
]
elements_per_sample = gemlite_kwargs["elements_per_sample"]
data_len = orig_shape[dim]
scale_len = scale.shape[dim]
ratio = data_len / scale_len
start_scale = int(start / ratio)
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# For packing only the K dimension. This should be flipped for N-dim packing.
div = elements_per_sample if dim == 0 else 1
packed_weight = aten.slice.Tensor(
packed_weight, dim, start // div, end // div, step
)

# Update in_features/out_features
gemlite_kwargs["in_features"] = (
packed_weight.shape[0] * elements_per_sample
)
gemlite_kwargs["out_features"] = packed_weight.shape[1]

scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
if zero_point is not None and zero_point.numel() > 0:
zero_point = aten.slice.Tensor(
zero_point, dim, start_scale, end_scale, step
)
else:
zero_point = None
# this is to handle padding
int_data, scale, zero_point = self._layout.post_process(
int_data, scale, zero_point, self.block_size
)

sliced = self.from_plain(
int_data, scale, zero_point, self._layout
) # Will be transposed again

sliced = GemliteAQTTensorImpl(
packed_weight, scale, zero_point, gemlite_kwargs, self._layout
)
return return_and_correct_aliasing(func, args, kwargs, sliced)

else:
Expand All @@ -308,10 +368,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
elif func is aten.copy_.default:
self = args[0]
src = args[1]

# Handle zero_point = None with symmetric quant
if self.zero_point is None:
self.zero_point = torch.tensor(
[[]], device=self.packed_weight.device, dtype=torch.int32
)

if src.zero_point is None:
src.zero_point = torch.tensor(
[[]], device=src.packed_weight.device, dtype=torch.int32
)

if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
for key in self.gemlite_kwargs:
self.gemlite_kwargs[key] = src.gemlite_kwargs[key]
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
Expand Down
6 changes: 3 additions & 3 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,11 +741,11 @@ def from_float(cls, weight):
weight = weight.to(torch.float16)

bit_width = 4
packing_bitwidth = 32
contiguous = None
packing_bitwidth = None
use_hqq = True

aqt_kwargs = get_gemlite_aqt_kwargs(
weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq
weight, cls.group_size, bit_width, packing_bitwidth, use_hqq
)
weight = to_affine_quantized_intx(weight, **aqt_kwargs)
input_quant_func = _to_float16
Expand Down
8 changes: 6 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,8 +990,9 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
"""

group_size: Optional[int] = 64
group_size: Optional[int] = 128
bit_width: int = 4
packing_bitwidth: Optional[int] = None
set_inductor_config: bool = True


Expand All @@ -1005,6 +1006,7 @@ def _gemlite_uintx_weight_only_transform(
):
group_size = config.group_size
bit_width = config.bit_width
packing_bitwidth = config.packing_bitwidth
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

Expand All @@ -1015,7 +1017,9 @@ def _gemlite_uintx_weight_only_transform(
use_hqq = True if bit_width == 4 else False
new_weight = to_affine_quantized_intx(
weight,
**get_gemlite_aqt_kwargs(weight, group_size, bit_width, use_hqq),
**get_gemlite_aqt_kwargs(
weight, group_size, bit_width, packing_bitwidth, use_hqq
),
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
Expand Down
Loading