Skip to content

Commit eeabd9c

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1 parent d963a88 commit eeabd9c

File tree

5 files changed

+313
-4
lines changed

5 files changed

+313
-4
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import pytest
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import functional as F
6+
7+
from torchao.quantization.quant_api import quantize_
8+
from torchao.prototype.scaled_grouped_mm.conversion_utils import MoETrainingConfig
9+
from torchao.float8.float8_utils import compute_error
10+
11+
# model definition from torchtitan:
12+
# https://github.com/pytorch/torchtitan/blob/768cde131105bde624160029d808e94649faf0f4/torchtitan/experiments/llama4/model/moe.py#L14
13+
class GroupedExperts(nn.Module):
14+
def __init__(
15+
self,
16+
dim: int,
17+
hidden_dim: int,
18+
num_experts: int,
19+
use_grouped_mm: bool,
20+
):
21+
super().__init__()
22+
self.num_experts = num_experts
23+
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
24+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
25+
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26+
self.use_grouped_mm = use_grouped_mm
27+
self.init_weights()
28+
29+
def forward(
30+
self,
31+
x: torch.Tensor,
32+
num_local_tokens_per_expert: torch.Tensor | list[int] | None = None,
33+
) -> torch.Tensor:
34+
# TODO: keeping this for loop implementation for comparison
35+
# and readability, will remove later
36+
if not self.use_grouped_mm:
37+
if num_local_tokens_per_expert is not None:
38+
# a tuple of tensors indexed by experts
39+
# each with shape (tokens_per_expert(varying), dim)
40+
x = torch.split(
41+
x,
42+
split_size_or_sections=num_local_tokens_per_expert,
43+
dim=0,
44+
)
45+
out_experts_splits = []
46+
for expert_idx, x_expert in enumerate(x):
47+
w1, w2, w3 = (
48+
self.w1[expert_idx],
49+
self.w2[expert_idx],
50+
self.w3[expert_idx],
51+
)
52+
h = F.silu(torch.matmul(x_expert, w1))
53+
h = h * torch.matmul(x_expert, w3)
54+
h = torch.matmul(h, w2)
55+
# h shape (tokens_per_expert(varying), dim)
56+
out_experts_splits.append(h)
57+
out = torch.cat(out_experts_splits, dim=0)
58+
else:
59+
# x shape (num_experts, tokens_per_expert, dim)
60+
h = F.silu(torch.bmm(x, self.w1))
61+
h = h * torch.bmm(x, self.w3)
62+
# out shape (num_experts, tokens_per_expert, dim)
63+
out = torch.bmm(h, self.w2)
64+
65+
return out
66+
67+
# grouped mm implementation
68+
if num_local_tokens_per_expert is not None:
69+
# https://github.com/pytorch/pytorch/pull/150374
70+
# NOTE: torch._gouped_mm requires bf16 dtypes
71+
# and shapes to be multiple of 8
72+
offsets = torch.cumsum(
73+
num_local_tokens_per_expert, dim=0, dtype=torch.int32
74+
)
75+
# grouped mm between a 2D tensor and a 3D tensor
76+
assert x.dim() == 2
77+
else:
78+
offsets = None
79+
# fall back to regular bmm between 3D tensors
80+
assert x.dim() == 3
81+
82+
assert (
83+
x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16
84+
), "torch._grouped_mm only supports bf16 dtypes"
85+
86+
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
87+
h = h * torch._grouped_mm(x, self.w3, offs=offsets)
88+
out = torch._grouped_mm(h, self.w2, offs=offsets)
89+
90+
return out
91+
92+
def init_weights(self, init_std: float = 0.02):
93+
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
94+
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
95+
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
96+
97+
class MoE(nn.Module):
98+
"""Toy MoE for testing. Not a complete implementation."""
99+
def __init__(self,
100+
dim: int,
101+
hidden_dim: int,
102+
num_experts: int,
103+
use_grouped_mm: bool
104+
):
105+
super().__init__()
106+
self.gate = nn.Linear(dim, num_experts)
107+
self.experts = GroupedExperts(
108+
dim,
109+
hidden_dim,
110+
num_experts,
111+
use_grouped_mm,
112+
)
113+
self.init_weights()
114+
115+
def forward(self, x: torch.Tensor, num_local_tokens_per_expert: torch.Tensor) -> torch.Tensor:
116+
return self.experts(x, num_local_tokens_per_expert=num_local_tokens_per_expert)
117+
118+
def init_weights(self, init_std: float = 0.02):
119+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
120+
121+
@pytest.mark.parametrize(
122+
"model_class,target_fqns", [
123+
# (MoE, ["experts"]), # calling quantize_ on higher level module
124+
(GroupedExperts, [""]), # calling quantize_ on experts directly
125+
])
126+
def test_moe_float8_training(model_class: nn.Module, target_fqns: list[str]):
127+
batch, seq, dim = 1, 8192, 4096
128+
num_experts, top_k = 2, 1
129+
130+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
131+
for target_fqn in target_fqns:
132+
if target_fqn in cur_fqn:
133+
return True
134+
return False
135+
136+
# define MoE layer
137+
torch.manual_seed(42)
138+
model = model_class(dim=dim, hidden_dim=4*dim, num_experts=num_experts, use_grouped_mm=True).to(torch.bfloat16).cuda()
139+
torch.manual_seed(42)
140+
ref_model = model_class(dim=dim, hidden_dim=4*dim, num_experts=num_experts, use_grouped_mm=True).to(torch.bfloat16).cuda()
141+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
142+
assert torch.equal(param1, param2)
143+
144+
# convert MoE to float8 training
145+
config = MoETrainingConfig()
146+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
147+
148+
# inputs
149+
torch.manual_seed(42)
150+
x = torch.randn(batch*seq*top_k, dim, dtype=torch.bfloat16, requires_grad=True).cuda()
151+
torch.manual_seed(42)
152+
ref_x = torch.randn(batch*seq*top_k, dim, dtype=torch.bfloat16, requires_grad=True).cuda()
153+
154+
# offsets
155+
num_tokens_per_expert = (batch * seq * top_k) // num_experts
156+
tokens_per_expert_tensor = torch.tensor([num_tokens_per_expert], dtype=torch.int32).repeat(num_experts).cuda()
157+
ref_tokens_per_expert_tensor = tokens_per_expert_tensor.clone()
158+
159+
# forward pass
160+
out = model(x, num_local_tokens_per_expert=tokens_per_expert_tensor)
161+
ref_out = ref_model(ref_x, num_local_tokens_per_expert=ref_tokens_per_expert_tensor)
162+
163+
# validate SQNR is acceptable.
164+
# a single fp8 gemm uses SQNR >= 25.0 for testing, so for a full MoE layer
165+
# we'll use a slightly lower threshold.
166+
out_sqnr = compute_error(out, ref_out)
167+
assert out_sqnr.item() >= 23.0, f"SQNR must be >= 23.0, got {out_sqnr.item()}."
168+
169+
# backward pass
170+
out.sum().backward()
171+
ref_out.sum().backward()
172+
173+
# validate input gradients
174+
assert torch.allclose(x.grad, ref_x.grad)
175+
176+
# validate param gradients
177+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
178+
assert torch.allclose(param1.grad, param2.grad)

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Float8LinearConfig,
2424
Float8LinearRecipeName,
2525
)
26-
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
26+
from torchao.float8.float8_linear import _matmul_with_hp_or_float8_args
2727
from torchao.float8.float8_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
2929
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
@@ -183,7 +183,7 @@ def compute_reference_forward(
183183

184184
# Validate each actual result group from the _scaled_grouped_mm is equal to:
185185
# 1. A manual _scaled_mm for the group.
186-
# 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients).
186+
# 2. A _matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients).
187187
outputs = []
188188
list1 = list(zip(A_list_fp8, B_t_fp8, A_scale_list, B_t_scales, result_list))
189189
list2 = list(zip(A_list, B_t, result_list))
@@ -199,7 +199,7 @@ def compute_reference_forward(
199199
use_fast_accum=float8_config.gemm_config_output.use_fast_accum,
200200
)
201201
a2, b2, result2 = list2[i]
202-
ref_group_result2 = matmul_with_hp_or_float8_args.apply(
202+
ref_group_result2 = _matmul_with_hp_or_float8_args.apply(
203203
a2,
204204
b2,
205205
LinearMMConfig(),
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.core.config import AOBaseConfig
6+
from torchao.quantization.transform_module import (
7+
register_quantize_module_handler,
8+
)
9+
from torchao.prototype.scaled_grouped_mm.tensor import ScaledGroupedMMTensor
10+
11+
12+
class MoETrainingConfig(AOBaseConfig):
13+
pass
14+
15+
16+
@register_quantize_module_handler(MoETrainingConfig)
17+
def _moe_training_transform(
18+
module: nn.Module,
19+
config: MoETrainingConfig,
20+
) -> nn.Module:
21+
"""
22+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
23+
24+
Args:
25+
module: Module to modify.
26+
config: MoETrainingConfig which defines how to perform the MoE training transform.
27+
28+
Returns:
29+
nn.Module: The modified module with swapped parameters.
30+
"""
31+
out = swap_params(module)
32+
return out
33+
34+
def swap_params(
35+
module: nn.Module,
36+
*,
37+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
38+
) -> nn.Module:
39+
"""
40+
Recurses through the nn.Module, recursively swapping the data tensor of
41+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
42+
passed the module_filter_fn, if specified.
43+
44+
Args:
45+
module: Module to modify.
46+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
47+
that pass the filter function will be swapped. The inputs to the
48+
filter function are the module instance, and the FQN.
49+
50+
Returns:
51+
nn.Module: The modified module with swapped linear layers.
52+
"""
53+
if isinstance(module, nn.Parameter) and (
54+
module_filter_fn is None or module_filter_fn(module, "")
55+
):
56+
if len(list(module.children())) > 0:
57+
raise AssertionError(
58+
f"Does not support a root nn.Parameter with children: {module}"
59+
)
60+
if not isinstance(module.data, ScaledGroupedMMTensor):
61+
new_data = ScaledGroupedMMTensor(module.data)
62+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
63+
return module
64+
65+
root_module = module
66+
67+
def post_order_traversal(
68+
module: nn.Module,
69+
cur_fqn: Optional[str] = None,
70+
parent_module: Optional[nn.Module] = None,
71+
):
72+
if cur_fqn is None:
73+
cur_fqn = ""
74+
75+
for child_module_name, child_module in module.named_children():
76+
if cur_fqn == "":
77+
new_fqn = child_module_name
78+
else:
79+
new_fqn = f"{cur_fqn}.{child_module_name}"
80+
81+
post_order_traversal(child_module, new_fqn, module)
82+
83+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
84+
for param_name, param in module.named_parameters(recurse=False):
85+
if not isinstance(param.data, ScaledGroupedMMTensor):
86+
new_param = nn.Parameter(
87+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
88+
)
89+
setattr(module, param_name, new_param)
90+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
91+
92+
post_order_traversal(root_module)
93+
return root_module

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

0 commit comments

Comments
 (0)