Skip to content

Add support for fbgemm fp8 kernels #2276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
GemliteUIntXWeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Expand All @@ -45,6 +47,7 @@
is_fbcode,
is_ROCM,
is_sm_at_least_89,
is_sm_at_least_90,
)

is_cusparselt_available = (
Expand Down Expand Up @@ -99,6 +102,10 @@ def get_quantization_functions(
if is_sm_at_least_89():
base_functions.append(float8_weight_only())

if is_sm_at_least_90():
base_functions.append(FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16))
base_functions.append(FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16))

return base_functions


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
run_tests,
)

from torchao.float8.config import e4m3_dtype
from torchao.quantization import (
FbgemmConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import is_sm_at_least_90
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_90,
)


class TestFbgemmInt4Tensor(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
class TestFbgemmFp8Tensor(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
def test_linear(self):
Expand All @@ -30,10 +35,9 @@ def test_linear(self):
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
input_dtype=e4m3_dtype,
weight_dtype=e4m3_dtype,
output_dtype=torch.bfloat16,
block_size=(1, 128),
)
quantize_(linear, config)
quantized = linear(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_90,
)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
class TestFbgemmInt4Tensor(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6")
def test_linear(self):
dtype = torch.bfloat16
device = "cuda"
Expand Down
26 changes: 14 additions & 12 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,23 +442,25 @@ def ffn_or_attn_only(mod, fqn):
f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
)
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
elif "fbgemm" in quantization:
elif "fbgemm" in quantization and "int4" in quantization:
from torchao.quantization import FbgemmConfig

_, precision, group_size = quantization.split("-")
group_size = int(group_size)
block_size = [1, group_size]
if precision == "int4":
quantize_(
model,
FbgemmConfig(
torch.bfloat16, torch.int4, torch.bfloat16, block_size
),
)
else:
raise NotImplementedError(
f"FbegemmConfig({precision=}) not supported yet"
)
assert precision == "int4", f"FbegemmConfig({precision=}) not supported yet"
quantize_(
model,
FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, block_size),
)
elif "fbgemm" in quantization and "fp8" in quantization:
from torchao.float8.config import e4m3_dtype
from torchao.quantization import FbgemmConfig

quantize_(
model,
FbgemmConfig(e4m3_dtype, e4m3_dtype, torch.bfloat16),
)
elif "int4dq-" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout

Expand Down
6 changes: 4 additions & 2 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .fbgemm_quantized_tensor import to_fbgemm_quantized
from .fbgemm_fp8_tensor import to_fbgemm_fp8
from .fbgemm_int4_tensor import to_fbgemm_int4
from .floatx import (
CutlassSemiSparseLayout,
Float8Layout,
Expand Down Expand Up @@ -62,5 +63,6 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"to_fbgemm_quantized",
"to_fbgemm_int4",
"to_fbgemm_fp8",
]
154 changes: 154 additions & 0 deletions torchao/dtypes/fbgemm_fp8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)

__all__ = [
"to_fbgemm_fp8",
]

aten = torch.ops.aten


class FbgemmFp8Tensor(TorchAOBaseTensor):
tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"]
tensor_attributes = ["dtype"]

def __new__(cls, float8_data, scale, activation_scale_ub, dtype):
shape = float8_data.shape
kwargs = {}
kwargs["device"] = float8_data.device
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, float8_data, scale, activation_scale_ub, dtype):
self.float8_data = float8_data
self.scale = scale
self.activation_scale_ub = activation_scale_ub

def __tensor_flatten__(self):
return self.tensor_data_attrs, [
getattr(self, attr) for attr in self.tensor_attributes
]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
return cls(
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
*tensor_attributes,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
*[fn(getattr(self, attr)) for attr in self.tensor_data_attrs],
*[getattr(self, attr) for attr in self.tensor_attributes],
)

def __repr__(self):
return (
f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, "
f"activation_scale_ub={self.activation_scale_ub}, "
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)

def _quantization_type(self):
return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}"

@classmethod
def from_float(
cls,
w: torch.Tensor,
activation_scale_ub: Optional[float] = None,
):
if activation_scale_ub is None:
activation_scale_ub = 1200.0

activation_scale_ub = torch.tensor(
[activation_scale_ub],
dtype=torch.float,
device=w.device,
)
wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
dtype = w.dtype
del w
return FbgemmFp8Tensor(
wq,
w_scale,
activation_scale_ub=activation_scale_ub,
dtype=dtype,
)


implements = FbgemmFp8Tensor.implements


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if not input_tensor.is_floating_point():
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

orig_act_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

# not used
num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device)
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
input_tensor, num_tokens, weight_tensor.activation_scale_ub
)
res = torch.ops.fbgemm.f8f8bf16_rowwise(
xq,
weight_tensor.float8_data,
x_scale,
weight_tensor.scale,
use_fast_accum=True,
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
if bias is not None:
res = res + bias

return res


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements([aten.clone.default, aten.copy_.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


to_fbgemm_fp8 = FbgemmFp8Tensor.from_float


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with FbgemmFp8Tensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([FbgemmFp8Tensor])
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.utils import TorchAOBaseTensor
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
)

__all__ = [
"to_fbgemm_quantized",
"to_fbgemm_int4",
]

aten = torch.ops.aten
Expand Down Expand Up @@ -71,25 +74,22 @@ def __repr__(self):
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)

def _quantization_type(self):
return f"shape={self.shape}, group_size={self.group_size}, device={self.device}"

@classmethod
def from_float(
cls,
w: torch.Tensor,
input_dtype: torch.dtype,
weight_dtype: torch.dtype,
output_dtype: torch.dtype,
block_size: List[int],
):
assert len(block_size) == w.ndim, (
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
)
group_size = block_size[-1]
if int4_row_quantize_zp is None:
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")

assert (input_dtype, weight_dtype, output_dtype) == (
torch.bfloat16,
torch.int4,
torch.bfloat16,
)
group_size = block_size[-1]

if w.ndim >= 3:
wq, scale, zero_point = zip(
Expand Down Expand Up @@ -138,9 +138,10 @@ def _(func, types, args, kwargs):
weight_tensor.scale,
weight_tensor.zero_point,
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
if bias is not None:
res = res + bias
return res.reshape(*orig_act_size[:-1], orig_out_features)
return res


@implements([aten.detach.default, aten.alias.default])
Expand All @@ -157,5 +158,9 @@ def _(func, types, args, kwargs):
)


# We can have `to_fbgemm_tensor` to dispatch to different Fbgemm tensors later
to_fbgemm_quantized = FbgemmInt4Tensor.from_float
to_fbgemm_int4 = FbgemmInt4Tensor.from_float


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([FbgemmInt4Tensor])
Loading
Loading