Skip to content

Can FP8 GEMM be enabled via module hooks instead of module swapping? #1887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zigzagcai opened this issue Mar 14, 2025 · 7 comments
Open

Comments

@zigzagcai
Copy link

Hi developers,

Thanks for such a great project!

I want to integrate torchao FP8 GEMM into our training framework. But in my framework, the linear layers are defined in customized modules (where we implement Tensor Parallel or ZeRO3 weight parallel), so it is hard to directly swap the linear layers with torchao Float8Linear.

So, can FP8 GEMM enabled via a more friendly way, such like module hooks? Since module swapping is not so flexible

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Mar 14, 2025

For float8 training, module swap is currently the only supported method.

However, you could implement a module hook which uses our float8 quantization primitives. Then in the forward/backward pass, any mm/matmul ops operating on these float8 tensors would be handled by torch._scaled_mm, which dispatches to cublas (for tensorwise scales) or cutlass (for rowwise scales) for the actual GEMMs.

We could potentially provide such a hook in torchao, though it is not currently planned. I think it could be useful though for use cases like this. @vkuzo @drisspg any thoughts on this?

@vkuzo
Copy link
Contributor

vkuzo commented Mar 14, 2025

We started with module swapping because it's the easiest way to iterate on performance/accuracy/usability. If we can cover more important use cases with alternate UX options, I'm in favor.

@zigzagcai , a couple of questions for you.

  1. are you using torch.compile / are you open to using torch.compile
  2. could you share a pointer to your code if you have it

From what you have shared so far, a tensor subclass weight wrapper sounds like the right UX, but would be good to see the callsites to confirm that.

@zigzagcai
Copy link
Author

zigzagcai commented Mar 17, 2025

We started with module swapping because it's the easiest way to iterate on performance/accuracy/usability. If we can cover more important use cases with alternate UX options, I'm in favor.

@zigzagcai , a couple of questions for you.

  1. are you using torch.compile / are you open to using torch.compile
  2. could you share a pointer to your code if you have it

From what you have shared so far, a tensor subclass weight wrapper sounds like the right UX, but would be good to see the callsites to confirm that.

Hi @vkuzo @danielvegamyhre

Thanks for your follow-up!

  1. we are open to use torch.compile
  2. The code of our ZeRO3 weight parallel implementation is here: https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L171-L315

The basic idea of our ZeRO3 weight parallel implementation:
In WPFusedDenseFunc, we all-gather weights in the fwd pass, then all-gather weights and reduce-scatter gradients in bwd pass. And we just apply this customized autograd function to https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L532-L678

So, I just wander how could we integrate torchao FP8 with our customized ZeRO3 weight parallel implementation?

@vkuzo
Copy link
Contributor

vkuzo commented Mar 18, 2025

thanks, @zigzagcai . I think a tensor subclass weight wrapper is promising for your use case.

We have a prototype feature for quantized training with int8 with a tensor subclass weight wrapping UX here: https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision-training . Would you be up for trying this and seeing if it works with your use case? If yes, that would be great signal for adding an option for this UX for float8.

@zigzagcai
Copy link
Author

thanks, @zigzagcai . I think a tensor subclass weight wrapper is promising for your use case.

We have a prototype feature for quantized training with int8 with a tensor subclass weight wrapping UX here: https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision-training . Would you be up for trying this and seeing if it works with your use case? If yes, that would be great signal for adding an option for this UX for float8.

Thanks @vkuzo
Could you please share a pointer to the tensor subclass wrapper?
And I see it is for INT8 training,so how could it be adjusted to FP8 training.

@vkuzo
Copy link
Contributor

vkuzo commented Mar 18, 2025

Thanks @vkuzo

Could you please share a pointer to the tensor subclass wrapper?

Yes, when you do the following

quantize_(model, int8_mixed_precision_training())

Then the weights of torch.nn.Linear modules will be swapped with Int8MixedPrecisionTrainingLinearWeight (code: https://github.com/pytorch/ao/blob/main/torchao/prototype/quantized_training/int8_mixed_precision.py#L300C22-L300C60). You may need to adjust filter_fn to apply this to your custom linear modules, something like

quantize_(model, int8_mixed_precision_training(), filter_fn=lambda mod, fqn: return isinstance(mod, YourCustomModule))

And I see it is for INT8 training,so how could it be adjusted to FP8 training.

Oh, I'm just asking if you're up for trying the int8 training feature to see if it already works for the way you set up your custom linear modules. If it does, that will make it easy for us to see if we can add an equivalent float8 UX in the future. If not, we'd love to learn what didn't work, which would help us brainstorm how we can solve your issue.

@zigzagcai
Copy link
Author

Thanks @vkuzo

Could you please share a pointer to the tensor subclass wrapper?

Yes, when you do the following

quantize_(model, int8_mixed_precision_training())
Then the weights of torch.nn.Linear modules will be swapped with Int8MixedPrecisionTrainingLinearWeight (code: https://github.com/pytorch/ao/blob/main/torchao/prototype/quantized_training/int8_mixed_precision.py#L300C22-L300C60). You may need to adjust filter_fn to apply this to your custom linear modules, something like

quantize_(model, int8_mixed_precision_training(), filter_fn=lambda mod, fqn: return isinstance(mod, YourCustomModule))

And I see it is for INT8 training,so how could it be adjusted to FP8 training.

Oh, I'm just asking if you're up for trying the int8 training feature to see if it already works for the way you set up your custom linear modules. If it does, that will make it easy for us to see if we can add an equivalent float8 UX in the future. If not, we'd love to learn what didn't work, which would help us brainstorm how we can solve your issue.

Thank you @vkuzo !
I will give it a try!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants