Skip to content

Commit 0692c1f

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1 parent 83663b8 commit 0692c1f

File tree

14 files changed

+286
-15
lines changed

14 files changed

+286
-15
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py renamed to test/prototype/moe_training/test_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2020

2121

22-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
22+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
2323
triton_fp8_col_major_jagged_colwise_scales,
2424
triton_fp8_row_major_jagged_rowwise_scales,
2525
)
26-
from torchao.prototype.scaled_grouped_mm.utils import (
26+
from torchao.prototype.moe_training.utils import (
2727
_is_column_major,
2828
_to_2d_jagged_float8_tensor_colwise,
2929
_to_2d_jagged_float8_tensor_rowwise,

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py renamed to test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
2727
from torchao.float8.float8_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
29-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
29+
from torchao.prototype.moe_training.scaled_grouped_mm import (
3030
_scaled_grouped_mm,
3131
)
3232
from torchao.testing.utils import skip_if_rocm
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
6+
from torchao.float8.float8_utils import compute_error
7+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
8+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
9+
from torchao.quantization.quant_api import quantize_
10+
from torchao.testing.utils import get_compute_capability
11+
12+
try:
13+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
14+
from torchtitan.experiments.llama4.model.moe import MoE
15+
except ImportError:
16+
import warnings
17+
18+
warnings.warn("torchtitan not installed, skipping MoE tests.")
19+
pytest.skip(allow_module_level=True)
20+
21+
if not torch.cuda.is_available() or get_compute_capability() < 8.9:
22+
pytest.skip(
23+
"CUDA not available or compute capability < 8.9", allow_module_level=True
24+
)
25+
26+
27+
@pytest.mark.parametrize(
28+
"target_fqns",
29+
[
30+
["experts"],
31+
["does.not.exist"],
32+
],
33+
)
34+
def test_moe_float8_training(target_fqns: list[str]):
35+
model_args = TransformerModelArgs(
36+
moe_enabled=True,
37+
num_experts=8,
38+
dim=256,
39+
)
40+
init_std = 0.02
41+
device = torch.device("cuda")
42+
43+
# reference bf16 MoE
44+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
45+
torch.manual_seed(42)
46+
ref_model.init_weights(init_std, device)
47+
48+
# target MoE for testing conversion
49+
model = MoE(model_args).to(torch.bfloat16).cuda()
50+
torch.manual_seed(42)
51+
model.init_weights(init_std, device)
52+
53+
# assert starting params are identical for both models
54+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
55+
assert torch.equal(param1, param2)
56+
57+
# convert MoE to float8 training
58+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
59+
for target_fqn in target_fqns:
60+
if target_fqn in cur_fqn:
61+
return True
62+
return False
63+
64+
# quantize test model
65+
config = MoETrainingConfig()
66+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
67+
68+
# validate that only the experts were converted
69+
_validate_model_conversion(
70+
model,
71+
target_fqns=target_fqns,
72+
)
73+
74+
# inputs
75+
batch, seq, dim = 8, 2048, 256
76+
ref_x = torch.randn(
77+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
78+
)
79+
x = ref_x.detach().clone().requires_grad_(True)
80+
81+
# forward pass
82+
ref_out = ref_model(ref_x)
83+
out = model(x)
84+
85+
# validate output
86+
out_sqnr = compute_error(out, ref_out)
87+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
88+
89+
# compute loss
90+
labels = torch.ones_like(ref_out)
91+
ref_loss = F.mse_loss(ref_out, labels)
92+
out_loss = F.mse_loss(out, labels)
93+
94+
# backward pass
95+
ref_loss.backward()
96+
out_loss.backward()
97+
98+
# validate input gradient
99+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
100+
assert input_grad_sqnr.item() >= 30.0, (
101+
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
102+
)
103+
104+
# validate param gradients
105+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
106+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
107+
assert param_grad_sqnr.item() >= 25.0, (
108+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
109+
)
110+
111+
112+
def _validate_model_conversion(
113+
root_module: nn.Module,
114+
target_fqns: list[str],
115+
):
116+
def _recursive_validate(
117+
module: nn.Module,
118+
cur_fqn: str,
119+
):
120+
is_allowed_module = cur_fqn in target_fqns
121+
122+
# check current module params
123+
for param_name, param in module.named_parameters(recurse=False):
124+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
125+
if is_converted_type:
126+
assert is_allowed_module, (
127+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
128+
)
129+
if not is_allowed_module:
130+
assert not is_converted_type, (
131+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
132+
)
133+
134+
# recursively check child modules
135+
for child_name, child_module in module.named_children():
136+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
137+
_recursive_validate(child_module, child_fqn)
138+
139+
_recursive_validate(root_module, "")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.prototype.moe_training.scaled_grouped_mm import (
2+
_scaled_grouped_mm as _scaled_grouped_mm,
3+
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
17+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.utils import (
21+
from torchao.prototype.moe_training.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
17+
from torchao.prototype.moe_training import _scaled_grouped_mm
1818

1919
device = torch.device("cuda")
2020

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.core.config import AOBaseConfig
6+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
7+
from torchao.quantization.transform_module import (
8+
register_quantize_module_handler,
9+
)
10+
11+
12+
class MoETrainingConfig(AOBaseConfig):
13+
pass
14+
15+
16+
@register_quantize_module_handler(MoETrainingConfig)
17+
def _moe_training_transform(
18+
module: nn.Module,
19+
config: MoETrainingConfig,
20+
) -> nn.Module:
21+
"""
22+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
23+
24+
Args:
25+
module: Module to modify.
26+
config: MoETrainingConfig which defines how to perform the MoE training transform.
27+
28+
Returns:
29+
nn.Module: The modified module with swapped parameters.
30+
"""
31+
out = swap_params(module)
32+
return out
33+
34+
35+
def swap_params(
36+
module: nn.Module,
37+
*,
38+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
39+
) -> nn.Module:
40+
"""
41+
Recurses through the nn.Module, recursively swapping the data tensor of
42+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
43+
passed the module_filter_fn, if specified.
44+
45+
Args:
46+
module: Module to modify.
47+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
48+
that pass the filter function will be swapped. The inputs to the
49+
filter function are the module instance, and the FQN.
50+
51+
Returns:
52+
nn.Module: The modified module with swapped linear layers.
53+
"""
54+
if isinstance(module, nn.Parameter) and (
55+
module_filter_fn is None or module_filter_fn(module, "")
56+
):
57+
if len(list(module.children())) > 0:
58+
raise AssertionError(
59+
f"Does not support a root nn.Parameter with children: {module}"
60+
)
61+
if not isinstance(module.data, ScaledGroupedMMTensor):
62+
new_data = ScaledGroupedMMTensor(module.data)
63+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
64+
return module
65+
66+
root_module = module
67+
68+
def post_order_traversal(
69+
module: nn.Module,
70+
cur_fqn: Optional[str] = None,
71+
parent_module: Optional[nn.Module] = None,
72+
):
73+
if cur_fqn is None:
74+
cur_fqn = ""
75+
76+
for child_module_name, child_module in module.named_children():
77+
if cur_fqn == "":
78+
new_fqn = child_module_name
79+
else:
80+
new_fqn = f"{cur_fqn}.{child_module_name}"
81+
82+
post_order_traversal(child_module, new_fqn, module)
83+
84+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
85+
for param_name, param in module.named_parameters(recurse=False):
86+
if not isinstance(param.data, ScaledGroupedMMTensor):
87+
new_param = nn.Parameter(
88+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
89+
)
90+
setattr(module, param_name, new_param)
91+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
92+
93+
post_order_traversal(root_module)
94+
return root_module
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
1+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
22
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
33
)
4-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
4+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
55
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
66
)

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py renamed to torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import triton
1717
import triton.language as tl
1818

19-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
19+
from torchao.prototype.moe_training.utils import _is_column_major
2020

2121
EPS = 1e-12
2222

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py renamed to torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from torchao.float8.config import ScalingGranularity
1212
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
13-
from torchao.prototype.scaled_grouped_mm.kernels import (
13+
from torchao.prototype.moe_training.kernels import (
1414
triton_fp8_col_major_jagged_colwise_scales,
1515
triton_fp8_row_major_jagged_rowwise_scales,
1616
)
17-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
17+
from torchao.prototype.moe_training.utils import _is_column_major
1818

1919

2020
def _scaled_grouped_mm(
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.moe_training import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

torchao/prototype/scaled_grouped_mm/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)