Skip to content

float8 moe training conversion API prototype #2275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.scaled_grouped_mm.utils import (
from torchao.prototype.moe_training.utils import (
_is_column_major,
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
from torchao.prototype.moe_training.scaled_grouped_mm import (
_scaled_grouped_mm,
)
from torchao.testing.utils import skip_if_rocm
Expand Down
140 changes: 140 additions & 0 deletions test/prototype/moe_training/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import copy

import pytest
import torch
from torch import nn
from torch.nn import functional as F

# this feature requires CUDA and SM89+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
pytest.skip(
"CUDA not available or compute capability < 8.9", allow_module_level=True
)

from torchao.float8.float8_utils import compute_error
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
from torchao.quantization.quant_api import quantize_

# this test requires torchtitan
try:
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
from torchtitan.experiments.llama4.model.moe import MoE
except ImportError:
import warnings

warnings.warn("torchtitan not installed, skipping MoE tests.")
pytest.skip(allow_module_level=True)


@pytest.mark.parametrize(
"target_fqns",
[
["experts"],
["does.not.exist"],
],
)
def test_moe_float8_training(target_fqns: list[str]):
model_args = TransformerModelArgs(
moe_enabled=True,
num_experts=8,
dim=256,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

# target MoE for testing conversion
model = copy.deepcopy(ref_model)

# assert starting params are identical for both models
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
assert torch.equal(param1, param2)

# convert MoE to float8 training
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in target_fqns:
if target_fqn in cur_fqn:
return True
return False

# quantize test model
config = MoETrainingConfig()
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
_validate_model_conversion(
model,
target_fqns=target_fqns,
)

# inputs
batch, seq, dim = 8, 2048, 256
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
x = ref_x.detach().clone().requires_grad_(True)

# forward pass
ref_out = ref_model(ref_x)
out = model(x)

# validate output
out_sqnr = compute_error(out, ref_out)
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."

# compute loss
labels = torch.ones_like(ref_out)
ref_loss = F.mse_loss(ref_out, labels)
out_loss = F.mse_loss(out, labels)

# backward pass
ref_loss.backward()
out_loss.backward()

# validate input gradient
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
assert input_grad_sqnr.item() >= 30.0, (
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
)

# validate param gradients
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
param_grad_sqnr = compute_error(param1.grad, param2.grad)
assert param_grad_sqnr.item() >= 25.0, (
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
)


def _validate_model_conversion(
root_module: nn.Module,
target_fqns: list[str],
):
def _recursive_validate(
module: nn.Module,
cur_fqn: str,
):
is_allowed_module = cur_fqn in target_fqns

# check current module params
for param_name, param in module.named_parameters(recurse=False):
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
if is_converted_type:
assert is_allowed_module, (
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
)
if not is_allowed_module:
assert not is_converted_type, (
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
)

# recursively check child modules
for child_name, child_module in module.named_children():
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
_recursive_validate(child_module, child_fqn)

_recursive_validate(root_module, "")
3 changes: 3 additions & 0 deletions torchao/prototype/moe_training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchao.prototype.moe_training.scaled_grouped_mm import _scaled_grouped_mm

__all__ = ["_scaled_grouped_mm"]
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from tabulate import tabulate
from tqdm import tqdm

from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.scaled_grouped_mm.utils import (
from torchao.prototype.moe_training.utils import (
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tabulate import tabulate
from tqdm import tqdm

from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
from torchao.prototype.moe_training import _scaled_grouped_mm

device = torch.device("cuda")

Expand Down
112 changes: 112 additions & 0 deletions torchao/prototype/moe_training/conversion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Callable, Optional

from torch import nn

from torchao.core.config import AOBaseConfig
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)


class MoETrainingConfig(AOBaseConfig):
"""
The MoETrainingConfig is specifically designed to be used on MoE models using
`torch._grouped_mm` to implement expert computation in token-choice routing,
where expert weights are implemented as 3D nn.Parameters wit `num_experts` as
the leading dim.

MoETrainingConfig has a module handler registered to it which will
find all nn.Parameters whose parent module matches the module filter function,
and swap their data tensor with a ScaledGroupedMMTensor.

The ScaledGroupedMMTensor is a tensor subclass which overrides the
`torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm,
which performs dynamic float8 rowwise quantization on scaled grouped GEMM
operands in both the forward and backward pass.

For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
"""

pass


@register_quantize_module_handler(MoETrainingConfig)
def _moe_training_transform(
module: nn.Module,
config: MoETrainingConfig,
) -> nn.Module:
"""
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.

Args:
module: Module to modify.
config: MoETrainingConfig which defines how to perform the MoE training transform.

Returns:
nn.Module: The modified module with swapped parameters.
"""
out = _swap_params(module)
return out


def _swap_params(
module: nn.Module,
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> nn.Module:
"""
Recurses through the nn.Module, recursively swapping the data tensor of
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
passed the module_filter_fn, if specified.

Args:
module: Module to modify.
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance, and the FQN.

Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, nn.Parameter) and (
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Parameter with children: {module}"
)
if not isinstance(module.data, ScaledGroupedMMTensor):
new_data = ScaledGroupedMMTensor(module.data)
return nn.Parameter(new_data, requires_grad=module.requires_grad)
return module

root_module = module

def post_order_traversal(
module: nn.Module,
cur_fqn: Optional[str] = None,
parent_module: Optional[nn.Module] = None,
):
if cur_fqn is None:
cur_fqn = ""

for child_module_name, child_module in module.named_children():
if cur_fqn == "":
new_fqn = child_module_name
else:
new_fqn = f"{cur_fqn}.{child_module_name}"

post_order_traversal(child_module, new_fqn, module)

if module_filter_fn is None or module_filter_fn(module, cur_fqn):
for param_name, param in module.named_parameters(recurse=False):
if not isinstance(param.data, ScaledGroupedMMTensor):
new_param = nn.Parameter(
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
)
setattr(module, param_name, new_param)
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")

post_order_traversal(root_module)
return root_module
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
)
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import triton
import triton.language as tl

from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
from torchao.prototype.moe_training.utils import _is_column_major

EPS = 1e-12

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.scaled_grouped_mm.kernels import (
from torchao.prototype.moe_training.kernels import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
from torchao.prototype.moe_training.utils import _is_column_major


def _scaled_grouped_mm(
Expand Down Expand Up @@ -83,7 +83,10 @@ def forward(
assert not _is_column_major(A), "A must be row-major"

# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
assert _is_column_major(B_t), "B must be column-major"
if not _is_column_major(B_t):
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
# TODO: figure out better solution than transposing for each forward pass.
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)

# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
# A shape: (M, K)
Expand Down
35 changes: 35 additions & 0 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch

from torchao.prototype.moe_training import _scaled_grouped_mm


class ScaledGroupedMMTensor(torch.Tensor):
"""
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
and overrides the torch._grouped_mm op by dispatching to the
differentiable _scaled_grouped_mm autograd function.
"""

grouped_mm_func_name = "_grouped_mm"
offs_arg_name = "offs"

def __init__(self, data: torch.Tensor):
self._data = data

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
if func.__name__ == cls.grouped_mm_func_name:
# Use torchao scaled grouped mm with dynamic quant for
# "2d x 3d with offsets" case (used for routed experts).
# Otherwise, fall back to regular grouped mm.
#
# TODO: support "3d x 3d without offsets" case, which is
# used for shared experts. This is basically the grouped_mm
# kernel handling a bmm.
A, B = args[0], args[1]
A_is_2d = A.dim() == 2
B_is_3d = B.dim() == 3
has_offs = kwargs.get(cls.offs_arg_name) is not None
if A_is_2d and B_is_3d and has_offs:
return _scaled_grouped_mm(*args, **kwargs)
return super().__torch_function__(func, types, args, kwargs)
3 changes: 0 additions & 3 deletions torchao/prototype/scaled_grouped_mm/__init__.py

This file was deleted.

Loading