-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
DynamicFlexAttention wrapper class for dynamic sequence lengths #1960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
See following gist https://gist.github.com/zyklotomic/527cb96da86c2b5f5984bede3be9b227 |
Hey! Great PR! Do you know why |
I have some interesting findings to report back! Should have dug deeper initially. Turns out getting dynamic shapes to work is something that has been worked on, and apparently is available in the nightly version of PyTorch. Links of interest: https://github.com/pytorch/pytorch/blob/8d08b4901586f230353a558ee00c16ad57f95178/torch/_inductor/kernel/flex_attention.py#L705 (most recent commit as of writing) -> which points to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/flex_decoding.py#L336
I did try my example notebook and set Not a As for your question on why What do you think is the best course of action? Should we wait for the PyTorch folks to stabilize instead? |
I think I only just understood what you mean. If I understand correctly, my wrapper class handles the padding for you based on the input size. |
380b4c8
to
098642e
Compare
108cb95
to
275a743
Compare
https://colab.research.google.com/drive/1X7CpQgIqgRpV2aIUgS_p7u1TR4ITUfXF?usp=sharing It might be a bit primitive to use a temporary print statement to confirm that the flex attention module was indeed being invoked but don't think there was any better way. |
Still WIP, trying to debug performance issues. Will try to cache the block masks and enable dynamic kernel size selection. Also seems like I forgot to account for GQA, the num heads might be diff for kv and query. |
a048acc
to
6982a50
Compare
https://colab.research.google.com/drive/1LlAbzLWeC7Js3S19NhMFKRHmWtP7oFrP?usp=sharing So it seems like when comparing to https://github.com/unslothai/notebooks/blob/main/nb/Gemma2_(2B)-Alpaca.ipynb, we use less GPU memory, but unfortunately a lot more time. The peak reserved memory for training number on my notebook is wrong because I forgot to run the stats collection cell right before; trying to conserve colab credits. I suspect it has to do with the block mask. The padding strategy unfortunately interferes with the strategy of using one larger block mask for everything. Maybe a custom BlockMask constructor would help, we would need to properly understand the format of the BlockMask. The padding also most likely means more extraneous operations in the matmul. There is also definitely a compile cost for each of the Flex Attention kernels, but that is a fixed cost. Taking the [WIP] tag off, but not sure if this is merge worthy given the performance problems. |
Had a stab at making Flex Attention work without excessive recompilation. I am not fully confident in this approach, it kinda feels jank to the max. Hence, I wanted to have confirmation if this is the right approach.
In essence, the kernel has to recompile every time the input sizes change. Hence, why not compile a kernel for a larger size, and pad inputs when necessary, and then splice the result before returning. See code for more thorough comments.
I haven't had the chance to really test the performance yet. There are potential enhancements too that I mention in the comments.
Will attach testing code for a demo in a bit.