Skip to content

Commit bc20390

Browse files
quic-akuruvilquic-rishinr
authored andcommitted
Added HybridCache class and function
Signed-off-by: Ann <[email protected]>
1 parent fdb2de2 commit bc20390

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Any, Dict, Optional, Tuple
1010

1111
import torch
12-
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
12+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache
1313

1414
from QEfficient.customop import (
1515
CtxGatherFunc,
@@ -428,3 +428,50 @@ def from_legacy_cache(
428428
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
429429
cache.is_updated[layer_idx] = True
430430
return cache
431+
432+
433+
class QEffHybridCache(HybridCache):
434+
def update(
435+
self,
436+
key_states: torch.Tensor,
437+
value_states: torch.Tensor,
438+
layer_idx: int,
439+
cache_kwargs: Optional[Dict[str, Any]] = None,
440+
) -> Tuple[torch.Tensor, torch.Tensor]:
441+
position_ids = cache_kwargs.get("position_ids")
442+
sliding_window_pattern = cache_kwargs.get("sliding_window_pattern")
443+
is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern))
444+
N = self.key_cache[layer_idx].shape[2]
445+
446+
kv_position_ids = torch.where((~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (N - 1))
447+
448+
kv_position_ids = torch.where(
449+
is_sliding_layer & (position_ids.max() >= (N - 1) * 2), (position_ids + 1) % N, kv_position_ids
450+
)
451+
452+
valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
453+
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
454+
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
455+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
456+
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states)
457+
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
458+
459+
# Original Gather
460+
ctx_len = self.key_cache[layer_idx].shape[2]
461+
ctx_indices = torch.arange(ctx_len)[None, None, ...]
462+
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
463+
invalid_mask = ctx_indices > gather_limit
464+
if torch.onnx.is_in_onnx_export():
465+
invalid_idx_value = torch.iinfo(torch.int32).max
466+
else:
467+
invalid_idx_value = 0
468+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
469+
470+
all_indices = torch.arange(N) + kv_position_ids.max() + 1
471+
rolling_indices = torch.where(all_indices > N - 1, all_indices % N, all_indices)
472+
final_indices = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), rolling_indices, ctx_indices)
473+
k_out = CtxGatherFunc.apply(k_out, final_indices)
474+
v_out = CtxGatherFunc.apply(v_out, final_indices)
475+
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
476+
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), v_out, ctx_v_out)
477+
return k_out, v_out

0 commit comments

Comments
 (0)