Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llamma4] Chunked Attention #37351

Closed
2 of 4 tasks
vasqu opened this issue Apr 7, 2025 · 5 comments
Closed
2 of 4 tasks

[Llamma4] Chunked Attention #37351

vasqu opened this issue Apr 7, 2025 · 5 comments
Labels

Comments

@vasqu
Copy link
Contributor

vasqu commented Apr 7, 2025

System Info

  • transformers version: 4.52.0.dev0 (around commit 8bbcdf5)
  • Platform: Linux-6.8.0-111057-tuxedo-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.30.1
  • Safetensors version: 0.4.3
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Tensorflow version (GPU?): 2.15.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 3080 Ti Laptop GPU

Who can help?

@ArthurZucker @winglian (fyi)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Because I'm GPU poor, I modified llama4 to only have one layer and a lower hidden size.

Rough script:

import torch

from transformers import AutoConfig, AutoTokenizer
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM


config = AutoConfig.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct").get_text_config()

# modify config for debugging
config._attn_implementation = "eager"  # or "flex_attention", possibly also "sdpa"
config.hidden_size = 128
config.num_hidden_layers = 1
config.attention_chunk_size = 3 # causes eager issues, leaving default causes issues in flex attention

# some dummy data
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct")
tokenizer.padding_side = "left"  # irrelevant tbh
input_text = ["What are we having for dinner?", "How are you?"]
input_ids = tokenizer(input_text, padding=True, return_tensors="pt").to("cuda")

# init module, half precision to save on vram 
test_module = Llama4ForCausalLM(config).to("cuda", torch.bfloat16)

# simple forward pass
test_module.forward(**input_ids)

This can cause various issues, e.g.

  • In eager: RuntimeError: The size of tensor a (8) must match the size of tensor b (2) at non-singleton dimension 0
  • In flex: ValueError: block_mask was created for block_mask.shape=(2, 1, 8, tensor(8192, device='cuda:0')) but got q_len=8 and kv_len=8. (...) - looks like more fixes for post-training llama4 #37329 (comment)

Expected behavior

Chunked attention doesn't seem to be correctly handled atm. A lots of code does not enter this territory because of the fairly long context to even go over the required chunk size.

@vasqu vasqu added the bug label Apr 7, 2025
@vasqu
Copy link
Contributor Author

vasqu commented Apr 7, 2025

Also as a note:

  • It has to be torch==2.6.0, previous versions (2.5.1) can handle this somewhat for flex attention (when disabling mask creation compilation - another bug introduced).
  • The purpose/usage behind chunked attention is unclear to me. It might also not handle padding correctly... (but that's only on first glance).

@ArthurZucker
Copy link
Collaborator

Chunked attention enable 10M context length! It discards unused cache on the go!

@vasqu
Copy link
Contributor Author

vasqu commented Apr 7, 2025

The attention mask reminded me of packed sequences 😄 But that sounds nice!

I forgot to mention that I also modified the mask padding (pls see vasqu@5f9b658). (Otherwise smaller sequences will crash on flex attention)

@overwindows
Copy link

Same issue here

when running the sample code from the LLaMA 4 release blog, and even without using Flex Attention (e.g., using eager or sdpa), it still throws an error if the input length exceeds 8K.

chunked_attention_mask = chunked_attention_mask & attention_mask RuntimeError: The size of tensor a (8192) must match the size of tensor b (16395) at non-singleton dimension 1

@ArthurZucker
Copy link
Collaborator

this should be fixed with latest patch release!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants