Skip to content

Commit fdb2de2

Browse files
asmigoswquic-rishinr
authored andcommitted
Added Hybrid Chunked Cache for Llama4
Signed-off-by: Asmita Goswami <[email protected]>
1 parent 3909508 commit fdb2de2

File tree

2 files changed

+160
-3
lines changed

2 files changed

+160
-3
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,151 @@ def update3D(
259259

260260
return k_out, v_out
261261

262+
def _sliding_update(
263+
self,
264+
layer_idx,
265+
key_states,
266+
value_states,
267+
position_ids,
268+
batch_index,
269+
k_out,
270+
v_out,
271+
):
272+
N = self.key_cache[layer_idx].shape[2]
273+
274+
# Update the position_ids to handle the sliding window
275+
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % (N - 1))
276+
kv_position_ids = torch.where(position_ids.max() >= (N - 1) * 2, (position_ids + 1) % N, kv_position_ids)
277+
278+
# Update the cache
279+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
280+
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states)
281+
282+
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
283+
284+
# Original Gather
285+
ctx_len = min(N, k_out.shape[2])
286+
ctx_indices = torch.arange(ctx_len)[None, None, ...]
287+
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
288+
invalid_mask = ctx_indices > gather_limit
289+
if torch.onnx.is_in_onnx_export():
290+
invalid_idx_value = torch.iinfo(torch.int32).max
291+
else:
292+
invalid_idx_value = 0
293+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
294+
295+
# rolling indices
296+
all_indices = torch.arange(N) + kv_position_ids.max() + 1
297+
rolling_indices = torch.where(all_indices > N - 1, all_indices % N, all_indices)
298+
299+
final_indices = torch.where(position_ids.max() >= (N - 1), rolling_indices, ctx_indices)
300+
301+
k_out = CtxGatherFunc.apply(k_out, final_indices)
302+
v_out = CtxGatherFunc.apply(v_out, final_indices)
303+
prefill_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
304+
305+
# Handle the rolling indices
306+
v_out = torch.where(position_ids.max() >= (N - 1), v_out, prefill_v_out)
307+
return k_out, v_out
308+
309+
def _static_update(
310+
self,
311+
layer_idx,
312+
key_states,
313+
value_states,
314+
position_ids,
315+
batch_index,
316+
k_out,
317+
v_out,
318+
):
319+
# Update the cache
320+
if len(self.key_cache) <= layer_idx:
321+
self.key_cache.append(key_states)
322+
self.value_cache.append(value_states)
323+
k_out, v_out = key_states, value_states
324+
else:
325+
# Scatter
326+
if batch_index is not None:
327+
invalid_scatter_index = torch.iinfo(torch.int32).max
328+
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)
329+
330+
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
331+
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
332+
)
333+
334+
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
335+
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
336+
)
337+
else:
338+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
339+
self.value_cache[layer_idx] = CtxScatterFunc.apply(
340+
self.value_cache[layer_idx], position_ids, value_states
341+
)
342+
343+
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
344+
345+
# Gather
346+
ctx_len = k_out.shape[2]
347+
ctx_indices = torch.arange(ctx_len)[None, None, ...]
348+
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
349+
invalid_mask = ctx_indices > gather_limit
350+
351+
if torch.onnx.is_in_onnx_export():
352+
invalid_idx_value = torch.iinfo(torch.int32).max
353+
else:
354+
invalid_idx_value = 0
355+
356+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
357+
if batch_index is not None:
358+
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
359+
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
360+
else:
361+
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
362+
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
363+
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
364+
365+
return k_out, v_out
366+
367+
def update_hybrid_chunked(
368+
self,
369+
key_states: torch.Tensor,
370+
value_states: torch.Tensor,
371+
layer_idx: int,
372+
cache_kwargs: Optional[Dict[str, Any]] = None,
373+
) -> Tuple[torch.Tensor, torch.Tensor]:
374+
"""
375+
Updates cache with support for both sliding window and position-based updates.
376+
"""
377+
if cache_kwargs is None:
378+
cache_kwargs = {}
379+
380+
k_out = self.key_cache[layer_idx]
381+
v_out = self.value_cache[layer_idx]
382+
key_states = key_states.to(k_out.dtype)
383+
value_states = value_states.to(v_out.dtype)
384+
385+
# Get cache parameters
386+
position_ids = cache_kwargs.get("position_ids")
387+
batch_index = cache_kwargs.get("batch_index", None)
388+
sliding_window = cache_kwargs.get("is_sliding", None)
389+
390+
if sliding_window[layer_idx]:
391+
update_fn = self._sliding_update
392+
else:
393+
update_fn = self._static_update
394+
395+
k_out, v_out = update_fn(
396+
layer_idx,
397+
key_states,
398+
value_states,
399+
position_ids,
400+
batch_index,
401+
k_out,
402+
v_out,
403+
)
404+
405+
return k_out, v_out
406+
262407

263408
class QEffEncoderDecoderCache(EncoderDecoderCache):
264409
"""

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from torch import nn
13-
from transformers.cache_utils import Cache, DynamicCache
13+
from transformers.cache_utils import Cache
1414
from transformers.modeling_outputs import (
1515
BaseModelOutput
1616
BaseModelOutputWithPast,
@@ -491,6 +491,7 @@ def forward(
491491

492492
query_states = query_states.transpose(1, 2)
493493
key_states = key_states.transpose(1, 2)
494+
is_sliding = kwargs.get("is_sliding")
494495

495496
if past_key_value is not None:
496497
chunk_postion_ids = position_ids
@@ -501,8 +502,10 @@ def forward(
501502
)
502503

503504
# sin and cos are specific to RoPE models; cache_position needed for the static cache
504-
cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids}
505-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
505+
cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_postion_ids, "is_sliding": is_sliding}
506+
key_states, value_states = past_key_value.update_hybrid_chunked(
507+
key_states, value_states, self.layer_idx, cache_kwargs
508+
)
506509

507510
attention_interface: Callable = eager_attention_forward
508511

@@ -736,6 +739,15 @@ def forward(
736739
)
737740
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
738741

742+
is_sliding = None
743+
if hasattr(self.config.get_text_config(), "no_rope_layers"):
744+
is_sliding = self.config.no_rope_layers
745+
else:
746+
layer_switch = getattr(self.config, "sliding_window_pattern", 2)
747+
is_sliding = [bool((i + 1) % layer_switch) for i in range(self.config.num_hidden_layers)]
748+
749+
kwargs["is_sliding"] = is_sliding
750+
739751
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
740752
outputs = self.model(
741753
input_ids=input_ids,

0 commit comments

Comments
 (0)