diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 68b5f41438..bd5ed0c3b5 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -371,12 +371,81 @@ def test_slice_gemlite(self, device, dtype): # in_feature not divisible by 1024 # out_feature not divisible by 8 # to test slice + padding for int4 weight only quantization - dummy = nn.Linear(256, 512, dtype=dtype, device=device) - quantize_(dummy, GemliteUIntXWeightOnlyConfig()) + in_features, out_features, group_size, bit_width = 256, 512, 64, 4 + orig_shape = [out_features, in_features] + dummy = nn.Linear( + in_features, out_features, bias=False, dtype=dtype, device=device + ) + quantize_( + dummy, + GemliteUIntXWeightOnlyConfig(bit_width=bit_width, group_size=group_size), + ) + W_group_mode = dummy.weight.tensor_impl.gemlite_kwargs["meta_args"][10] + # make sure these run without error _ = dummy.weight.narrow(0, 0, 64) _ = dummy.weight.narrow(1, 0, 128) + # Dequant op + import gemlite + + def dequant(input_layer, in_features, orig_shape): + int_data = input_layer.tensor_impl.packed_weight + scale = input_layer.tensor_impl.scale + zero_point = input_layer.tensor_impl.zero_point + + W_q = ( + gemlite.bitpack.unpack_over_rows( + int_data, + W_nbits=bit_width, + num_output_rows=in_features, + dtype=torch.uint8, + ) + .T.contiguous() + .view([-1, group_size]) + ) + + s = scale.t().contiguous().view(-1, 1) + z = zero_point.t().contiguous().view(-1, 1) + + if W_group_mode == 4: # FMA + W_deq = (W_q * s + z).view(orig_shape) + else: + W_deq = ((W_q - z) * s).view(orig_shape) + + return W_deq + + W_r = dequant(dummy.weight, dummy.in_features, orig_shape) + + # Slicing in half + for slice_axis, start, end in [ + (0, 0, 256), + (0, 256, 256), + (1, 0, 128), + (1, 128, 128), + ]: + layer_sliced = dummy.weight.narrow(slice_axis, start, end) + + if slice_axis == 0: + num_rows, out_shape = ( + dummy.in_features, + (orig_shape[0] // 2, orig_shape[1]), + ) + else: + num_rows, out_shape = ( + dummy.in_features // 2, + (orig_shape[0], orig_shape[1] // 2), + ) + + W_slice = dequant(layer_sliced, num_rows, out_shape) + + W_slice_ref = ( + W_r[start : start + end, :] + if slice_axis == 0 + else W_r[:, start : start + end] + ) + self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0) + @common_utils.parametrize("device", ["cuda"]) @common_utils.parametrize("dtype", [torch.bfloat16]) def test_matmul(self, device, dtype): diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 1c840f7ec4..eb06cf2a96 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -25,7 +25,6 @@ except: gemlite = None - aten = torch.ops.aten @@ -35,7 +34,12 @@ def _same_metadata( ) -> bool: kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) for k, v in self.gemlite_kwargs.items(): - if k != "scale_activations": + if k in [ + "in_features", + "out_features", + "packing_bitwidth", + "elements_per_sample", + ]: kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) return ( @@ -80,6 +84,7 @@ def get_gemlite_aqt_kwargs( weight, group_size=64, bit_width=4, + packing_bitwidth=None, use_hqq=True, ): if gemlite is None: @@ -99,6 +104,9 @@ def get_gemlite_aqt_kwargs( assert group_size is None or bit_width != 8, ( "gemlite only works with group_size=None for bit_width=8" ) + assert packing_bitwidth in [8, 16, 32, None], ( + f"Invalid packing bitwidth, got {packing_bitwidth}" + ) out_features, in_features = weight.shape group_size = in_features if group_size is None else group_size @@ -107,6 +115,7 @@ def get_gemlite_aqt_kwargs( aqt_kwargs["_layout"] = GemlitePackedLayout( group_size=group_size, bit_width=bit_width, + packing_bitwidth=packing_bitwidth, ) aqt_kwargs["use_hqq"] = use_hqq return aqt_kwargs @@ -114,8 +123,9 @@ def get_gemlite_aqt_kwargs( @dataclass(frozen=True) class GemlitePackedLayout(Layout): - group_size: Optional[int] = 64 + group_size: Optional[int] = 128 bit_width: int = 4 + packing_bitwidth: Optional[int] = None @register_layout(GemlitePackedLayout) @@ -191,24 +201,36 @@ def from_plain( group_size, bit_width = _layout.group_size, _layout.bit_width out_features, in_features = int_data.shape + packing_bitwidth = _layout.packing_bitwidth if bit_width == 8 and group_size == in_features: gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights( int_data, scales=scale, bias=None ) else: - gemlite_linear = gemlite.helper.A16Wn(device=int_data.device).from_weights( + gemlite_linear = gemlite.helper.A16Wn( + device=int_data.device, packing_bitwidth=packing_bitwidth + ).from_weights( int_data, scale, zero_point, bit_width, group_size, bias=None ) + meta_args = gemlite_linear.get_meta_args() gemlite_kwargs = { "in_features": in_features, "out_features": out_features, - "meta_args": gemlite_linear.get_meta_args(), + "packing_bitwidth": packing_bitwidth, + "data_contiguous": gemlite_linear.data_contiguous, + "elements_per_sample": gemlite_linear.elements_per_sample, + "W_group_mode": gemlite_linear.W_group_mode, + "meta_args": meta_args, } packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() packed_weight = packed_weight.to(device) + if zero_point is None: + zero_point = torch.tensor( + [[]], device=packed_weight.device, dtype=torch.int32 + ) return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) @@ -235,18 +257,39 @@ def _apply_fn_to_data(self, fn): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = self.packed_weight.device int_data = ( - gemlite.bitpack.unpack_over_rows( - self.packed_weight.cuda(), - W_nbits=self._layout.bit_width, - num_output_rows=self.gemlite_kwargs["out_features"], - dtype=torch.uint8, + ( + gemlite.bitpack.unpack_over_rows( + self.packed_weight.cuda(), + W_nbits=self._layout.bit_width, + num_output_rows=self.gemlite_kwargs["in_features"], + dtype=torch.uint8, + ) ) + .to(device) .t() - .contiguous() - ).to(device) + ) + + # Preserve col-row major layout + if self.gemlite_kwargs["data_contiguous"]: + int_data = int_data.contiguous() + + # Handle FMA mode: W_q * s + z -> (W_q - z) * s + if self.gemlite_kwargs["W_group_mode"] == 4: + scale_min_val = 1e-8 + scale = self.scale.clone().float() + scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = ( + scale_min_val + ) + scale[ + torch.logical_and(scale < 0, scale.abs() <= scale_min_val) + ] = -scale_min_val + zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100) + zero_point = zero_point.to(self.scale.dtype) + else: + zero_point = self.zero_point scale = self.scale.t().contiguous() - zero_point = self.zero_point.t().contiguous() + zero_point = zero_point.t().contiguous() return int_data, scale, zero_point @@ -274,14 +317,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs): assert step == 1, "Only step == 1 is supported in slicing right now" if dim in [0, 1]: - int_data, scale, zero_point = self.get_plain() - data_len = int_data.shape[dim] + # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T + dim = 1 - dim + packed_weight = self.packed_weight + scale = self.scale + zero_point = self.zero_point + + gemlite_kwargs = self.gemlite_kwargs.copy() + orig_shape = [ + gemlite_kwargs["in_features"], + gemlite_kwargs["out_features"], + ] + elements_per_sample = gemlite_kwargs["elements_per_sample"] + data_len = orig_shape[dim] scale_len = scale.shape[dim] ratio = data_len / scale_len start_scale = int(start / ratio) end_scale = int(end / ratio) - int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # For packing only the K dimension. This should be flipped for N-dim packing. + div = elements_per_sample if dim == 0 else 1 + packed_weight = aten.slice.Tensor( + packed_weight, dim, start // div, end // div, step + ) + + # Update in_features/out_features + gemlite_kwargs["in_features"] = ( + packed_weight.shape[0] * elements_per_sample + ) + gemlite_kwargs["out_features"] = packed_weight.shape[1] + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) if zero_point is not None and zero_point.numel() > 0: zero_point = aten.slice.Tensor( @@ -289,15 +354,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) else: zero_point = None - # this is to handle padding - int_data, scale, zero_point = self._layout.post_process( - int_data, scale, zero_point, self.block_size - ) - - sliced = self.from_plain( - int_data, scale, zero_point, self._layout - ) # Will be transposed again + sliced = GemliteAQTTensorImpl( + packed_weight, scale, zero_point, gemlite_kwargs, self._layout + ) return return_and_correct_aliasing(func, args, kwargs, sliced) else: @@ -308,10 +368,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.copy_.default: self = args[0] src = args[1] + + # Handle zero_point = None with symmetric quant + if self.zero_point is None: + self.zero_point = torch.tensor( + [[]], device=self.packed_weight.device, dtype=torch.int32 + ) + + if src.zero_point is None: + src.zero_point = torch.tensor( + [[]], device=src.packed_weight.device, dtype=torch.int32 + ) + if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + for key in self.gemlite_kwargs: + self.gemlite_kwargs[key] = src.gemlite_kwargs[key] return raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 41ea588231..998204c8fe 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -741,11 +741,11 @@ def from_float(cls, weight): weight = weight.to(torch.float16) bit_width = 4 - packing_bitwidth = 32 - contiguous = None + packing_bitwidth = None use_hqq = True + aqt_kwargs = get_gemlite_aqt_kwargs( - weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq + weight, cls.group_size, bit_width, packing_bitwidth, use_hqq ) weight = to_affine_quantized_intx(weight, **aqt_kwargs) input_quant_func = _to_float16 diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 56229b0d27..be25b144a6 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -990,8 +990,9 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig): `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. """ - group_size: Optional[int] = 64 + group_size: Optional[int] = 128 bit_width: int = 4 + packing_bitwidth: Optional[int] = None set_inductor_config: bool = True @@ -1005,6 +1006,7 @@ def _gemlite_uintx_weight_only_transform( ): group_size = config.group_size bit_width = config.bit_width + packing_bitwidth = config.packing_bitwidth if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1015,7 +1017,9 @@ def _gemlite_uintx_weight_only_transform( use_hqq = True if bit_width == 4 else False new_weight = to_affine_quantized_intx( weight, - **get_gemlite_aqt_kwargs(weight, group_size, bit_width, use_hqq), + **get_gemlite_aqt_kwargs( + weight, group_size, bit_width, packing_bitwidth, use_hqq + ), ) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module)