Skip to content

[Bug] FSDP2 FP8 compatibility problem with nn.Linear layers (GPU count > out_features) #1938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
HIT-cwh opened this issue Mar 24, 2025 · 4 comments

Comments

@HIT-cwh
Copy link

HIT-cwh commented Mar 24, 2025

When using FSDP2 for Float8 training, an issue occurs when the number of GPUs exceeds the out_features of an nn.Linear layer. Specifically, FSDP2 splits the weight tensor into a shape of [0, in_features] in some ranks, which causes an error during tensor-wise FP8 training here:

RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

Image

To address this, I modified the code as follows:

if x.numel() == 0:
    amax = torch.tensor(0., device=x.device, dtype=x.dtype)
else:
    amax = torch.max(torch.abs(x))

However, this introduces another issue:

RuntimeError: setStorage: sizes [4, 16], strides [16, 1], storage offset 0, and itemsize 1 requiring a storage size of 64 are out of bounds for storage of size 0

Image

Here is my complete reproducible code:

import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed._composable.fsdp import (
    CPUOffloadPolicy,
    MixedPrecisionPolicy,
    fully_shard,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

class Float8Handler:
    def __init__(self, 
        enable_float8_linear=True, enable_fsdp_float8_all_gather=True,
        precompute_float8_dynamic_scale_for_fsdp=True,
        scaling_type_input='dynamic', scaling_type_weight='dynamic', scaling_type_grad_output='dynamic',
        scaling_granularity_input='tensorwise', scaling_granularity_weight='tensorwise', scaling_granularity_grad_output='tensorwise',
        compile=True, pad_inner_dim=False,
    ):
        self.enabled = False

        if not enable_float8_linear:
            return

        try:
            from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
            from torchao.float8.config import ScalingGranularity
        except ImportError as e:
            raise ImportError(
                "torchao is not installed. Please install it to use float8 linear layers."
            ) from e

        # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
        enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather
        
        scaling_type_input = ScalingType(scaling_type_input)
        scaling_type_weight = ScalingType(scaling_type_weight)
        scaling_type_grad_output = ScalingType(scaling_type_grad_output)
        scaling_granularity_input = ScalingGranularity(scaling_granularity_input)
        scaling_granularity_weight = ScalingGranularity(scaling_granularity_weight)
        scaling_granularity_grad_output = ScalingGranularity(scaling_granularity_grad_output)
        self.config = Float8LinearConfig(
            enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
            cast_config_input=CastConfig(scaling_type=scaling_type_input, scaling_granularity=scaling_granularity_input),
            cast_config_weight=CastConfig(scaling_type=scaling_type_weight, scaling_granularity=scaling_granularity_weight),
            cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output, scaling_granularity=scaling_granularity_grad_output),
            enable_pre_and_post_forward=False,
            pad_inner_dim=pad_inner_dim
        )

        self.enabled = True

        # for precompute_float8_dynamic_scale_for_fsdp
        self.precompute_scale = (
            enable_fsdp_float8_all_gather
            and precompute_float8_dynamic_scale_for_fsdp
        )

        # for sync_float8_amax_and_scale_history
        self.delayed_scaling = (
            scaling_type_input == "delayed"
            or scaling_type_weight == "delayed"
            or scaling_type_grad_output == "delayed"
        )
        self._sync_float8_amax_and_scale_history = None
        self.compile = compile

    def convert_to_float8_training(self, model: nn.Module):
        """
        This function converts the linear layers of `model` to `Float8Linear`.
        Note that today, only dynamic tensor scaling (the default) is supported.
        This will mutate the model inplace.
        """
        
        if not self.enabled:
            return

        from torchao.float8 import convert_to_float8_training

        convert_to_float8_training(
            model,
            config=self.config,
        )


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(16, 4, bias=True)
    
    def forward(self, x):
        return self.fc(x)

rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', )

world_size = dist.get_world_size()
world_mesh = init_device_mesh('cuda', (world_size, ), mesh_dim_names=("world",))['world']
float8_handler = Float8Handler(
    compile=True,
    enable_fsdp_float8_all_gather=True,
    pad_inner_dim=True,
)

model = Model().cuda().to(torch.bfloat16)
float8_handler.convert_to_float8_training(model)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16)
fully_shard(
    model,
    mesh=world_mesh,
    mp_policy=mp_policy,
    reshard_after_forward=True
)

print(model.fc.weight.to_local().shape)

x = torch.randn(16, 16, requires_grad=True, device='cuda', dtype=torch.bfloat16)
out = model(x)
out.mean().backward()
torch                      2.6.0+cu126
torchao                  0.9.0+cu126

Any suggestions?

@awgu
Copy link
Contributor

awgu commented Mar 24, 2025

IMHO, you probably should not be using FP8 when you have such a small out_features size, but having a better check/error message for this might be good.

@danielvegamyhre
Copy link
Contributor

As Andrew noted, the overhead of float8 quantization will outweigh the benefit of using float8 GEMMs when the shape sizes are this small. See this table in the docs which show for what shapes float8 is faster than bf16.

@drisspg drisspg added the float8 label Mar 24, 2025
@HIT-cwh
Copy link
Author

HIT-cwh commented Mar 25, 2025

Hi @awgu @danielvegamyhre ! Thanks a lot for your responses.

In my actual use case, I'm working with a Float8Linear Module where out_features = 512 and in_features = 5120, running on 1024 GPUs. I encountered the same error as mentioned above. To make it easier to reproduce the issue, I scaled down the configuration to 8 GPUs and proportionally adjusted out_features to 4.

As a temporary workaround, I implemented a rather hacky solution by flattening the linear weights and reshaping them to (out_features, in_features) during the fwd/bwd pass. This tweak squashed the error, but it's not cool and I'm looking for a better way to handle this.

@cassanof
Copy link

Been running into this as well, a workaround is to use HSDP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants