Skip to content

Commit a71744c

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
1 parent 1017c7e commit a71744c

File tree

3 files changed

+130
-1
lines changed

3 files changed

+130
-1
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
from torchao.prototype.scaled_grouped_mm.tensor import ScaledGroupedMMTensor
5+
6+
7+
def swap_params(
8+
module: nn.Module,
9+
*,
10+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
11+
) -> nn.Module:
12+
"""
13+
Recurses through the nn.Module, recursively swapping the data tensor of
14+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
15+
passed the module_filter_fn, if specified.
16+
17+
Args:
18+
module: Module to modify.
19+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
20+
that pass the filter function will be swapped. The inputs to the
21+
filter function are the module instance, and the FQN.
22+
23+
Returns:
24+
nn.Module: The modified module with swapped linear layers.
25+
"""
26+
if isinstance(module, nn.Parameter) and (
27+
module_filter_fn is None or module_filter_fn(module, "")
28+
):
29+
if len(list(module.children())) > 0:
30+
raise AssertionError(
31+
f"Does not support a root nn.Parameter with children: {module}"
32+
)
33+
if not isinstance(module.data, ScaledGroupedMMTensor):
34+
new_data = ScaledGroupedMMTensor(module.data)
35+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
36+
return module
37+
38+
root_module = module
39+
40+
def post_order_traversal(
41+
module: nn.Module,
42+
cur_fqn: Optional[str] = None,
43+
parent_module: Optional[nn.Module] = None,
44+
):
45+
if cur_fqn is None:
46+
cur_fqn = ""
47+
48+
for child_module_name, child_module in module.named_children():
49+
if cur_fqn == "":
50+
new_fqn = child_module_name
51+
else:
52+
new_fqn = f"{cur_fqn}.{child_module_name}"
53+
54+
post_order_traversal(child_module, new_fqn, module)
55+
56+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
57+
for param_name, param in module.named_parameters(recurse=False):
58+
if not isinstance(param.data, ScaledGroupedMMTensor):
59+
new_param = nn.Parameter(
60+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
61+
)
62+
setattr(module, param_name, new_param)
63+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
64+
65+
post_order_traversal(root_module)
66+
return root_module
67+
68+
69+
def convert_moe_to_float8_training(
70+
module: nn.Module,
71+
*,
72+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
73+
) -> nn.Module:
74+
"""
75+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
76+
77+
Args:
78+
module: Module to modify.
79+
module_filter_fn: If specified, only the `torch.nn.Parameter` instances of
80+
modules that pass the filter function will be swapped. The inputs to the
81+
filter function are the module instance and the FQN.
82+
83+
Returns:
84+
nn.Module: The modified module with swapped parameters.
85+
"""
86+
87+
out = swap_params(
88+
module,
89+
module_filter_fn=module_filter_fn,
90+
)
91+
return out

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

0 commit comments

Comments
 (0)