Skip to content

Commit 123a4bc

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1 parent 83663b8 commit 123a4bc

File tree

14 files changed

+277
-15
lines changed

14 files changed

+277
-15
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py renamed to test/prototype/moe_training/test_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2020

2121

22-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
22+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
2323
triton_fp8_col_major_jagged_colwise_scales,
2424
triton_fp8_row_major_jagged_rowwise_scales,
2525
)
26-
from torchao.prototype.scaled_grouped_mm.utils import (
26+
from torchao.prototype.moe_training.utils import (
2727
_is_column_major,
2828
_to_2d_jagged_float8_tensor_colwise,
2929
_to_2d_jagged_float8_tensor_rowwise,

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py renamed to test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
2727
from torchao.float8.float8_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
29-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
29+
from torchao.prototype.moe_training.scaled_grouped_mm import (
3030
_scaled_grouped_mm,
3131
)
3232
from torchao.testing.utils import skip_if_rocm
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
6+
from torchao.float8.float8_utils import compute_error
7+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
8+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
9+
from torchao.quantization.quant_api import quantize_
10+
11+
try:
12+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
13+
from torchtitan.experiments.llama4.model.moe import MoE
14+
except ImportError:
15+
import warnings
16+
17+
warnings.warn("torchtitan not installed, skipping MoE tests.")
18+
pytest.skip(allow_module_level=True)
19+
20+
21+
@pytest.mark.parametrize(
22+
"target_fqns",
23+
[["experts"]],
24+
)
25+
def test_moe_float8_training(target_fqns: list[str]):
26+
model_args = TransformerModelArgs(
27+
moe_enabled=True,
28+
num_experts=8,
29+
dim=256,
30+
)
31+
init_std = 0.02
32+
device = torch.device("cuda")
33+
34+
# reference bf16 MoE
35+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
36+
torch.manual_seed(42)
37+
ref_model.init_weights(init_std, device)
38+
39+
# target MoE for testing conversion
40+
model = MoE(model_args).to(torch.bfloat16).cuda()
41+
torch.manual_seed(42)
42+
model.init_weights(init_std, device)
43+
44+
# assert starting params are identical for both models
45+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
46+
assert torch.equal(param1, param2)
47+
48+
# convert MoE to float8 training
49+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
50+
for target_fqn in target_fqns:
51+
if target_fqn in cur_fqn:
52+
return True
53+
return False
54+
55+
# quantize test model
56+
config = MoETrainingConfig()
57+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
58+
59+
# validate that only the experts were converted
60+
_validate_model_conversion(
61+
model,
62+
target_fqns=target_fqns,
63+
)
64+
65+
# inputs
66+
batch, seq, dim = 8, 2048, 256
67+
ref_x = torch.randn(
68+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
69+
)
70+
x = ref_x.detach().clone().requires_grad_(True)
71+
72+
# forward pass
73+
ref_out = ref_model(ref_x)
74+
out = model(x)
75+
76+
# validate output
77+
out_sqnr = compute_error(out, ref_out)
78+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
79+
80+
# compute loss
81+
labels = torch.ones_like(ref_out)
82+
ref_loss = F.mse_loss(ref_out, labels)
83+
out_loss = F.mse_loss(out, labels)
84+
85+
# backward pass
86+
ref_loss.backward()
87+
out_loss.backward()
88+
89+
# validate input gradient
90+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
91+
assert input_grad_sqnr.item() >= 30.0, (
92+
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
93+
)
94+
95+
# validate param gradients
96+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
97+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
98+
assert param_grad_sqnr.item() >= 25.0, (
99+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
100+
)
101+
102+
103+
def _validate_model_conversion(
104+
root_module: nn.Module,
105+
target_fqns: list[str],
106+
):
107+
def _recursive_validate(
108+
module: nn.Module,
109+
cur_fqn: str,
110+
):
111+
is_allowed_module = cur_fqn in target_fqns
112+
113+
# check current module params
114+
for param_name, param in module.named_parameters(recurse=False):
115+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
116+
if is_converted_type:
117+
assert is_allowed_module, (
118+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
119+
)
120+
if not is_allowed_module:
121+
assert not is_converted_type, (
122+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
123+
)
124+
125+
# recursively check child modules
126+
for child_name, child_module in module.named_children():
127+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
128+
_recursive_validate(child_module, child_fqn)
129+
130+
_recursive_validate(root_module, "")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.prototype.moe_training.scaled_grouped_mm import (
2+
_scaled_grouped_mm as _scaled_grouped_mm,
3+
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
17+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.utils import (
21+
from torchao.prototype.moe_training.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
17+
from torchao.prototype.moe_training import _scaled_grouped_mm
1818

1919
device = torch.device("cuda")
2020

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.core.config import AOBaseConfig
6+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
7+
from torchao.quantization.transform_module import (
8+
register_quantize_module_handler,
9+
)
10+
11+
12+
class MoETrainingConfig(AOBaseConfig):
13+
pass
14+
15+
16+
@register_quantize_module_handler(MoETrainingConfig)
17+
def _moe_training_transform(
18+
module: nn.Module,
19+
config: MoETrainingConfig,
20+
) -> nn.Module:
21+
"""
22+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
23+
24+
Args:
25+
module: Module to modify.
26+
config: MoETrainingConfig which defines how to perform the MoE training transform.
27+
28+
Returns:
29+
nn.Module: The modified module with swapped parameters.
30+
"""
31+
out = swap_params(module)
32+
return out
33+
34+
35+
def swap_params(
36+
module: nn.Module,
37+
*,
38+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
39+
) -> nn.Module:
40+
"""
41+
Recurses through the nn.Module, recursively swapping the data tensor of
42+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
43+
passed the module_filter_fn, if specified.
44+
45+
Args:
46+
module: Module to modify.
47+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
48+
that pass the filter function will be swapped. The inputs to the
49+
filter function are the module instance, and the FQN.
50+
51+
Returns:
52+
nn.Module: The modified module with swapped linear layers.
53+
"""
54+
if isinstance(module, nn.Parameter) and (
55+
module_filter_fn is None or module_filter_fn(module, "")
56+
):
57+
if len(list(module.children())) > 0:
58+
raise AssertionError(
59+
f"Does not support a root nn.Parameter with children: {module}"
60+
)
61+
if not isinstance(module.data, ScaledGroupedMMTensor):
62+
new_data = ScaledGroupedMMTensor(module.data)
63+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
64+
return module
65+
66+
root_module = module
67+
68+
def post_order_traversal(
69+
module: nn.Module,
70+
cur_fqn: Optional[str] = None,
71+
parent_module: Optional[nn.Module] = None,
72+
):
73+
if cur_fqn is None:
74+
cur_fqn = ""
75+
76+
for child_module_name, child_module in module.named_children():
77+
if cur_fqn == "":
78+
new_fqn = child_module_name
79+
else:
80+
new_fqn = f"{cur_fqn}.{child_module_name}"
81+
82+
post_order_traversal(child_module, new_fqn, module)
83+
84+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
85+
for param_name, param in module.named_parameters(recurse=False):
86+
if not isinstance(param.data, ScaledGroupedMMTensor):
87+
new_param = nn.Parameter(
88+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
89+
)
90+
setattr(module, param_name, new_param)
91+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
92+
93+
post_order_traversal(root_module)
94+
return root_module
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
1+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
22
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
33
)
4-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
4+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
55
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
66
)

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py renamed to torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import triton
1717
import triton.language as tl
1818

19-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
19+
from torchao.prototype.moe_training.utils import _is_column_major
2020

2121
EPS = 1e-12
2222

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py renamed to torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from torchao.float8.config import ScalingGranularity
1212
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
13-
from torchao.prototype.scaled_grouped_mm.kernels import (
13+
from torchao.prototype.moe_training.kernels import (
1414
triton_fp8_col_major_jagged_colwise_scales,
1515
triton_fp8_row_major_jagged_rowwise_scales,
1616
)
17-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
17+
from torchao.prototype.moe_training.utils import _is_column_major
1818

1919

2020
def _scaled_grouped_mm(
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.moe_training import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

torchao/prototype/scaled_grouped_mm/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)