diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index b74c5d2ecf..68b5f41438 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -24,7 +24,9 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) +from torchao.float8.config import e4m3_dtype from torchao.quantization import ( + FbgemmConfig, GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, @@ -45,6 +47,7 @@ is_fbcode, is_ROCM, is_sm_at_least_89, + is_sm_at_least_90, ) is_cusparselt_available = ( @@ -99,6 +102,10 @@ def get_quantization_functions( if is_sm_at_least_89(): base_functions.append(float8_weight_only()) + if is_sm_at_least_90(): + base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16)) + base_functions.append(FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16)) + return base_functions diff --git a/test/dtypes/test_fbgemm_quantized.py b/test/dtypes/test_fbgemm_fp8.py similarity index 77% rename from test/dtypes/test_fbgemm_quantized.py rename to test/dtypes/test_fbgemm_fp8.py index fe2573530c..d2f1e2d82a 100644 --- a/test/dtypes/test_fbgemm_quantized.py +++ b/test/dtypes/test_fbgemm_fp8.py @@ -12,15 +12,20 @@ run_tests, ) +from torchao.float8.config import e4m3_dtype from torchao.quantization import ( FbgemmConfig, quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import is_sm_at_least_90 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, + is_sm_at_least_90, +) -class TestFbgemmInt4Tensor(TestCase): +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +class TestFbgemmFp8Tensor(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") def test_linear(self): @@ -30,10 +35,9 @@ def test_linear(self): linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, + input_dtype=e4m3_dtype, + weight_dtype=e4m3_dtype, output_dtype=torch.bfloat16, - block_size=(1, 128), ) quantize_(linear, config) quantized = linear(input) diff --git a/test/dtypes/test_fbgemm_quantized_tensor.py b/test/dtypes/test_fbgemm_int4.py similarity index 92% rename from test/dtypes/test_fbgemm_quantized_tensor.py rename to test/dtypes/test_fbgemm_int4.py index 51b68dd977..22fe5bc110 100644 --- a/test/dtypes/test_fbgemm_quantized_tensor.py +++ b/test/dtypes/test_fbgemm_int4.py @@ -18,15 +18,15 @@ ) from torchao.quantization.utils import compute_error from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90, ) +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") class TestFbgemmInt4Tensor(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6") def test_linear(self): dtype = torch.bfloat16 device = "cuda" diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 40f70fe93e..756cfabb32 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -442,23 +442,25 @@ def ffn_or_attn_only(mod, fqn): f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" ) quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) - elif "fbgemm" in quantization: + elif "fbgemm" in quantization and "int4" in quantization: from torchao.quantization import FbgemmConfig _, precision, group_size = quantization.split("-") group_size = int(group_size) block_size = [1, group_size] - if precision == "int4": - quantize_( - model, - FbgemmConfig( - torch.bfloat16, torch.int4, torch.bfloat16, block_size - ), - ) - else: - raise NotImplementedError( - f"FbegemmConfig({precision=}) not supported yet" - ) + assert precision == "int4", f"FbegemmConfig({precision=}) not supported yet" + quantize_( + model, + FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, block_size), + ) + elif "fbgemm" in quantization and "fp8" in quantization: + from torchao.float8.config import e4m3_dtype + from torchao.quantization import FbgemmConfig + + quantize_( + model, + FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16), + ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 1003491828..692d56ad31 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,7 +8,8 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .fbgemm_quantized_tensor import to_fbgemm_quantized +from .fbgemm_fp8_tensor import to_fbgemm_fp8 +from .fbgemm_int4_tensor import to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -62,5 +63,6 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", - "to_fbgemm_quantized", + "to_fbgemm_int4", + "to_fbgemm_fp8", ] diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py new file mode 100644 index 0000000000..735c21c2ca --- /dev/null +++ b/torchao/dtypes/fbgemm_fp8_tensor.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) + +__all__ = [ + "to_fbgemm_fp8", +] + +aten = torch.ops.aten + + +class FbgemmFp8Tensor(TorchAOBaseTensor): + tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"] + tensor_attributes = ["dtype"] + + def __new__(cls, float8_data, scale, activation_scale_ub, dtype): + shape = float8_data.shape + kwargs = {} + kwargs["device"] = float8_data.device + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, float8_data, scale, activation_scale_ub, dtype): + self.float8_data = float8_data + self.scale = scale + self.activation_scale_ub = activation_scale_ub + + def __tensor_flatten__(self): + return self.tensor_data_attrs, [ + getattr(self, attr) for attr in self.tensor_attributes + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensor_attributes, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *[getattr(self, attr) for attr in self.tensor_attributes], + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, " + f"activation_scale_ub={self.activation_scale_ub}, " + f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + def _quantization_type(self): + return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}" + + @classmethod + def from_float( + cls, + w: torch.Tensor, + activation_scale_ub: Optional[float] = None, + ): + if activation_scale_ub is None: + activation_scale_ub = 1200.0 + + activation_scale_ub = torch.tensor( + [activation_scale_ub], + dtype=torch.float, + device=w.device, + ) + wq, w_scale = torch.ops.triton.quantize_fp8_row(w) + # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + dtype = w.dtype + del w + return FbgemmFp8Tensor( + wq, + w_scale, + activation_scale_ub=activation_scale_ub, + dtype=dtype, + ) + + +implements = FbgemmFp8Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + # not used + num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + input_tensor, num_tokens, weight_tensor.activation_scale_ub + ) + res = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, + weight_tensor.float8_data, + x_scale, + weight_tensor.scale, + use_fast_accum=True, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + if bias is not None: + res = res + bias + + return res + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements([aten.clone.default, aten.copy_.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +to_fbgemm_fp8 = FbgemmFp8Tensor.from_float + + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([FbgemmFp8Tensor]) diff --git a/torchao/dtypes/fbgemm_quantized_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py similarity index 87% rename from torchao/dtypes/fbgemm_quantized_tensor.py rename to torchao/dtypes/fbgemm_int4_tensor.py index fd788a73a3..c2ab6246bf 100644 --- a/torchao/dtypes/fbgemm_quantized_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -11,10 +11,13 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) __all__ = [ - "to_fbgemm_quantized", + "to_fbgemm_int4", ] aten = torch.ops.aten @@ -71,25 +74,22 @@ def __repr__(self): f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) + def _quantization_type(self): + return f"shape={self.shape}, group_size={self.group_size}, device={self.device}" + @classmethod def from_float( cls, w: torch.Tensor, - input_dtype: torch.dtype, - weight_dtype: torch.dtype, - output_dtype: torch.dtype, block_size: List[int], ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" ) - group_size = block_size[-1] + if int4_row_quantize_zp is None: + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - assert (input_dtype, weight_dtype, output_dtype) == ( - torch.bfloat16, - torch.int4, - torch.bfloat16, - ) + group_size = block_size[-1] if w.ndim >= 3: wq, scale, zero_point = zip( @@ -138,9 +138,10 @@ def _(func, types, args, kwargs): weight_tensor.scale, weight_tensor.zero_point, ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) if bias is not None: res = res + bias - return res.reshape(*orig_act_size[:-1], orig_out_features) + return res @implements([aten.detach.default, aten.alias.default]) @@ -157,5 +158,9 @@ def _(func, types, args, kwargs): ) -# We can have `to_fbgemm_tensor` to dispatch to different Fbgemm tensors later -to_fbgemm_quantized = FbgemmInt4Tensor.from_float +to_fbgemm_int4 = FbgemmInt4Tensor.from_float + + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([FbgemmInt4Tensor]) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ada19859bc..c8903cc77e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -46,7 +46,8 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, - to_fbgemm_quantized, + to_fbgemm_fp8, + to_fbgemm_int4, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -537,6 +538,9 @@ def _quantization_type(weight: torch.Tensor): if isinstance(weight, LinearActivationQuantizedTensor): return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + if hasattr(weight, "_quantization_type"): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + if type(weight) is torch.Tensor: return "not quantized" @@ -1981,7 +1985,8 @@ class FbgemmConfig(AOBaseConfig): input_dtype: torch.dtype weight_dtype: torch.dtype output_dtype: torch.dtype - block_size: List[int] + block_size: Optional[List[int]] = None + activation_scale_ub: Optional[float] = None @register_quantize_module_handler(FbgemmConfig) @@ -1998,22 +2003,31 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: _SUPPORTED_DTYPES = { (torch.bfloat16, torch.int4, torch.bfloat16), + (e4m3_dtype, e4m3_dtype, torch.bfloat16), } if ( - config.input_dtype, - config.weight_dtype, - config.output_dtype, - ) in _SUPPORTED_DTYPES: - weight = to_fbgemm_quantized( + (config.input_dtype == torch.bfloat16) + and (config.weight_dtype == torch.int4) + and (config.output_dtype == torch.bfloat16) + ): + weight = to_fbgemm_int4( module.weight, - config.input_dtype, - config.weight_dtype, - config.output_dtype, config.block_size, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) + elif ( + (config.input_dtype == e4m3_dtype) + and (config.weight_dtype == e4m3_dtype) + and (config.output_dtype == torch.bfloat16) + ): + weight = to_fbgemm_fp8( + module.weight, + config.activation_scale_ub, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) else: raise NotImplementedError( f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}"