-
Notifications
You must be signed in to change notification settings - Fork 283
float8 moe training conversion API prototype #2275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import copy | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
# this feature requires CUDA and SM89+ | ||
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): | ||
pytest.skip( | ||
"CUDA not available or compute capability < 8.9", allow_module_level=True | ||
) | ||
|
||
from torchao.float8.float8_utils import compute_error | ||
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig | ||
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor | ||
from torchao.quantization.quant_api import quantize_ | ||
|
||
# this test requires torchtitan | ||
try: | ||
from torchtitan.experiments.llama4.model.args import TransformerModelArgs | ||
from torchtitan.experiments.llama4.model.moe import MoE | ||
except ImportError: | ||
import warnings | ||
|
||
warnings.warn("torchtitan not installed, skipping MoE tests.") | ||
pytest.skip(allow_module_level=True) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"target_fqns", | ||
[ | ||
["experts"], | ||
["does.not.exist"], | ||
], | ||
) | ||
def test_moe_float8_training(target_fqns: list[str]): | ||
model_args = TransformerModelArgs( | ||
moe_enabled=True, | ||
num_experts=8, | ||
dim=256, | ||
) | ||
init_std = 0.02 | ||
device = torch.device("cuda") | ||
|
||
# reference bf16 MoE | ||
ref_model = MoE(model_args).to(torch.bfloat16).cuda() | ||
torch.manual_seed(42) | ||
ref_model.init_weights(init_std, device) | ||
|
||
# target MoE for testing conversion | ||
model = copy.deepcopy(ref_model) | ||
|
||
# assert starting params are identical for both models | ||
for param1, param2 in zip(model.parameters(), ref_model.parameters()): | ||
assert torch.equal(param1, param2) | ||
|
||
# convert MoE to float8 training | ||
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: | ||
for target_fqn in target_fqns: | ||
if target_fqn in cur_fqn: | ||
return True | ||
return False | ||
|
||
# quantize test model | ||
config = MoETrainingConfig() | ||
quantize_(model, config=config, filter_fn=moe_module_filter_fn) | ||
|
||
# validate that only the experts were converted | ||
_validate_model_conversion( | ||
model, | ||
target_fqns=target_fqns, | ||
) | ||
|
||
# inputs | ||
batch, seq, dim = 8, 2048, 256 | ||
ref_x = torch.randn( | ||
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device | ||
) | ||
x = ref_x.detach().clone().requires_grad_(True) | ||
|
||
# forward pass | ||
ref_out = ref_model(ref_x) | ||
out = model(x) | ||
|
||
# validate output | ||
out_sqnr = compute_error(out, ref_out) | ||
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." | ||
|
||
# compute loss | ||
labels = torch.ones_like(ref_out) | ||
ref_loss = F.mse_loss(ref_out, labels) | ||
out_loss = F.mse_loss(out, labels) | ||
|
||
# backward pass | ||
ref_loss.backward() | ||
out_loss.backward() | ||
|
||
# validate input gradient | ||
input_grad_sqnr = compute_error(x.grad, ref_x.grad) | ||
assert input_grad_sqnr.item() >= 30.0, ( | ||
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." | ||
) | ||
|
||
# validate param gradients | ||
for param1, param2 in zip(model.parameters(), ref_model.parameters()): | ||
param_grad_sqnr = compute_error(param1.grad, param2.grad) | ||
assert param_grad_sqnr.item() >= 25.0, ( | ||
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." | ||
) | ||
|
||
|
||
def _validate_model_conversion( | ||
root_module: nn.Module, | ||
target_fqns: list[str], | ||
): | ||
def _recursive_validate( | ||
module: nn.Module, | ||
cur_fqn: str, | ||
): | ||
is_allowed_module = cur_fqn in target_fqns | ||
|
||
# check current module params | ||
for param_name, param in module.named_parameters(recurse=False): | ||
is_converted_type = isinstance(param, ScaledGroupedMMTensor) | ||
if is_converted_type: | ||
assert is_allowed_module, ( | ||
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." | ||
) | ||
if not is_allowed_module: | ||
assert not is_converted_type, ( | ||
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." | ||
) | ||
|
||
# recursively check child modules | ||
for child_name, child_module in module.named_children(): | ||
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name | ||
_recursive_validate(child_module, child_fqn) | ||
|
||
_recursive_validate(root_module, "") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torchao.prototype.moe_training.scaled_grouped_mm import _scaled_grouped_mm | ||
|
||
__all__ = ["_scaled_grouped_mm"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from typing import Callable, Optional | ||
|
||
from torch import nn | ||
|
||
from torchao.core.config import AOBaseConfig | ||
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor | ||
from torchao.quantization.transform_module import ( | ||
register_quantize_module_handler, | ||
) | ||
|
||
|
||
class MoETrainingConfig(AOBaseConfig): | ||
""" | ||
The MoETrainingConfig is specifically designed to be used on MoE models using | ||
`torch._grouped_mm` to implement expert computation in token-choice routing, | ||
where expert weights are implemented as 3D nn.Parameters wit `num_experts` as | ||
the leading dim. | ||
|
||
MoETrainingConfig has a module handler registered to it which will | ||
find all nn.Parameters whose parent module matches the module filter function, | ||
and swap their data tensor with a ScaledGroupedMMTensor. | ||
|
||
The ScaledGroupedMMTensor is a tensor subclass which overrides the | ||
`torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm, | ||
which performs dynamic float8 rowwise quantization on scaled grouped GEMM | ||
operands in both the forward and backward pass. | ||
|
||
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. | ||
""" | ||
|
||
pass | ||
|
||
|
||
@register_quantize_module_handler(MoETrainingConfig) | ||
def _moe_training_transform( | ||
module: nn.Module, | ||
config: MoETrainingConfig, | ||
) -> nn.Module: | ||
""" | ||
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor. | ||
|
||
Args: | ||
module: Module to modify. | ||
config: MoETrainingConfig which defines how to perform the MoE training transform. | ||
|
||
Returns: | ||
nn.Module: The modified module with swapped parameters. | ||
""" | ||
out = _swap_params(module) | ||
return out | ||
|
||
|
||
def _swap_params( | ||
module: nn.Module, | ||
*, | ||
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, | ||
) -> nn.Module: | ||
""" | ||
Recurses through the nn.Module, recursively swapping the data tensor of | ||
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module | ||
passed the module_filter_fn, if specified. | ||
|
||
Args: | ||
module: Module to modify. | ||
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that | ||
that pass the filter function will be swapped. The inputs to the | ||
filter function are the module instance, and the FQN. | ||
|
||
Returns: | ||
nn.Module: The modified module with swapped linear layers. | ||
""" | ||
if isinstance(module, nn.Parameter) and ( | ||
module_filter_fn is None or module_filter_fn(module, "") | ||
): | ||
if len(list(module.children())) > 0: | ||
raise AssertionError( | ||
f"Does not support a root nn.Parameter with children: {module}" | ||
) | ||
if not isinstance(module.data, ScaledGroupedMMTensor): | ||
new_data = ScaledGroupedMMTensor(module.data) | ||
return nn.Parameter(new_data, requires_grad=module.requires_grad) | ||
return module | ||
|
||
root_module = module | ||
|
||
def post_order_traversal( | ||
module: nn.Module, | ||
cur_fqn: Optional[str] = None, | ||
parent_module: Optional[nn.Module] = None, | ||
): | ||
if cur_fqn is None: | ||
cur_fqn = "" | ||
|
||
for child_module_name, child_module in module.named_children(): | ||
if cur_fqn == "": | ||
new_fqn = child_module_name | ||
else: | ||
new_fqn = f"{cur_fqn}.{child_module_name}" | ||
|
||
post_order_traversal(child_module, new_fqn, module) | ||
|
||
if module_filter_fn is None or module_filter_fn(module, cur_fqn): | ||
for param_name, param in module.named_parameters(recurse=False): | ||
if not isinstance(param.data, ScaledGroupedMMTensor): | ||
new_param = nn.Parameter( | ||
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad | ||
) | ||
setattr(module, param_name, new_param) | ||
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor") | ||
|
||
post_order_traversal(root_module) | ||
return root_module |
4 changes: 2 additions & 2 deletions
4
...ype/scaled_grouped_mm/kernels/__init__.py → ...rototype/moe_training/kernels/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( | ||
from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( | ||
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales, | ||
) | ||
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( | ||
from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( | ||
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
|
||
from torchao.prototype.moe_training import _scaled_grouped_mm | ||
|
||
|
||
class ScaledGroupedMMTensor(torch.Tensor): | ||
""" | ||
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor | ||
and overrides the torch._grouped_mm op by dispatching to the | ||
differentiable _scaled_grouped_mm autograd function. | ||
""" | ||
|
||
grouped_mm_func_name = "_grouped_mm" | ||
offs_arg_name = "offs" | ||
|
||
def __init__(self, data: torch.Tensor): | ||
self._data = data | ||
|
||
@classmethod | ||
def __torch_function__(cls, func, types, args, kwargs={}): | ||
if func.__name__ == cls.grouped_mm_func_name: | ||
# Use torchao scaled grouped mm with dynamic quant for | ||
# "2d x 3d with offsets" case (used for routed experts). | ||
# Otherwise, fall back to regular grouped mm. | ||
# | ||
# TODO: support "3d x 3d without offsets" case, which is | ||
# used for shared experts. This is basically the grouped_mm | ||
# kernel handling a bmm. | ||
A, B = args[0], args[1] | ||
A_is_2d = A.dim() == 2 | ||
B_is_3d = B.dim() == 3 | ||
has_offs = kwargs.get(cls.offs_arg_name) is not None | ||
if A_is_2d and B_is_3d and has_offs: | ||
return _scaled_grouped_mm(*args, **kwargs) | ||
return super().__torch_function__(func, types, args, kwargs) |
File renamed without changes.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.