diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py new file mode 100644 index 0000000000..9b9b53d5aa --- /dev/null +++ b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py @@ -0,0 +1,375 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from enum import auto, Enum + +import logging +from typing import List, Optional, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.affine_quantized_tensor import ( + AQTTensorImpl, + register_aqt_quantized_linear_dispatch, + register_layout, +) +from torchao.dtypes.utils import Layout +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + +from torchao.quantization.quant_api import to_affine_quantized_intx + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +class Target(Enum): + """Enum that indicates the backend target + """ + NATIVE = auto() + FALLBACK = auto() + +def target_from_str(target: str) -> Target: + if target.lower() == "native": + return Target.NATIVE + elif target.lower() == "fallback": + return Target.FALLBACK + else: + raise ValueError(f"Invalid target: {target}") + + +# This format is intended for use with int8 dynamic quantization +class Linear8BitActXBitWeightLayout(Layout): + nbit: int + group_size: int + + # The target platform for the layout, either 'native' or 'fallback'. + target: Target + + def __init__( + self, + nbit: int, + group_size: int, + target: str, + ): + assert nbit <= 7 + self.nbit = nbit + self.group_size = group_size + self.target = target_from_str(target) + + def extra_repr(self): + return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}" + + +def _pack_weights_native( + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout: Layout, +): + assert isinstance(layout, Linear8BitActXBitWeightLayout) + assert layout.target == Target.NATIVE + nbit = layout.nbit + group_size = layout.group_size + has_weight_zeros = zero_point is not None + + if has_weight_zeros: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + zero_point.reshape(-1).to(torch.int8), + torch.empty(0, group_size, dtype=torch.int8), + ] + else: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + torch.empty(0, group_size, dtype=torch.int8), + ] + + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")( + *args + ) + + +@register_layout(Linear8BitActXBitWeightLayout) +class Linear8BitActXBitWeightAQTTensorImpl(AQTTensorImpl): + def __new__( + cls, + packed_weight: torch.Tensor, + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["dtype"] = packed_weight.dtype + assert not packed_weight.requires_grad + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Linear8BitActXBitWeightLayout) + + # In the native case, scale and zero_point information is inside + # the packed_weight + if _layout.target == Target.NATIVE: + assert scale is None + assert zero_point is None + + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __repr__(self): + layout = self.get_layout() + return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout={layout})" + + def get_layout(self) -> Layout: + return self._layout + + def get_plain(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if self.get_layout().target == Target.FALLBACK: + return self.packed_weight, self.scale, self.zero_point + raise NotImplementedError("get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback") + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout: Layout, + ): + assert isinstance(layout, Linear8BitActXBitWeightLayout) + + try: + if layout.target == Target.NATIVE: + packed_weight = _pack_weights_native( + int_data, scale, zero_point, layout + ) + scale = None + zero_point = None + return cls(packed_weight, scale, zero_point, layout) + except Exception as e: + logger.warning( + f"A failure occurred when packing weights with Linear8BitActXBitWeightLayout.target={layout.target}: {e}\n" + + "Falling back to **slow** implementation Linear8BitActXBitWeightLayout.target=fallback." + ) + layout.target = Target.FALLBACK + + # Fallback + assert layout.target == Target.FALLBACK + packed_weight = int_data.to(torch.int8) + return cls(packed_weight, scale, zero_point, layout) + + def _apply_fn_to_data(self, fn): + self.packed_weight = fn(self.packed_weight) + if self.scale is not None: + self.scale = fn(self.scale) + + if self.zero_point is not None: + self.zero_point = fn(self.zero_point) + return self + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is torch.ops.aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + raise NotImplementedError( + f"Linear8BitActXBitWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + if self.get_layout().target == Target.NATIVE: + return ["packed_weight"], [self.get_layout()] + + # fallback + assert self.get_layout().target == Target.FALLBACK + if self.zero_point is None: + return ["packed_weight", "scale"], [self.get_layout()] + return ["packed_weight", "scale", "zero_point"], [self.get_layout()] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale, zero_point = ( + tensor_data_dict["packed_weight"], + tensor_data_dict.get("scale", None), + tensor_data_dict.get("zero_point", None), + ) + (layout,) = tensor_attributes + return cls(packed_weight, scale, zero_point, layout) + + +def _linear_int8_dynamic_activation_intx_weight_check( + input_tensor, weight_tensor, bias +): + layout = weight_tensor.tensor_impl.get_layout() + return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None + + +def _linear_int8_dynamic_activation_intx_weight_fallback_impl( + input_tensor, weight_tensor, bias +): + assert weight_tensor.tensor_impl.get_layout().target == Target.FALLBACK + assert bias is None + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + weight_qvals = weight_tensor.tensor_impl.packed_weight.to(torch.int32) + weight_scales = weight_tensor.tensor_impl.scale + weight_zeros = weight_tensor.tensor_impl.zero_point + group_size = weight_tensor.tensor_impl.get_layout().group_size + has_weight_zeros = weight_zeros is not None + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + + weights_dequantized = weight_tensor.dequantize() + + # Quantize activations + activations_dequantized = to_affine_quantized_intx( + input_tensor, + mapping_type=MappingType.ASYMMETRIC, + block_size=(1, k), + target_dtype=torch.int32, + quant_min=-128, + quant_max=127, + eps=0.0, + zero_point_dtype=torch.int32, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.INT, + use_hqq=False, + ).dequantize() + + return torch.matmul( + activations_dequantized, weights_dequantized.transpose(1, 0) + ) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + + return res + + +def _linear_int8_dynamic_activation_intx_weight_native_impl( + input_tensor, weight_tensor, bias +): + assert weight_tensor.tensor_impl.get_layout().target == Target.NATIVE + assert bias is None + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.tensor_impl.get_layout().group_size + packed_weight = weight_tensor.tensor_impl.packed_weight + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + args = ( + input_tensor, + packed_weight, + torch.empty(0, group_size, dtype=torch.int8), + torch.empty(0, n, dtype=torch.int8), + torch.empty(0, k, dtype=torch.int8), + ) + + has_weight_zeros = (weight_tensor.zero_point_domain != ZeroPointDomain.NONE) + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + assert group_size == weight_tensor.tensor_impl.get_layout().group_size + nbit = weight_tensor.tensor_impl.get_layout().nbit + + n, k = weight_tensor.shape + m, k_ = input_tensor.shape + assert k_ == k + + packed_weight = weight_tensor.tensor_impl.packed_weight + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight" + )(*args) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + return res + + +def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias): + target = weight_tensor.tensor_impl.get_layout().target + if target == Target.NATIVE: + return _linear_int8_dynamic_activation_intx_weight_native_impl( + input_tensor, weight_tensor, bias + ) + + if target == Target.FALLBACK: + return _linear_int8_dynamic_activation_intx_weight_fallback_impl( + input_tensor, weight_tensor, bias + ) + + assert False, f"Unknown target {target}" + + +register_aqt_quantized_linear_dispatch( + _linear_int8_dynamic_activation_intx_weight_check, + _linear_int8_dynamic_activation_intx_weight_impl, +) diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index e7e3cddf01..c1bfa5c32a 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -45,6 +45,31 @@ quantized_model = linear_quantizer.quantize(quantized_model) If you get stuck on the above steps, working examples for both linear and embedding are in torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py and torchao/experimental/tests/test_embedding_xbit_quantizer.py. For example, running `python tests/test_linear_8bit_act_xbit_weight_quantizer.py` loads the ops, creates a toy model, quantizes the model, and runs it in eager, compile, AOTI, and exports the model. +### Subclass API + +For linear, you can also use the new subclass API in torchao. + +```python +import torch +torch.ops.load_library("cmake-out/lib/libtorchao_ops_aten.dylib") # make sure this path is correct on your machine + +my_model = Model() + +from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight +from torchao.quantization.quant_api import quantize_ +quantize_( + my_model, + int8_dynamic_activation_intx_weight( + group_size=256, + nbit=4, + has_weight_zeros=False, + ), +) +``` + +If you get stuck, consult +`tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py`. + ## Available in torchchat TorchAO experimental kernels are [available in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels), PyTorch's solution for running LLMs locally. Torchchat integration uses similar steps to above. diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 18b18357bc..1c04305d31 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -463,3 +463,56 @@ def quantize(self, model: nn.Module) -> nn.Module: }, ) return model + + +from torchao.experimental._linear_8bit_act_xbit_weight_layout import Linear8BitActXBitWeightLayout +from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + MappingType, + to_affine_quantized_intx, + ZeroPointDomain, +) + + +def int8_dynamic_activation_intx_weight( + group_size: int = 128, + nbit: int = 4, + has_weight_zeros: bool = False, + target: str = "native", +): + + def apply(weight): + assert weight.shape[-1] % group_size == 0 + assert weight.device == torch.device("cpu"), "Only CPU is supported" + use_hqq = False + layout = Linear8BitActXBitWeightLayout( + nbit=nbit, group_size=group_size, target=target + ) + mapping_type = MappingType.ASYMMETRIC + eps = torch.finfo(torch.float32).eps + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = -(1 << (nbit - 1)) + quant_max = (1 << (nbit - 1)) - 1 + zero_point_dtype = torch.int8 + preserve_zero = has_weight_zeros + zero_point_domain = ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE + # Note: this works differently than other quantizers because the dynamic + # activation quantization is fused with the kernel/op (and static activation quantization + # is not supported). + return to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=layout, + use_hqq=use_hqq, + ) + + return _get_linear_subclass_inserter(apply) diff --git a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py b/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py new file mode 100644 index 0000000000..44e63386ce --- /dev/null +++ b/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import glob +import os +import subprocess + +import sys +import tempfile +import unittest + +import torch + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) + +from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight +from torchao.quantization.quant_api import quantize_ + +from torchao.utils import unwrap_tensor_subclass +from torchao.experimental.quant_api import ( + _Int8DynActIntxWeightQuantizedLinearFallback, +) + +def cmake_build_torchao_ops(temp_build_dir): + from distutils.sysconfig import get_python_lib + + print("Building torchao ops for ATen target") + cmake_prefix_path = get_python_lib() + dir_path = os.path.dirname(os.path.realpath(__file__)) + subprocess.run( + [ + "cmake", + "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, + "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, + "-S " + dir_path + "/../", + "-B " + temp_build_dir.name, + ] + ) + subprocess.run( + [ + "cmake", + "--build", + temp_build_dir.name, + "-j 16", + "--target install", + "--config Release", + ] + ) + + +temp_build_dir = tempfile.TemporaryDirectory() +cmake_build_torchao_ops(temp_build_dir) +libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +assert len(libs) == 1 +torch.ops.load_library(libs[0]) + + +class TestInt8DynamicActivationIntxWeight(unittest.TestCase): + def test_accuracy(self): + group_size = 128 + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k, dtype=torch.float32) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for nbit in [1, 2, 3, 4, 5, 6, 7]: + for has_weight_zeros in [True, False]: + print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + target="fallback", + ), + ) + + with torch.no_grad(): + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + #TODO: remove expected_result2 checks when we deprecate non-subclass API + reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() + reference_impl.quantize_and_pack_weights( + model[0].weight, nbit, group_size, has_weight_zeros + ) + expected_result2 = reference_impl(activations) + + num_mismatch_at_low_tol = 0 + num_mismatch_at_low_tol2 = 0 + num_total = result.reshape(-1).shape[0] + for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + expected_val2 = expected_result2.reshape(-1)[i] + self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + self.assertTrue(torch.allclose(expected_val, expected_val2, atol=1e-2, rtol=1e-1)) + if not torch.allclose(expected_val, expected_val2): + num_mismatch_at_low_tol2 += 1 + + # Assert at most 5% of entries are not close at a low tolerance + self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + self.assertTrue(num_mismatch_at_low_tol2 / num_total <= 0.01) + + def test_export_compile_aoti(self): + group_size = 32 + m = 3 + k0 = 512 + k1 = 256 + k2 = 128 + k3 = 1024 + nbit = 4 + has_weight_zeros = True + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=False), + torch.nn.Linear(k2, k3, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + + print("Quantizing model") + quantize_( + model, + int8_dynamic_activation_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + target="native", + ), + ) + + unwrapped_model = copy.deepcopy(model) + unwrap_tensor_subclass(model) + + print("Exporting quantized model") + exported = torch.export.export(model, (activations,)) + + print("Compiling quantized model") + compiled = torch.compile(unwrapped_model) + with torch.no_grad(): + compiled(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + print("Exporting quantized model with AOTI") + torch._export.aot_compile( + model, + (activations,), + options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + ) + + print("Running quantized model in AOTI") + fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") + fn(activations) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index dfd3bcaad8..0c1bfeffbd 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -60,9 +60,11 @@ class ZeroPointDomain(Enum): integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + none domain: quantized_val = (float_val / scale) """ INT = auto() FLOAT = auto() + NONE = auto() class TorchAODType(Enum): """ @@ -344,6 +346,9 @@ def _quantize_affine_no_dtype_cast( quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ) + elif zero_point_domain == ZeroPointDomain.NONE.name: + assert zero_point is None, "zero_point should be None when zero_point_domain is NONE" + quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) elif zero_point_domain is None: # This case handles quantization for float8 we expect no zero point and no zero point domain assert zero_point is None, "zero_point should be None when zero_point_domain is None" @@ -477,6 +482,10 @@ def _dequantize_affine_no_dtype_check( dequant = dequant - zero_point.to(torch.int32) dequant = dequant.to(output_dtype) dequant = dequant * scale + elif zero_point_domain == ZeroPointDomain.NONE.name: + assert zero_point is None, "zero_point should be None when zero_point_domain is NONE" + dequant = input.to(output_dtype) + dequant = dequant * scale elif zero_point_domain is None: # This case handles dequantization for float8 we expect no zero point and no zero point domain assert zero_point is None, "zero_point should be None when zero_point_domain is None" @@ -813,15 +822,20 @@ def _choose_qparams_affine( assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) scale = torch.clamp(scale, min=eps) - if preserve_zero: - zero_point = quant_min - torch.round(min_val_neg / scale) - zero_point = torch.clamp(zero_point, quant_min, quant_max) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None else: - assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" - mid_point = (quant_max + quant_min + 1) / 2 - zero_point = min_val_neg + scale * mid_point - - return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) + if preserve_zero: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point + + if zero_point is not None: + zero_point = zero_point.to(dtype=zero_point_dtype) + return scale.to(dtype=scale_dtype), zero_point # HQQ