Skip to content

[generate] move SinkCache to a custom_generate repo #38399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] HQQQuantizedCache

[[autodoc]] SinkCache
- update
- get_seq_length
- reorder_cache

[[autodoc]] OffloadedCache
- update
- prefetch_layer
Expand Down Expand Up @@ -443,4 +438,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] CompileConfig
- __call__

34 changes: 3 additions & 31 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ Transformers offers several [`Cache`] classes that implement different caching m
| Offloaded Static Cache | No | Yes | Yes | High | Yes |
| Quantized Cache | Yes | No | No | Low | Yes |
| Sliding Window Cache | No | Yes | Yes | High | No |
| Sink Cache | Yes | No | Yes | Mid | Yes |

This guide introduces you to the different [`Cache`] classes and shows you how to use them for generation.

Expand Down Expand Up @@ -174,28 +173,6 @@ I like rock music because it's loud and energetic. It's a great way to express m
</hfoption>
</hfoptions>

### Sink cache

[`SinkCache`] is capable of generating very long sequences ("infinite length" according to the paper) by only retaining a few initial tokens from the sequence. These are called the *sink tokens* because they account for a significant portion of the attention scores during generation. Subsequent tokens are discarded on a sliding windowed basis, and only the latest `window_size` tokens are kept. This means most of the previous knowledge is discarded.

The sink tokens allow a model to maintain stable performance even when it's dealing with very long text sequences.

Enable [`SinkCache`] by initializing it first with the [window_length](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.window_length) and [num_sink_tokens](https://hf.co/docs/transformers/main/en/internal/generation_utils#transformers.SinkCache.num_sink_tokens) parameters before passing it to [past_key_values](https://hf.co/docs/transformers/internal/generation_utils#transformers.generation.GenerateDecoderOnlyOutput.past_key_values) in [`~GenerationMixin.generate`].

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device)

past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
out = model.generate(**inputs, do_sample=False, max_new_tokens=30, past_key_values=past_key_values)
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily"
```

## Speed optimized caches

The default [`DynamicCache`] prevents you from taking advantage of just-in-time (JIT) optimizations because the cache size isn't fixed. JIT optimizations enable you to maximize latency at the expense of memory usage. All of the following cache types are compatible with JIT optimizations like [torch.compile](./llm_optims#static-kv-cache-and-torchcompile) to accelerate generation.
Expand Down Expand Up @@ -247,7 +224,7 @@ Enable [`SlidingWindowCache`] by configuring `cache_implementation="sliding_wind

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16).to("cuda:0")
Expand Down Expand Up @@ -284,16 +261,13 @@ A cache can also work in iterative generation settings where there is back-and-f

For iterative generation with a cache, start by initializing an empty cache class and then you can feed in your new prompts. Keep track of dialogue history with a [chat template](./chat_templating).

If you're using [`SinkCache`], the inputs need to be truncated to the maximum length because [`SinkCache`] can generate text that exceeds its maximum window size. However, the first input shouldn't exceed the maximum cache length.

The example below demonstrates how to use a cache for iterative generation.

```py
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import (
DynamicCache,
SinkCache,
StaticCache,
SlidingWindowCache,
QuantoQuantizedCache,
Expand All @@ -313,8 +287,6 @@ messages = []
for prompt in user_prompts:
messages.append({"role": "user", "content": prompt})
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
if isinstance(past_key_values, SinkCache):
inputs = {k: v[:, -max_cache_length:] for k, v in inputs.items()}
input_length = inputs["input_ids"].shape[1]
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values)
completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True)
Expand All @@ -336,7 +308,7 @@ model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Init StaticCache with big enough max-length (1024 tokens for the below example)
# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)

Expand All @@ -351,7 +323,7 @@ responses = []
for prompt in prompts:
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
response = tokenizer.batch_decode(outputs)[0]
responses.append(response)

Expand Down
5 changes: 0 additions & 5 deletions docs/source/ko/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,6 @@ generation_output[:2]

[[autodoc]] HQQQuantizedCache

[[autodoc]] SinkCache
- update
- get_seq_length
- reorder_cache

[[autodoc]] OffloadedCache
- update
- prefetch_layer
Expand Down
198 changes: 8 additions & 190 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import importlib.metadata
import json
import os
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -1063,199 +1062,18 @@ def _dequantize(self, qtensor):

class SinkCache(Cache):
"""
Deprecated.

A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.

Parameters:
window_length (`int`):
The length of the context window.
num_sink_tokens (`int`):
The number of sink tokens. See the original paper for more information.

Example:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
SinkCache()
```
Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
general `custom_generate`usage.
"""

def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

warnings.warn(
"`SinkCache` is deprecated and will be removed in v4.53.0. You can achieve similar functionality by "
"using a model with a sliding window attention mechanism, or by expanding RoPE and optionally using an "
"offloaded cache implementation.",
FutureWarning,
# TODO (joao, manuel): Remove this class in v4.59.0
def __init__(self, **kwargs) -> None:
raise NotImplementedError(
"`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
"https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
)

@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states

def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)

# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin

self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_rerotation_cache[key_states.shape[-2]]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]

def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
return self.window_length

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted.

Return:
A tuple containing the updated key and value states.
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
if cache_kwargs is None:
cache_kwargs = {}
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None

# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and layer_idx == 0:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length:
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)

# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache.append(key_states)
self.value_cache.append(value_states)

elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

else:
# Shifting cache
keys_to_keep = self.key_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
]

# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
keys_to_keep[..., partial_rotation_size:],
)
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
if partial_rotation_size is not None:
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)

# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)

sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
values_to_keep = self.value_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
]
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]


class StaticCache(Cache):
"""
Expand Down
9 changes: 1 addition & 8 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torch.fx.proxy import ParameterProxy

from .. import logging
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from ..cache_utils import Cache, DynamicCache, StaticCache
from ..modeling_utils import PretrainedConfig, PreTrainedModel
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
Expand Down Expand Up @@ -832,12 +832,6 @@ def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
{},
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
)
ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache",
(SinkCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
)
ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache",
(StaticCache,),
Expand Down Expand Up @@ -880,7 +874,6 @@ class HFTracer(Tracer):
_CLASSES_TO_PATCH = {
Cache: ProxyableCache,
DynamicCache: ProxyableDynamicCache,
SinkCache: ProxyableSinkCache,
StaticCache: ProxyableStaticCache,
}

Expand Down
1 change: 1 addition & 0 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ def find_all_documented_objects() -> List[str]:
"VitPoseBackbone", # Internal module
"VitPoseBackboneConfig", # Internal module
"get_values", # Internal object
"SinkCache", # Moved to a custom_generate repository, to be deleted from transformers in v4.59.0
]

# This list should be empty. Objects in it should get their own doc page.
Expand Down