Skip to content

Commit f9a42d6

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1 parent 83663b8 commit f9a42d6

File tree

14 files changed

+283
-15
lines changed

14 files changed

+283
-15
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py renamed to test/prototype/moe_training/test_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2020

2121

22-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
22+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
2323
triton_fp8_col_major_jagged_colwise_scales,
2424
triton_fp8_row_major_jagged_rowwise_scales,
2525
)
26-
from torchao.prototype.scaled_grouped_mm.utils import (
26+
from torchao.prototype.moe_training.utils import (
2727
_is_column_major,
2828
_to_2d_jagged_float8_tensor_colwise,
2929
_to_2d_jagged_float8_tensor_rowwise,

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py renamed to test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
2727
from torchao.float8.float8_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
29-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
29+
from torchao.prototype.moe_training.scaled_grouped_mm import (
3030
_scaled_grouped_mm,
3131
)
3232
from torchao.testing.utils import skip_if_rocm
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
6+
from torchao.float8.float8_utils import compute_error
7+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
8+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
9+
from torchao.quantization.quant_api import quantize_
10+
from torchao.testing.utils import get_compute_capability
11+
12+
try:
13+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
14+
from torchtitan.experiments.llama4.model.moe import MoE
15+
except ImportError:
16+
import warnings
17+
18+
warnings.warn("torchtitan not installed, skipping MoE tests.")
19+
pytest.skip(allow_module_level=True)
20+
21+
22+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
23+
@pytest.mark.skipif(get_compute_capability() < 8.9, reason="Only supported on SM89+")
24+
@pytest.mark.parametrize(
25+
"target_fqns",
26+
[
27+
["experts"],
28+
["does.not.exist"],
29+
],
30+
)
31+
def test_moe_float8_training(target_fqns: list[str]):
32+
model_args = TransformerModelArgs(
33+
moe_enabled=True,
34+
num_experts=8,
35+
dim=256,
36+
)
37+
init_std = 0.02
38+
device = torch.device("cuda")
39+
40+
# reference bf16 MoE
41+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
42+
torch.manual_seed(42)
43+
ref_model.init_weights(init_std, device)
44+
45+
# target MoE for testing conversion
46+
model = MoE(model_args).to(torch.bfloat16).cuda()
47+
torch.manual_seed(42)
48+
model.init_weights(init_std, device)
49+
50+
# assert starting params are identical for both models
51+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
52+
assert torch.equal(param1, param2)
53+
54+
# convert MoE to float8 training
55+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
56+
for target_fqn in target_fqns:
57+
if target_fqn in cur_fqn:
58+
return True
59+
return False
60+
61+
# quantize test model
62+
config = MoETrainingConfig()
63+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
64+
65+
# validate that only the experts were converted
66+
_validate_model_conversion(
67+
model,
68+
target_fqns=target_fqns,
69+
)
70+
71+
# inputs
72+
batch, seq, dim = 8, 2048, 256
73+
ref_x = torch.randn(
74+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
75+
)
76+
x = ref_x.detach().clone().requires_grad_(True)
77+
78+
# forward pass
79+
ref_out = ref_model(ref_x)
80+
out = model(x)
81+
82+
# validate output
83+
out_sqnr = compute_error(out, ref_out)
84+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
85+
86+
# compute loss
87+
labels = torch.ones_like(ref_out)
88+
ref_loss = F.mse_loss(ref_out, labels)
89+
out_loss = F.mse_loss(out, labels)
90+
91+
# backward pass
92+
ref_loss.backward()
93+
out_loss.backward()
94+
95+
# validate input gradient
96+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
97+
assert input_grad_sqnr.item() >= 30.0, (
98+
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
99+
)
100+
101+
# validate param gradients
102+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
103+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
104+
assert param_grad_sqnr.item() >= 25.0, (
105+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
106+
)
107+
108+
109+
def _validate_model_conversion(
110+
root_module: nn.Module,
111+
target_fqns: list[str],
112+
):
113+
def _recursive_validate(
114+
module: nn.Module,
115+
cur_fqn: str,
116+
):
117+
is_allowed_module = cur_fqn in target_fqns
118+
119+
# check current module params
120+
for param_name, param in module.named_parameters(recurse=False):
121+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
122+
if is_converted_type:
123+
assert is_allowed_module, (
124+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
125+
)
126+
if not is_allowed_module:
127+
assert not is_converted_type, (
128+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
129+
)
130+
131+
# recursively check child modules
132+
for child_name, child_module in module.named_children():
133+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
134+
_recursive_validate(child_module, child_fqn)
135+
136+
_recursive_validate(root_module, "")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.prototype.moe_training.scaled_grouped_mm import (
2+
_scaled_grouped_mm as _scaled_grouped_mm,
3+
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
17+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.utils import (
21+
from torchao.prototype.moe_training.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
17+
from torchao.prototype.moe_training import _scaled_grouped_mm
1818

1919
device = torch.device("cuda")
2020

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.core.config import AOBaseConfig
6+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
7+
from torchao.quantization.transform_module import (
8+
register_quantize_module_handler,
9+
)
10+
11+
12+
class MoETrainingConfig(AOBaseConfig):
13+
pass
14+
15+
16+
@register_quantize_module_handler(MoETrainingConfig)
17+
def _moe_training_transform(
18+
module: nn.Module,
19+
config: MoETrainingConfig,
20+
) -> nn.Module:
21+
"""
22+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
23+
24+
Args:
25+
module: Module to modify.
26+
config: MoETrainingConfig which defines how to perform the MoE training transform.
27+
28+
Returns:
29+
nn.Module: The modified module with swapped parameters.
30+
"""
31+
out = swap_params(module)
32+
return out
33+
34+
35+
def swap_params(
36+
module: nn.Module,
37+
*,
38+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
39+
) -> nn.Module:
40+
"""
41+
Recurses through the nn.Module, recursively swapping the data tensor of
42+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
43+
passed the module_filter_fn, if specified.
44+
45+
Args:
46+
module: Module to modify.
47+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
48+
that pass the filter function will be swapped. The inputs to the
49+
filter function are the module instance, and the FQN.
50+
51+
Returns:
52+
nn.Module: The modified module with swapped linear layers.
53+
"""
54+
if isinstance(module, nn.Parameter) and (
55+
module_filter_fn is None or module_filter_fn(module, "")
56+
):
57+
if len(list(module.children())) > 0:
58+
raise AssertionError(
59+
f"Does not support a root nn.Parameter with children: {module}"
60+
)
61+
if not isinstance(module.data, ScaledGroupedMMTensor):
62+
new_data = ScaledGroupedMMTensor(module.data)
63+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
64+
return module
65+
66+
root_module = module
67+
68+
def post_order_traversal(
69+
module: nn.Module,
70+
cur_fqn: Optional[str] = None,
71+
parent_module: Optional[nn.Module] = None,
72+
):
73+
if cur_fqn is None:
74+
cur_fqn = ""
75+
76+
for child_module_name, child_module in module.named_children():
77+
if cur_fqn == "":
78+
new_fqn = child_module_name
79+
else:
80+
new_fqn = f"{cur_fqn}.{child_module_name}"
81+
82+
post_order_traversal(child_module, new_fqn, module)
83+
84+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
85+
for param_name, param in module.named_parameters(recurse=False):
86+
if not isinstance(param.data, ScaledGroupedMMTensor):
87+
new_param = nn.Parameter(
88+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
89+
)
90+
setattr(module, param_name, new_param)
91+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
92+
93+
post_order_traversal(root_module)
94+
return root_module
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
1+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
22
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
33
)
4-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
4+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
55
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
66
)

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py renamed to torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import triton
1717
import triton.language as tl
1818

19-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
19+
from torchao.prototype.moe_training.utils import _is_column_major
2020

2121
EPS = 1e-12
2222

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py renamed to torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from torchao.float8.config import ScalingGranularity
1212
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
13-
from torchao.prototype.scaled_grouped_mm.kernels import (
13+
from torchao.prototype.moe_training.kernels import (
1414
triton_fp8_col_major_jagged_colwise_scales,
1515
triton_fp8_row_major_jagged_rowwise_scales,
1616
)
17-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
17+
from torchao.prototype.moe_training.utils import _is_column_major
1818

1919

2020
def _scaled_grouped_mm(
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.moe_training import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

torchao/prototype/scaled_grouped_mm/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)