Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flex_attention support for Qwen2.5/Gemma is broken #37299

Open
2 of 4 tasks
flukeskywalker opened this issue Apr 5, 2025 · 5 comments
Open
2 of 4 tasks

flex_attention support for Qwen2.5/Gemma is broken #37299

flukeskywalker opened this issue Apr 5, 2025 · 5 comments
Labels

Comments

@flukeskywalker
Copy link

System Info

  • transformers version: 4.50.3
  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.30.1
  • Safetensors version: 0.5.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA GeForce RTX 4090

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from torch.nn.attention.flex_attention import create_block_mask
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/Qwen2.5-0.5B" # AttributeError: 'BlockMask' object has no attribute 'dim'
# model_name = "meta-llama/Llama-3.2-1B" # works

model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="flex_attention").to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(model._supports_flex_attn)

tokens = tokenizer("Hello model, this is human. ")["input_ids"]

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, None, None, len(tokens), len(tokens), device="cuda")
model(input_ids=torch.tensor(tokens, dtype=torch.long, device="cuda").unsqueeze(0), attention_mask=block_mask)

Expected behavior

This snippet should run without error, like it does for meta-llama/Llama-3.2-1B, since Qwen2.5 model is based on Llama arch and both support flex_attention (_supports_flex_attention=True).

The error occurs because Qwen2Model._update_causal_mask() doesn't handle the case when flex_attention is enabled and the block mask is passed in as the attention_mask. This is handled in LlamaModel._update_causal_mask():

        if self.config._attn_implementation == "flex_attention":
            if isinstance(attention_mask, torch.Tensor):
                attention_mask = make_flex_block_causal_mask(attention_mask)
            if isinstance(attention_mask, BlockMask):
                return attention_mask

IIUC, adding the same handling to Qwen2Model should fix the issue, and indeed this works on my local fork. But Qwen2Model code is auto-generated, so it must be fixed elsewhere.

@flukeskywalker
Copy link
Author

I haven't used Gemma yet but noticed that its implementation also omits handling the flex_attention case in the _update_causal_mask() method here.

Please correct me if I'm wrong, but this appears to be a wider bug in the model code generation logic that misses handling flex_attention when the model has _supports_flex_attention=True.

@flukeskywalker flukeskywalker changed the title flex_attention support for Qwen2.5 is broken flex_attention support for Qwen2.5/Gemma is broken Apr 7, 2025
@ArthurZucker
Copy link
Collaborator

Yes, this is required indeed ! Do you want to open a PR for a fix?

@flukeskywalker
Copy link
Author

flukeskywalker commented Apr 7, 2025

Yes I'd like to! Will do once I'm sure what the fix is.

I'm new to internals of transformers and the details of the model codegen aren't fully clear to me, but digging more into the codebase I think the Qwen2 issue might be caused due to inheriting from MistralModel instead of LlamaModel.

Now MistralModel itself inherits from LlamaModel so this would be okay but it also overrides _update_causal_mask() without handling the block mask. I think this means that the issue also affects Mistral models. But I'm not sure if there is a good reason why Qwen2Model inherits from MistralModel or why MistralModel overrides _update_causal_mask().

Potential fix ideas:

  1. Make Qwen2Model inherit directly from LlamaModel to fix Qwen2
  2. Remove overriden methods _update_causal_mask() and _prepare_4d_causal_attention_mask_with_cache_position() from MistralModel since these are already mostly defined by LlamaModel.

Technically 2 alone would solve the Qwen2 issue too, but 1 would still simplify things.


Similarly, the Gemma3 text model inheritance goes llke: Gemma3TextModel <- Gemma2Model <- GemmaModel <- LlamaModel. This would in principle work with all attention interfaces, except that Gemma2Model overrides _update_causal_mask() causing break in support for flex_attention.

  1. The fix here could be to simply remove Gemma2Model._update_causal_mask(). Edit: Like the Mistral case, we'd need to handle the case of Hybrid/Static cache

It comes down to keeping all masking related logic contained in LlamaModel which would hopefully fix flex_attention support for Qwen2, Gemma2, Gemma3 and Mistral models. What do you think about fixes 1, 2 and 3?

Edit: the above fixes would require some additions to LlamaModel._update_causal_mask() to handle/check the different cache types supported by different models. It is possible that you do not want to bloat LlamaModel like this, so in that case the fixes would consist of checking the attn_implementation and creating/passing along the block mask in MistralModel to fix Mistral/Qwen2, and Gemma2Model to fix Gemma2 and 3. Do you prefer this simpler fix?

PS: why does Qwen2Model inherit from MistralModel?

@ArthurZucker
Copy link
Collaborator

just update the overwritter _update_causal_mask where needed, the simpler the better. FLex should be on most / all models if possible!

@ArthurZucker
Copy link
Collaborator

@Cyrilvallez is reworking our attention creation api

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants