Skip to content

[TPU] add kv cache update kernel #19928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"

# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
Expand Down
59 changes: 59 additions & 0 deletions tests/v1/tpu/test_kv_cache_update_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import numpy as np
import pytest
import torch
import torch_xla

import vllm.v1.attention.backends.pallas # noqa: F401
from vllm.platforms import current_platform


@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a test for TPU only")
def test_kv_cache_update_kernel():
page_num = 1000
page_size = 32
combined_kv_head_num = 16
head_dim = 128
kernel_block_size = 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also test when the kernel need to iterate thru multiple blocks?

padded_num_tokens = 128
Comment on lines +16 to +21
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there reasonable edge cases we could test here? Like block size != 16, an odd number of tokens, small kv head num, etc

kv_cache_cpu = torch.zeros(
(page_num * page_size, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
new_kv_cpu = torch.randn(
(padded_num_tokens, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, 32, 32, 1, 1, 1, 9], dtype=np.int32)
kv_cache_start_indices = np.array([57, 64, 96, 104, 213, 345, 488],
dtype=np.int32)
new_kv_cache_indices = np.array([0, 7, 39, 71, 72, 73, 74], dtype=np.int32)
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
slot_mapping = np.pad(
slot_mapping, [[0, kernel_block_size - slot_mapping.shape[0]], [0, 0]],
constant_values=0)
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu")
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
torch_xla.sync()

torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
Copy link
Collaborator

@vanbasten23 vanbasten23 Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to do this torch.ops.xla.dynamo_set_buffer_donor_?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it should be an inplace-update.

new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
kernel_block_size)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()

for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
slice_lens):
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]

assert torch.allclose(kv_cache_xla.cpu(),
kv_cache_cpu,
atol=1e-4,
rtol=1e-4)
109 changes: 109 additions & 0 deletions vllm/attention/ops/pallas_kv_cache_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu


def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [num_slices, 3]
# Input
new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref,
# Output
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch, # [block_size, page_size, num_combined_kv_heads, head_dim]
sem,
):
async_copies = []
block_idx = pl.program_id(0)
block_size = scratch.shape[0]

# Copy from new_kv_hbm_ref to scratch
for i in range(block_size):
offset_i = i + block_idx * block_size
new_kv_start = slices_ref[offset_i, 1]
length = slices_ref[offset_i, 2]
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)

for async_copy in async_copies:
async_copy.wait()

# Copy from scratch to kv_cache_hbm_ref
async_copies.clear()
for i in range(block_size):
offset_i = i + block_idx * block_size
kv_cache_start = slices_ref[offset_i, 0]
length = slices_ref[offset_i, 2]
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()


@functools.partial(
jax.jit,
static_argnames=["page_size", "block_size"],
)
def kv_cache_update(
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
slices: jax.
Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache: jax.
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
*,
page_size: int = 32,
block_size: int = 8,
):
assert slices.shape[0] % block_size == 0
_, num_combined_kv_heads, head_dim = new_kv.shape

in_specs = [
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]

out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]

scalar_prefetches = [slices]
scratch = pltpu.VMEM(
(block_size, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,
)

scratch_shapes = [
scratch,
pltpu.SemaphoreType.DMA,
]

kernel = pl.pallas_call(
_kv_cache_update_kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(slices.shape[0] // block_size, ),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this maps kv_cache_hbm_ref to the output so that you don't need to specify the output in "_kv_cache_update_kernel"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, they're just aliases.

)

return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
55 changes: 50 additions & 5 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from typing import Any, Optional

import torch
# Required to register custom ops.
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
# Required to register custom ops.
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
Expand Down Expand Up @@ -108,6 +112,7 @@ class PallasMetadata:
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
kv_cache_update_block_size: int


class PallasAttentionBackendImpl(AttentionImpl):
Expand Down Expand Up @@ -213,7 +218,10 @@ def forward(
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, kv_cache, slot_mapping)
kv_cache_update_block_size = \
attn_metadata.kv_cache_update_block_size
write_to_kv_cache(key, value, kv_cache, slot_mapping,
kv_cache_update_block_size)

output = torch.ops.xla.ragged_paged_attention(
query,
Expand Down Expand Up @@ -245,16 +253,17 @@ def write_to_kv_cache(
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_update_block_size: int,
) -> None:
""" Write the key and values to the KV cache.

Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]

kv_cache_update_block_size: int
"""
_, _, num_combined_kv_heads, head_size = kv_cache.shape
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
Expand All @@ -263,4 +272,40 @@ def write_to_kv_cache(
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)

kv_cache = kv_cache.flatten(0, 1)
kv_cache.index_copy_(0, slot_mapping, kv)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv, slot_mapping, kv_cache, page_size, kv_cache_update_block_size)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)


@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
block_size: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

"page_size": page_size,
"block_size": block_size
})
return new_kv_cache


XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
"int page_size, int block_size) -> Tensor", )


@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
block_size: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
page_size, block_size)
return new_kv_cache


@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
block_size: int) -> torch.Tensor:
return kv_cache
Loading