-
Notifications
You must be signed in to change notification settings - Fork 239
Can FP8 GEMM be enabled via module hooks instead of module swapping? #1887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
For float8 training, module swap is currently the only supported method. However, you could implement a module hook which uses our float8 quantization primitives. Then in the forward/backward pass, any mm/matmul ops operating on these float8 tensors would be handled by torch._scaled_mm, which dispatches to cublas (for tensorwise scales) or cutlass (for rowwise scales) for the actual GEMMs. We could potentially provide such a hook in torchao, though it is not currently planned. I think it could be useful though for use cases like this. @vkuzo @drisspg any thoughts on this? |
We started with module swapping because it's the easiest way to iterate on performance/accuracy/usability. If we can cover more important use cases with alternate UX options, I'm in favor. @zigzagcai , a couple of questions for you.
From what you have shared so far, a tensor subclass weight wrapper sounds like the right UX, but would be good to see the callsites to confirm that. |
Thanks for your follow-up!
The basic idea of our ZeRO3 weight parallel implementation: So, I just wander how could we integrate torchao FP8 with our customized ZeRO3 weight parallel implementation? |
thanks, @zigzagcai . I think a tensor subclass weight wrapper is promising for your use case. We have a prototype feature for quantized training with int8 with a tensor subclass weight wrapping UX here: https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision-training . Would you be up for trying this and seeing if it works with your use case? If yes, that would be great signal for adding an option for this UX for float8. |
Thanks @vkuzo |
Yes, when you do the following quantize_(model, int8_mixed_precision_training()) Then the weights of quantize_(model, int8_mixed_precision_training(), filter_fn=lambda mod, fqn: return isinstance(mod, YourCustomModule))
Oh, I'm just asking if you're up for trying the int8 training feature to see if it already works for the way you set up your custom linear modules. If it does, that will make it easy for us to see if we can add an equivalent float8 UX in the future. If not, we'd love to learn what didn't work, which would help us brainstorm how we can solve your issue. |
Thank you @vkuzo ! |
Hi developers,
Thanks for such a great project!
I want to integrate torchao FP8 GEMM into our training framework. But in my framework, the linear layers are defined in customized modules (where we implement Tensor Parallel or ZeRO3 weight parallel), so it is hard to directly swap the linear layers with torchao
Float8Linear
.So, can FP8 GEMM enabled via a more friendly way, such like module hooks? Since module swapping is not so flexible
The text was updated successfully, but these errors were encountered: