Skip to content

Commit a10b3a0

Browse files
float8 moe training conversion API prototype
stack-info: PR: #2275, branch: danielvegamyhre/stack/1
1 parent d963a88 commit a10b3a0

File tree

3 files changed

+131
-1
lines changed

3 files changed

+131
-1
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.prototype.scaled_grouped_mm.tensor import ScaledGroupedMMTensor
6+
7+
8+
def swap_params(
9+
module: nn.Module,
10+
*,
11+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
12+
) -> nn.Module:
13+
"""
14+
Recurses through the nn.Module, recursively swapping the data tensor of
15+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
16+
passed the module_filter_fn, if specified.
17+
18+
Args:
19+
module: Module to modify.
20+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
21+
that pass the filter function will be swapped. The inputs to the
22+
filter function are the module instance, and the FQN.
23+
24+
Returns:
25+
nn.Module: The modified module with swapped linear layers.
26+
"""
27+
if isinstance(module, nn.Parameter) and (
28+
module_filter_fn is None or module_filter_fn(module, "")
29+
):
30+
if len(list(module.children())) > 0:
31+
raise AssertionError(
32+
f"Does not support a root nn.Parameter with children: {module}"
33+
)
34+
if not isinstance(module.data, ScaledGroupedMMTensor):
35+
new_data = ScaledGroupedMMTensor(module.data)
36+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
37+
return module
38+
39+
root_module = module
40+
41+
def post_order_traversal(
42+
module: nn.Module,
43+
cur_fqn: Optional[str] = None,
44+
parent_module: Optional[nn.Module] = None,
45+
):
46+
if cur_fqn is None:
47+
cur_fqn = ""
48+
49+
for child_module_name, child_module in module.named_children():
50+
if cur_fqn == "":
51+
new_fqn = child_module_name
52+
else:
53+
new_fqn = f"{cur_fqn}.{child_module_name}"
54+
55+
post_order_traversal(child_module, new_fqn, module)
56+
57+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
58+
for param_name, param in module.named_parameters(recurse=False):
59+
if not isinstance(param.data, ScaledGroupedMMTensor):
60+
new_param = nn.Parameter(
61+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
62+
)
63+
setattr(module, param_name, new_param)
64+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
65+
66+
post_order_traversal(root_module)
67+
return root_module
68+
69+
# TODO: migrate to quantize_, will need to define a config for this etc
70+
def convert_moe_to_float8_training(
71+
module: nn.Module,
72+
*,
73+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
74+
) -> nn.Module:
75+
"""
76+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
77+
78+
Args:
79+
module: Module to modify.
80+
module_filter_fn: If specified, only the `torch.nn.Parameter` instances of
81+
modules that pass the filter function will be swapped. The inputs to the
82+
filter function are the module instance and the FQN.
83+
84+
Returns:
85+
nn.Module: The modified module with swapped parameters.
86+
"""
87+
88+
out = swap_params(
89+
module,
90+
module_filter_fn=module_filter_fn,
91+
)
92+
return out

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

0 commit comments

Comments
 (0)