Skip to content

Llama4 chunked attention support #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
194 changes: 193 additions & 1 deletion QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Dict, Optional, Tuple

import torch
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache

from QEfficient.customop import (
CtxGatherFunc,
Expand Down Expand Up @@ -259,6 +259,151 @@ def update3D(

return k_out, v_out

def _sliding_update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asmigosw As we have discussed, please restructure this to reuse the hybrid cache function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please report the o/p match with reference, with a smaller config for chunked_window. It will be hard verifying for 8K, so for testing purpose, change the config.

self,
layer_idx,
key_states,
value_states,
position_ids,
batch_index,
k_out,
v_out,
):
N = self.key_cache[layer_idx].shape[2]

# Update the position_ids to handle the sliding window
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % (N - 1))
kv_position_ids = torch.where(position_ids.max() >= (N - 1) * 2, (position_ids + 1) % N, kv_position_ids)

# Update the cache
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
ctx_len = min(N, k_out.shape[2])
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

# rolling indices
all_indices = torch.arange(N) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > N - 1, all_indices % N, all_indices)

final_indices = torch.where(position_ids.max() >= (N - 1), rolling_indices, ctx_indices)

k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
prefill_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

# Handle the rolling indices
v_out = torch.where(position_ids.max() >= (N - 1), v_out, prefill_v_out)
return k_out, v_out

def _static_update(
self,
layer_idx,
key_states,
value_states,
position_ids,
batch_index,
k_out,
v_out,
):
# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
# Scatter
if batch_index is not None:
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)

self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)

self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Gather
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out

def update_hybrid_chunked(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates cache with support for both sliding window and position-based updates.
"""
if cache_kwargs is None:
cache_kwargs = {}

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)

# Get cache parameters
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
sliding_window = cache_kwargs.get("is_sliding", None)

if sliding_window[layer_idx]:
update_fn = self._sliding_update
else:
update_fn = self._static_update

k_out, v_out = update_fn(
layer_idx,
key_states,
value_states,
position_ids,
batch_index,
k_out,
v_out,
)

return k_out, v_out


class QEffEncoderDecoderCache(EncoderDecoderCache):
"""
Expand All @@ -283,3 +428,50 @@ def from_legacy_cache(
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache


class QEffHybridCache(HybridCache):
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
position_ids = cache_kwargs.get("position_ids")
sliding_window_pattern = cache_kwargs.get("sliding_window_pattern")
is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern))
N = self.key_cache[layer_idx].shape[2]

kv_position_ids = torch.where((~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (N - 1))

kv_position_ids = torch.where(
is_sliding_layer & (position_ids.max() >= (N - 1) * 2), (position_ids + 1) % N, kv_position_ids
)

valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

all_indices = torch.arange(N) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > N - 1, all_indices % N, all_indices)
final_indices = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), rolling_indices, ctx_indices)
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), v_out, ctx_v_out)
return k_out, v_out
64 changes: 52 additions & 12 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutput
BaseModelOutputWithPast,

Check failure on line 16 in QEfficient/transformers/models/llama4/modeling_llama4.py

View workflow job for this annotation

GitHub Actions / lint

Ruff

QEfficient/transformers/models/llama4/modeling_llama4.py:16:5: SyntaxError: Expected ',', found name

Check failure on line 16 in QEfficient/transformers/models/llama4/modeling_llama4.py

View workflow job for this annotation

GitHub Actions / lint

Ruff

QEfficient/transformers/models/llama4/modeling_llama4.py:16:5: SyntaxError: Expected ',', found name
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
Expand Down Expand Up @@ -314,8 +314,7 @@
# self.max_seq_len_cached = config.max_position_embeddings
# TODO: vbaddi Shouldn't for rope, the max posision_embeddings be original embeddings for rope,
# chunk size 8192 always? and Revisit when >8K Chunked attention is enabled.
self.max_seq_len_cached = config.rope_scaling["original_max_position_embeddings"]
# self.max_seq_len_cached = config.max_position_embeddings
self.max_seq_len_cached = constants.LLAMA4_MAX_POSITION_EMBEDDINGS

# Get inverse frequency and scaling function (handles yarn/etc)
inv_freq, self.attention_scaling = self.rope_init_fn(config, device)
Expand Down Expand Up @@ -492,11 +491,21 @@

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
is_sliding = kwargs.get("is_sliding")

if past_key_value is not None:
chunk_postion_ids = position_ids

if self.use_rope:
chunk_postion_ids = torch.where(
chunk_postion_ids != -1, chunk_postion_ids % self.config.attention_chunk_size, chunk_postion_ids
)

# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids, "is_sliding": is_sliding}
key_states, value_states = past_key_value.update_hybrid_chunked(
key_states, value_states, self.layer_idx, cache_kwargs
)

attention_interface: Callable = eager_attention_forward

Expand Down Expand Up @@ -540,7 +549,7 @@
residual = hidden_states

# use local attention mask for ROPE layers
if self.use_chunked_attention and chunk_causal_mask is not None:
if self.use_chunked_attention:
attention_mask = chunk_causal_mask

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -640,11 +649,14 @@
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)

_, chunk_causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
causal_mask = _create_causal_mask(
position_ids=position_ids, target_length=past_key_values.key_cache[3].shape[-2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here key_cache[3] instead of the hard coded value can we generalize using some config value.

)
chunked_position_ids = torch.where(
position_ids != -1, position_ids % self.config.attention_chunk_size, position_ids
)
target_length = min(past_key_values.key_cache[0].shape[-2], torch.tensor(self.config.attention_chunk_size))
chunk_causal_mask = _create_causal_mask(position_ids=chunked_position_ids, target_length=target_length)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -727,6 +739,15 @@
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

is_sliding = None
if hasattr(self.config.get_text_config(), "no_rope_layers"):
is_sliding = self.config.no_rope_layers
else:
layer_switch = getattr(self.config, "sliding_window_pattern", 2)
is_sliding = [bool((i + 1) % layer_switch) for i in range(self.config.num_hidden_layers)]

kwargs["is_sliding"] = is_sliding

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand Down Expand Up @@ -860,6 +881,15 @@

prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
chunk_ctx_len = min(
ctx_len,
(
self.config.text_config.attention_chunk_size
if hasattr(self, "config")
else constants.LLAMA4_ATTENTION_CHUNK_SIZE
),
)

if img_size is None and hasattr(self.config.vision_config, "image_size"):
img_size = getattr(self.config.vision_config, "image_size")
elif img_size is None:
Expand Down Expand Up @@ -889,6 +919,8 @@
"max_num_tiles": max_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In specializations, we need total CL also, right? For nope layers KV.

},
{
"batch_size": batch_size,
Expand All @@ -897,6 +929,8 @@
"max_num_tiles": max_num_tiles,
"img_size": img_size,
"vision_size": vision_size,
"chunk_length": prefill_seq_len,
"chunk_ctx_len": chunk_ctx_len,
},
]

Expand All @@ -918,8 +952,14 @@
lang_dynamic_axes["vision_embeds"] = {0: "vision_size"}
vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
pkv_dynamic_axes = {0: "batch_size"}
for i in range(self.language_model.config.num_hidden_layers):
# switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers.
if int((i + 1) % 4 != 0):
pkv_dynamic_axes[2] = "chunk_ctx_len"
else:
pkv_dynamic_axes[2] = "ctx_len"

for kv in ["key", "value"]:
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes

Expand Down Expand Up @@ -1006,7 +1046,7 @@
lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape[0][0].shape, dtype=torch.float32))

inputs = {}
if kv_offload:
Expand Down
Loading
Loading