-
Notifications
You must be signed in to change notification settings - Fork 44
Llama4 chunked attention support #395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ca7ba6d
d0070d0
4d49859
ab78163
2ecb253
fe57cbf
9bcb384
311fdfa
ac7fb12
3909508
fdb2de2
bc20390
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,8 @@ | |
from torch import nn | ||
from transformers.cache_utils import Cache | ||
from transformers.modeling_outputs import ( | ||
BaseModelOutput, | ||
BaseModelOutput | ||
BaseModelOutputWithPast, | ||
Check failure on line 16 in QEfficient/transformers/models/llama4/modeling_llama4.py
|
||
CausalLMOutputWithPast, | ||
) | ||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS | ||
|
@@ -314,8 +314,7 @@ | |
# self.max_seq_len_cached = config.max_position_embeddings | ||
# TODO: vbaddi Shouldn't for rope, the max posision_embeddings be original embeddings for rope, | ||
# chunk size 8192 always? and Revisit when >8K Chunked attention is enabled. | ||
self.max_seq_len_cached = config.rope_scaling["original_max_position_embeddings"] | ||
# self.max_seq_len_cached = config.max_position_embeddings | ||
self.max_seq_len_cached = constants.LLAMA4_MAX_POSITION_EMBEDDINGS | ||
|
||
# Get inverse frequency and scaling function (handles yarn/etc) | ||
inv_freq, self.attention_scaling = self.rope_init_fn(config, device) | ||
|
@@ -492,11 +491,21 @@ | |
|
||
query_states = query_states.transpose(1, 2) | ||
key_states = key_states.transpose(1, 2) | ||
is_sliding = kwargs.get("is_sliding") | ||
|
||
if past_key_value is not None: | ||
chunk_postion_ids = position_ids | ||
|
||
if self.use_rope: | ||
chunk_postion_ids = torch.where( | ||
chunk_postion_ids != -1, chunk_postion_ids % self.config.attention_chunk_size, chunk_postion_ids | ||
) | ||
|
||
# sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} | ||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids, "is_sliding": is_sliding} | ||
key_states, value_states = past_key_value.update_hybrid_chunked( | ||
key_states, value_states, self.layer_idx, cache_kwargs | ||
) | ||
|
||
attention_interface: Callable = eager_attention_forward | ||
|
||
|
@@ -540,7 +549,7 @@ | |
residual = hidden_states | ||
|
||
# use local attention mask for ROPE layers | ||
if self.use_chunked_attention and chunk_causal_mask is not None: | ||
if self.use_chunked_attention: | ||
attention_mask = chunk_causal_mask | ||
|
||
hidden_states = self.input_layernorm(hidden_states) | ||
|
@@ -640,11 +649,14 @@ | |
if position_ids is None: | ||
position_ids = cache_position.unsqueeze(0) | ||
|
||
causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) | ||
|
||
_, chunk_causal_mask = self._update_causal_mask( | ||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | ||
causal_mask = _create_causal_mask( | ||
position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here |
||
) | ||
chunked_position_ids = torch.where( | ||
position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids | ||
) | ||
target_length = min(past_key_values.key_cache[0].shape[-2], torch.tensor(self.config.attention_chunk_size)) | ||
chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=target_length) | ||
|
||
# embed positions | ||
hidden_states = inputs_embeds | ||
|
@@ -727,6 +739,15 @@ | |
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
is_sliding = None | ||
if hasattr(self.config.get_text_config(), "no_rope_layers"): | ||
is_sliding = self.config.no_rope_layers | ||
else: | ||
layer_switch = getattr(self.config, "sliding_window_pattern", 2) | ||
is_sliding = [bool((i + 1) % layer_switch) for i in range(self.config.num_hidden_layers)] | ||
|
||
kwargs["is_sliding"] = is_sliding | ||
|
||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
outputs = self.model( | ||
input_ids=input_ids, | ||
|
@@ -860,6 +881,15 @@ | |
|
||
prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 | ||
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN | ||
chunk_ctx_len = min( | ||
ctx_len, | ||
( | ||
self.config.text_config.attention_chunk_size | ||
if hasattr(self, "config") | ||
else constants.LLAMA4_ATTENTION_CHUNK_SIZE | ||
), | ||
) | ||
|
||
if img_size is None and hasattr(self.config.vision_config, "image_size"): | ||
img_size = getattr(self.config.vision_config, "image_size") | ||
elif img_size is None: | ||
|
@@ -889,6 +919,8 @@ | |
"max_num_tiles": max_num_tiles, | ||
"img_size": img_size, | ||
"vision_size": vision_size, | ||
"chunk_length": prefill_seq_len, | ||
"chunk_ctx_len": chunk_ctx_len, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In specializations, we need total CL also, right? For nope layers KV. |
||
}, | ||
{ | ||
"batch_size": batch_size, | ||
|
@@ -897,6 +929,8 @@ | |
"max_num_tiles": max_num_tiles, | ||
"img_size": img_size, | ||
"vision_size": vision_size, | ||
"chunk_length": prefill_seq_len, | ||
"chunk_ctx_len": chunk_ctx_len, | ||
}, | ||
] | ||
|
||
|
@@ -918,8 +952,14 @@ | |
lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} | ||
vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} | ||
|
||
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} | ||
pkv_dynamic_axes = {0: "batch_size"} | ||
for i in range(self.language_model.config.num_hidden_layers): | ||
# switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. | ||
if int((i + 1) % 4 != 0): | ||
pkv_dynamic_axes[2] = "chunk_ctx_len" | ||
else: | ||
pkv_dynamic_axes[2] = "ctx_len" | ||
|
||
for kv in ["key", "value"]: | ||
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes | ||
|
||
|
@@ -1006,7 +1046,7 @@ | |
lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)] | ||
for i in range(self.language_model.config.num_hidden_layers): | ||
for kv in ["key", "value"]: | ||
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) | ||
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape[0][0].shape, dtype=torch.float32)) | ||
|
||
inputs = {} | ||
if kv_offload: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@asmigosw As we have discussed, please restructure this to reuse the hybrid cache function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please report the o/p match with reference, with a smaller config for
chunked_window
. It will be hard verifying for 8K, so for testing purpose, change the config.