-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[TPU] add kv cache update kernel #19928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
import torch_xla | ||
|
||
import vllm.v1.attention.backends.pallas # noqa: F401 | ||
from vllm.platforms import current_platform | ||
|
||
|
||
@pytest.mark.skipif(not current_platform.is_tpu(), | ||
reason="This is a test for TPU only") | ||
def test_kv_cache_update_kernel(): | ||
page_num = 1000 | ||
page_size = 32 | ||
combined_kv_head_num = 16 | ||
head_dim = 128 | ||
kernel_block_size = 16 | ||
padded_num_tokens = 128 | ||
Comment on lines
+16
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there reasonable edge cases we could test here? Like block size != 16, an odd number of tokens, small kv head num, etc |
||
kv_cache_cpu = torch.zeros( | ||
(page_num * page_size, combined_kv_head_num, head_dim), | ||
dtype=torch.bfloat16, | ||
device="cpu") | ||
kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) | ||
new_kv_cpu = torch.randn( | ||
(padded_num_tokens, combined_kv_head_num, head_dim), | ||
dtype=torch.bfloat16, | ||
device="cpu") | ||
new_kv_xla = new_kv_cpu.to(torch_xla.device()) | ||
slice_lens = np.array([7, 32, 32, 1, 1, 1, 9], dtype=np.int32) | ||
kv_cache_start_indices = np.array([57, 64, 96, 104, 213, 345, 488], | ||
dtype=np.int32) | ||
new_kv_cache_indices = np.array([0, 7, 39, 71, 72, 73, 74], dtype=np.int32) | ||
slot_mapping = np.stack( | ||
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) | ||
slot_mapping = np.pad( | ||
slot_mapping, [[0, kernel_block_size - slot_mapping.shape[0]], [0, 0]], | ||
constant_values=0) | ||
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu") | ||
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) | ||
torch_xla.sync() | ||
|
||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we want to do this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because it should be an inplace-update. |
||
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( | ||
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, | ||
kernel_block_size) | ||
yaochengji marked this conversation as resolved.
Show resolved
Hide resolved
|
||
kv_cache_xla.copy_(new_kv_cache_xla) | ||
torch_xla.sync() | ||
|
||
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, | ||
slice_lens): | ||
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] | ||
|
||
assert torch.allclose(kv_cache_xla.cpu(), | ||
kv_cache_cpu, | ||
atol=1e-4, | ||
rtol=1e-4) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import functools | ||
|
||
import jax | ||
from jax.experimental import pallas as pl | ||
from jax.experimental.pallas import tpu as pltpu | ||
|
||
|
||
def _kv_cache_update_kernel( | ||
# Prefetch | ||
slices_ref, # [num_slices, 3] | ||
# Input | ||
new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim] | ||
kv_cache_hbm_ref, | ||
# Output | ||
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] | ||
# Scratch | ||
scratch, # [block_size, page_size, num_combined_kv_heads, head_dim] | ||
sem, | ||
): | ||
async_copies = [] | ||
block_idx = pl.program_id(0) | ||
block_size = scratch.shape[0] | ||
|
||
# Copy from new_kv_hbm_ref to scratch | ||
for i in range(block_size): | ||
offset_i = i + block_idx * block_size | ||
new_kv_start = slices_ref[offset_i, 1] | ||
length = slices_ref[offset_i, 2] | ||
async_copy = pltpu.make_async_copy( | ||
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], | ||
scratch.at[i, pl.ds(0, length), ...], | ||
sem, | ||
) | ||
async_copy.start() | ||
async_copies.append(async_copy) | ||
|
||
for async_copy in async_copies: | ||
async_copy.wait() | ||
|
||
# Copy from scratch to kv_cache_hbm_ref | ||
async_copies.clear() | ||
for i in range(block_size): | ||
offset_i = i + block_idx * block_size | ||
kv_cache_start = slices_ref[offset_i, 0] | ||
length = slices_ref[offset_i, 2] | ||
async_copy = pltpu.make_async_copy( | ||
scratch.at[i, pl.ds(0, length), ...], | ||
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], | ||
sem, | ||
) | ||
async_copy.start() | ||
async_copies.append(async_copy) | ||
for async_copy in async_copies: | ||
async_copy.wait() | ||
|
||
|
||
@functools.partial( | ||
jax.jit, | ||
static_argnames=["page_size", "block_size"], | ||
) | ||
def kv_cache_update( | ||
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] | ||
slices: jax. | ||
Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len) | ||
kv_cache: jax. | ||
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] | ||
*, | ||
page_size: int = 32, | ||
block_size: int = 8, | ||
): | ||
assert slices.shape[0] % block_size == 0 | ||
_, num_combined_kv_heads, head_dim = new_kv.shape | ||
|
||
in_specs = [ | ||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), | ||
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), | ||
] | ||
|
||
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] | ||
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] | ||
|
||
scalar_prefetches = [slices] | ||
scratch = pltpu.VMEM( | ||
(block_size, page_size, num_combined_kv_heads, head_dim), | ||
new_kv.dtype, | ||
) | ||
|
||
scratch_shapes = [ | ||
scratch, | ||
pltpu.SemaphoreType.DMA, | ||
] | ||
|
||
kernel = pl.pallas_call( | ||
_kv_cache_update_kernel, | ||
grid_spec=pltpu.PrefetchScalarGridSpec( | ||
num_scalar_prefetch=len(scalar_prefetches), | ||
in_specs=in_specs, | ||
out_specs=out_specs, | ||
grid=(slices.shape[0] // block_size, ), | ||
scratch_shapes=scratch_shapes, | ||
), | ||
out_shape=out_shape, | ||
input_output_aliases={len(scalar_prefetches) + 1: 0}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this maps kv_cache_hbm_ref to the output so that you don't need to specify the output in "_kv_cache_update_kernel"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, they're just aliases. |
||
) | ||
|
||
return kernel(*scalar_prefetches, new_kv, kv_cache)[0] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,12 @@ | |
from typing import Any, Optional | ||
|
||
import torch | ||
# Required to register custom ops. | ||
import torch_xla.core.xla_builder as xb | ||
import torch_xla.experimental.custom_kernel # noqa: F401 | ||
# Required to register custom ops. | ||
from torch.library import impl | ||
from torch_xla._internal.jax_workarounds import requires_jax | ||
from torch_xla.experimental.custom_kernel import XLA_LIB | ||
|
||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
AttentionLayer, AttentionType) | ||
|
@@ -108,6 +112,7 @@ class PallasMetadata: | |
context_lens: torch.Tensor | ||
query_start_loc: torch.Tensor | ||
num_seqs: torch.Tensor | ||
kv_cache_update_block_size: int | ||
|
||
|
||
class PallasAttentionBackendImpl(AttentionImpl): | ||
|
@@ -213,7 +218,10 @@ def forward( | |
# Write input keys and values to the KV cache. | ||
# Skip this if sharing KV cache with an earlier attention layer. | ||
slot_mapping = attn_metadata.slot_mapping | ||
write_to_kv_cache(key, value, kv_cache, slot_mapping) | ||
kv_cache_update_block_size = \ | ||
attn_metadata.kv_cache_update_block_size | ||
write_to_kv_cache(key, value, kv_cache, slot_mapping, | ||
kv_cache_update_block_size) | ||
|
||
output = torch.ops.xla.ragged_paged_attention( | ||
query, | ||
|
@@ -245,16 +253,17 @@ def write_to_kv_cache( | |
value: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
slot_mapping: torch.Tensor, | ||
kv_cache_update_block_size: int, | ||
) -> None: | ||
""" Write the key and values to the KV cache. | ||
|
||
Args: | ||
key: shape = [num_tokens, num_kv_heads * head_size] | ||
value: shape = [num_tokens, num_kv_heads * head_size] | ||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] | ||
|
||
kv_cache_update_block_size: int | ||
""" | ||
_, _, num_combined_kv_heads, head_size = kv_cache.shape | ||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape | ||
head_size = cdiv(head_size, | ||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT | ||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, | ||
|
@@ -263,4 +272,40 @@ def write_to_kv_cache( | |
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) | ||
|
||
kv_cache = kv_cache.flatten(0, 1) | ||
kv_cache.index_copy_(0, slot_mapping, kv) | ||
new_kv_cache = torch.ops.xla.kv_cache_update_op( | ||
kv, slot_mapping, kv_cache, page_size, kv_cache_update_block_size) | ||
# NOTE: the in-place copy will be optimized away by XLA compiler. | ||
kv_cache.copy_(new_kv_cache) | ||
|
||
|
||
@requires_jax | ||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, | ||
kv_cache: torch.Tensor, page_size: int, | ||
block_size: int): | ||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update | ||
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice! |
||
"page_size": page_size, | ||
"block_size": block_size | ||
}) | ||
return new_kv_cache | ||
|
||
|
||
XLA_LIB.define( | ||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " | ||
"int page_size, int block_size) -> Tensor", ) | ||
|
||
|
||
@impl(XLA_LIB, "kv_cache_update_op", "XLA") | ||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, | ||
kv_cache: torch.Tensor, page_size: int, | ||
block_size: int) -> torch.Tensor: | ||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, | ||
page_size, block_size) | ||
return new_kv_cache | ||
|
||
|
||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") | ||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, | ||
kv_cache: torch.Tensor, page_size: int, | ||
block_size: int) -> torch.Tensor: | ||
return kv_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also test when the kernel need to iterate thru multiple blocks?