|
9 | 9 | from typing import Any, Dict, Optional, Tuple
|
10 | 10 |
|
11 | 11 | import torch
|
12 |
| -from transformers.cache_utils import DynamicCache, EncoderDecoderCache |
| 12 | +from transformers.cache_utils import DynamicCache, EncoderDecoderCache, HybridCache |
13 | 13 |
|
14 | 14 | from QEfficient.customop import (
|
15 | 15 | CtxGatherFunc,
|
@@ -428,3 +428,50 @@ def from_legacy_cache(
|
428 | 428 | cache.cross_attention_cache.update(key_states, value_states, layer_idx)
|
429 | 429 | cache.is_updated[layer_idx] = True
|
430 | 430 | return cache
|
| 431 | + |
| 432 | + |
| 433 | +class QEffHybridCache(HybridCache): |
| 434 | + def update( |
| 435 | + self, |
| 436 | + key_states: torch.Tensor, |
| 437 | + value_states: torch.Tensor, |
| 438 | + layer_idx: int, |
| 439 | + cache_kwargs: Optional[Dict[str, Any]] = None, |
| 440 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 441 | + position_ids = cache_kwargs.get("position_ids") |
| 442 | + sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") |
| 443 | + is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) |
| 444 | + N = self.key_cache[layer_idx].shape[2] |
| 445 | + |
| 446 | + kv_position_ids = torch.where((~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (N - 1)) |
| 447 | + |
| 448 | + kv_position_ids = torch.where( |
| 449 | + is_sliding_layer & (position_ids.max() >= (N - 1) * 2), (position_ids + 1) % N, kv_position_ids |
| 450 | + ) |
| 451 | + |
| 452 | + valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) |
| 453 | + key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) |
| 454 | + value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) |
| 455 | + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) |
| 456 | + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) |
| 457 | + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] |
| 458 | + |
| 459 | + # Original Gather |
| 460 | + ctx_len = self.key_cache[layer_idx].shape[2] |
| 461 | + ctx_indices = torch.arange(ctx_len)[None, None, ...] |
| 462 | + gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) |
| 463 | + invalid_mask = ctx_indices > gather_limit |
| 464 | + if torch.onnx.is_in_onnx_export(): |
| 465 | + invalid_idx_value = torch.iinfo(torch.int32).max |
| 466 | + else: |
| 467 | + invalid_idx_value = 0 |
| 468 | + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) |
| 469 | + |
| 470 | + all_indices = torch.arange(N) + kv_position_ids.max() + 1 |
| 471 | + rolling_indices = torch.where(all_indices > N - 1, all_indices % N, all_indices) |
| 472 | + final_indices = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), rolling_indices, ctx_indices) |
| 473 | + k_out = CtxGatherFunc.apply(k_out, final_indices) |
| 474 | + v_out = CtxGatherFunc.apply(v_out, final_indices) |
| 475 | + ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) |
| 476 | + v_out = torch.where((is_sliding_layer & (position_ids.max() >= (N - 1))), v_out, ctx_v_out) |
| 477 | + return k_out, v_out |
0 commit comments