Skip to content

Commit 0c06aba

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1 parent 83663b8 commit 0c06aba

File tree

14 files changed

+280
-15
lines changed

14 files changed

+280
-15
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py renamed to test/prototype/moe_training/test_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2020

2121

22-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
22+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
2323
triton_fp8_col_major_jagged_colwise_scales,
2424
triton_fp8_row_major_jagged_rowwise_scales,
2525
)
26-
from torchao.prototype.scaled_grouped_mm.utils import (
26+
from torchao.prototype.moe_training.utils import (
2727
_is_column_major,
2828
_to_2d_jagged_float8_tensor_colwise,
2929
_to_2d_jagged_float8_tensor_rowwise,

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py renamed to test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
2727
from torchao.float8.float8_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
29-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
29+
from torchao.prototype.moe_training.scaled_grouped_mm import (
3030
_scaled_grouped_mm,
3131
)
3232
from torchao.testing.utils import skip_if_rocm
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
6+
from torchao.float8.float8_utils import compute_error
7+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
8+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
9+
from torchao.quantization.quant_api import quantize_
10+
11+
try:
12+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
13+
from torchtitan.experiments.llama4.model.moe import MoE
14+
except ImportError:
15+
import warnings
16+
17+
warnings.warn("torchtitan not installed, skipping MoE tests.")
18+
pytest.skip(allow_module_level=True)
19+
20+
21+
@pytest.mark.parametrize(
22+
"target_fqns",
23+
[
24+
["experts"],
25+
["does.not.exist"],
26+
],
27+
)
28+
def test_moe_float8_training(target_fqns: list[str]):
29+
model_args = TransformerModelArgs(
30+
moe_enabled=True,
31+
num_experts=8,
32+
dim=256,
33+
)
34+
init_std = 0.02
35+
device = torch.device("cuda")
36+
37+
# reference bf16 MoE
38+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
39+
torch.manual_seed(42)
40+
ref_model.init_weights(init_std, device)
41+
42+
# target MoE for testing conversion
43+
model = MoE(model_args).to(torch.bfloat16).cuda()
44+
torch.manual_seed(42)
45+
model.init_weights(init_std, device)
46+
47+
# assert starting params are identical for both models
48+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
49+
assert torch.equal(param1, param2)
50+
51+
# convert MoE to float8 training
52+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
53+
for target_fqn in target_fqns:
54+
if target_fqn in cur_fqn:
55+
return True
56+
return False
57+
58+
# quantize test model
59+
config = MoETrainingConfig()
60+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
61+
62+
# validate that only the experts were converted
63+
_validate_model_conversion(
64+
model,
65+
target_fqns=target_fqns,
66+
)
67+
68+
# inputs
69+
batch, seq, dim = 8, 2048, 256
70+
ref_x = torch.randn(
71+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
72+
)
73+
x = ref_x.detach().clone().requires_grad_(True)
74+
75+
# forward pass
76+
ref_out = ref_model(ref_x)
77+
out = model(x)
78+
79+
# validate output
80+
out_sqnr = compute_error(out, ref_out)
81+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
82+
83+
# compute loss
84+
labels = torch.ones_like(ref_out)
85+
ref_loss = F.mse_loss(ref_out, labels)
86+
out_loss = F.mse_loss(out, labels)
87+
88+
# backward pass
89+
ref_loss.backward()
90+
out_loss.backward()
91+
92+
# validate input gradient
93+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
94+
assert input_grad_sqnr.item() >= 30.0, (
95+
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
96+
)
97+
98+
# validate param gradients
99+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
100+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
101+
assert param_grad_sqnr.item() >= 25.0, (
102+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
103+
)
104+
105+
106+
def _validate_model_conversion(
107+
root_module: nn.Module,
108+
target_fqns: list[str],
109+
):
110+
def _recursive_validate(
111+
module: nn.Module,
112+
cur_fqn: str,
113+
):
114+
is_allowed_module = cur_fqn in target_fqns
115+
116+
# check current module params
117+
for param_name, param in module.named_parameters(recurse=False):
118+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
119+
if is_converted_type:
120+
assert is_allowed_module, (
121+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
122+
)
123+
if not is_allowed_module:
124+
assert not is_converted_type, (
125+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
126+
)
127+
128+
# recursively check child modules
129+
for child_name, child_module in module.named_children():
130+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
131+
_recursive_validate(child_module, child_fqn)
132+
133+
_recursive_validate(root_module, "")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.prototype.moe_training.scaled_grouped_mm import (
2+
_scaled_grouped_mm as _scaled_grouped_mm,
3+
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
17+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.utils import (
21+
from torchao.prototype.moe_training.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
17+
from torchao.prototype.moe_training import _scaled_grouped_mm
1818

1919
device = torch.device("cuda")
2020

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.core.config import AOBaseConfig
6+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
7+
from torchao.quantization.transform_module import (
8+
register_quantize_module_handler,
9+
)
10+
11+
12+
class MoETrainingConfig(AOBaseConfig):
13+
pass
14+
15+
16+
@register_quantize_module_handler(MoETrainingConfig)
17+
def _moe_training_transform(
18+
module: nn.Module,
19+
config: MoETrainingConfig,
20+
) -> nn.Module:
21+
"""
22+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
23+
24+
Args:
25+
module: Module to modify.
26+
config: MoETrainingConfig which defines how to perform the MoE training transform.
27+
28+
Returns:
29+
nn.Module: The modified module with swapped parameters.
30+
"""
31+
out = swap_params(module)
32+
return out
33+
34+
35+
def swap_params(
36+
module: nn.Module,
37+
*,
38+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
39+
) -> nn.Module:
40+
"""
41+
Recurses through the nn.Module, recursively swapping the data tensor of
42+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
43+
passed the module_filter_fn, if specified.
44+
45+
Args:
46+
module: Module to modify.
47+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
48+
that pass the filter function will be swapped. The inputs to the
49+
filter function are the module instance, and the FQN.
50+
51+
Returns:
52+
nn.Module: The modified module with swapped linear layers.
53+
"""
54+
if isinstance(module, nn.Parameter) and (
55+
module_filter_fn is None or module_filter_fn(module, "")
56+
):
57+
if len(list(module.children())) > 0:
58+
raise AssertionError(
59+
f"Does not support a root nn.Parameter with children: {module}"
60+
)
61+
if not isinstance(module.data, ScaledGroupedMMTensor):
62+
new_data = ScaledGroupedMMTensor(module.data)
63+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
64+
return module
65+
66+
root_module = module
67+
68+
def post_order_traversal(
69+
module: nn.Module,
70+
cur_fqn: Optional[str] = None,
71+
parent_module: Optional[nn.Module] = None,
72+
):
73+
if cur_fqn is None:
74+
cur_fqn = ""
75+
76+
for child_module_name, child_module in module.named_children():
77+
if cur_fqn == "":
78+
new_fqn = child_module_name
79+
else:
80+
new_fqn = f"{cur_fqn}.{child_module_name}"
81+
82+
post_order_traversal(child_module, new_fqn, module)
83+
84+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
85+
for param_name, param in module.named_parameters(recurse=False):
86+
if not isinstance(param.data, ScaledGroupedMMTensor):
87+
new_param = nn.Parameter(
88+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
89+
)
90+
setattr(module, param_name, new_param)
91+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
92+
93+
post_order_traversal(root_module)
94+
return root_module
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
1+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
22
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
33
)
4-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
4+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
55
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
66
)

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py renamed to torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import triton
1717
import triton.language as tl
1818

19-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
19+
from torchao.prototype.moe_training.utils import _is_column_major
2020

2121
EPS = 1e-12
2222

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py renamed to torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from torchao.float8.config import ScalingGranularity
1212
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
13-
from torchao.prototype.scaled_grouped_mm.kernels import (
13+
from torchao.prototype.moe_training.kernels import (
1414
triton_fp8_col_major_jagged_colwise_scales,
1515
triton_fp8_row_major_jagged_rowwise_scales,
1616
)
17-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
17+
from torchao.prototype.moe_training.utils import _is_column_major
1818

1919

2020
def _scaled_grouped_mm(
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.moe_training import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

torchao/prototype/scaled_grouped_mm/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)