You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using FSDP2 for Float8 training, an issue occurs when the number of GPUs exceeds the out_features of an nn.Linear layer. Specifically, FSDP2 splits the weight tensor into a shape of [0, in_features] in some ranks, which causes an error during tensor-wise FP8 training here:
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.
RuntimeError: setStorage: sizes [4, 16], strides [16, 1], storage offset 0, and itemsize 1 requiring a storage size of 64 are out of bounds for storage of size 0
Here is my complete reproducible code:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
class Float8Handler:
def __init__(self,
enable_float8_linear=True, enable_fsdp_float8_all_gather=True,
precompute_float8_dynamic_scale_for_fsdp=True,
scaling_type_input='dynamic', scaling_type_weight='dynamic', scaling_type_grad_output='dynamic',
scaling_granularity_input='tensorwise', scaling_granularity_weight='tensorwise', scaling_granularity_grad_output='tensorwise',
compile=True, pad_inner_dim=False,
):
self.enabled = False
if not enable_float8_linear:
return
try:
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.config import ScalingGranularity
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use float8 linear layers."
) from e
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity_input = ScalingGranularity(scaling_granularity_input)
scaling_granularity_weight = ScalingGranularity(scaling_granularity_weight)
scaling_granularity_grad_output = ScalingGranularity(scaling_granularity_grad_output)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_input=CastConfig(scaling_type=scaling_type_input, scaling_granularity=scaling_granularity_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight, scaling_granularity=scaling_granularity_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output, scaling_granularity=scaling_granularity_grad_output),
enable_pre_and_post_forward=False,
pad_inner_dim=pad_inner_dim
)
self.enabled = True
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and precompute_float8_dynamic_scale_for_fsdp
)
# for sync_float8_amax_and_scale_history
self.delayed_scaling = (
scaling_type_input == "delayed"
or scaling_type_weight == "delayed"
or scaling_type_grad_output == "delayed"
)
self._sync_float8_amax_and_scale_history = None
self.compile = compile
def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Note that today, only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
if not self.enabled:
return
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(
model,
config=self.config,
)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(16, 4, bias=True)
def forward(self, x):
return self.fc(x)
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', )
world_size = dist.get_world_size()
world_mesh = init_device_mesh('cuda', (world_size, ), mesh_dim_names=("world",))['world']
float8_handler = Float8Handler(
compile=True,
enable_fsdp_float8_all_gather=True,
pad_inner_dim=True,
)
model = Model().cuda().to(torch.bfloat16)
float8_handler.convert_to_float8_training(model)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16)
fully_shard(
model,
mesh=world_mesh,
mp_policy=mp_policy,
reshard_after_forward=True
)
print(model.fc.weight.to_local().shape)
x = torch.randn(16, 16, requires_grad=True, device='cuda', dtype=torch.bfloat16)
out = model(x)
out.mean().backward()
torch 2.6.0+cu126
torchao 0.9.0+cu126
Any suggestions?
The text was updated successfully, but these errors were encountered:
IMHO, you probably should not be using FP8 when you have such a small out_features size, but having a better check/error message for this might be good.
As Andrew noted, the overhead of float8 quantization will outweigh the benefit of using float8 GEMMs when the shape sizes are this small. See this table in the docs which show for what shapes float8 is faster than bf16.
In my actual use case, I'm working with a Float8Linear Module where out_features = 512 and in_features = 5120, running on 1024 GPUs. I encountered the same error as mentioned above. To make it easier to reproduce the issue, I scaled down the configuration to 8 GPUs and proportionally adjusted out_features to 4.
As a temporary workaround, I implemented a rather hacky solution by flattening the linear weights and reshaping them to (out_features, in_features) during the fwd/bwd pass. This tweak squashed the error, but it's not cool and I'm looking for a better way to handle this.
When using FSDP2 for Float8 training, an issue occurs when the number of GPUs exceeds the out_features of an nn.Linear layer. Specifically, FSDP2 splits the weight tensor into a shape of [0, in_features] in some ranks, which causes an error during tensor-wise FP8 training here:
To address this, I modified the code as follows:
However, this introduces another issue:
Here is my complete reproducible code:
Any suggestions?
The text was updated successfully, but these errors were encountered: