diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index a2a5c2a02cb..90cad506ab1 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" run_and_track_test 15 "test_spmd_model_weight_loading.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" +run_and_track_test 16 "test_kv_cache_update_kernel.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py new file mode 100644 index 00000000000..63a1f6777e4 --- /dev/null +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest +import torch +import torch_xla + +import vllm.v1.attention.backends.pallas # noqa: F401 +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a test for TPU only") +@pytest.mark.parametrize("page_size", [32, 33]) +@pytest.mark.parametrize("combined_kv_head_num", [2, 16]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("num_slices_per_block", [4, 8]) +def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, + head_dim: int, num_slices_per_block: int): + page_num = 1000 + padded_num_tokens = 128 + kv_cache_cpu = torch.zeros( + (page_num * page_size, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) + new_kv_cpu = torch.randn( + (padded_num_tokens, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + new_kv_xla = new_kv_cpu.to(torch_xla.device()) + slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], + dtype=np.int32) + kv_cache_start_indices = np.array([ + page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, + page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 + ], + dtype=np.int32) + new_kv_cache_indices = np.concatenate( + [np.array([0], dtype=np.int32), + np.cumsum(slice_lens[:-1])]) + slot_mapping = np.stack( + [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) + padded_size = (slot_mapping.shape[0] + num_slices_per_block - + 1) // num_slices_per_block * num_slices_per_block + slot_mapping = np.pad(slot_mapping, + [[0, padded_size - slot_mapping.shape[0]], [0, 0]], + constant_values=0) + slot_mapping = np.transpose(slot_mapping) + slot_mapping_cpu = torch.tensor(slot_mapping, + device="cpu", + dtype=torch.int32) + slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) + torch_xla.sync() + + torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) + new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( + new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, + num_slices_per_block) + kv_cache_xla.copy_(new_kv_cache_xla) + torch_xla.sync() + + for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, + slice_lens): + kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] + + assert torch.allclose(kv_cache_xla.cpu(), + kv_cache_cpu, + atol=1e-4, + rtol=1e-4) diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index 3a9d80847a1..e279edfffbc 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -47,7 +47,7 @@ class FakeAttentionLayer: key = torch.zeros(num_tokens, num_kv_heads * head_size) value = torch.zeros(num_tokens, num_kv_heads * head_size) kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size) - slot_mapping = torch.zeros(num_tokens, dtype=torch.int64) + slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64) max_num_reqs = 8 max_num_blocks_per_req = 8 block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), @@ -65,6 +65,7 @@ class FakeAttentionLayer: context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + num_slices_per_kv_cache_update_block=8, ) with patch("torch.ops.xla.ragged_paged_attention" diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py new file mode 100644 index 00000000000..1a92b10e4f9 --- /dev/null +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def _kv_cache_update_kernel( + # Prefetch + slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start, + # slice_len) + # Input + new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] + kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, + # head_dim] + # Output + _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + # Scratch + scratch, # [num_slices_per_block, page_size, num_combined_kv_heads, + # head_dim] + sem, +): + async_copies = [] + block_idx = pl.program_id(0) + num_slices_per_block = scratch.shape[0] + + # Copy from new_kv_hbm_ref to scratch + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + new_kv_start = slices_ref[1, offset_i] + length = slices_ref[2, offset_i] + async_copy = pltpu.make_async_copy( + new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], + scratch.at[i, pl.ds(0, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + + for async_copy in async_copies: + async_copy.wait() + + # Copy from scratch to kv_cache_hbm_ref + async_copies.clear() + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + kv_cache_start = slices_ref[0, offset_i] + length = slices_ref[2, offset_i] + async_copy = pltpu.make_async_copy( + scratch.at[i, pl.ds(0, length), ...], + kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + for async_copy in async_copies: + async_copy.wait() + + +@functools.partial( + jax.jit, + static_argnames=["page_size", "num_slices_per_block"], +) +def kv_cache_update( + new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] + slices: jax. + Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) + kv_cache: jax. + Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + *, + page_size: int = 32, + num_slices_per_block: int = 8, +): + assert slices.shape[1] % num_slices_per_block == 0 + _, num_combined_kv_heads, head_dim = new_kv.shape + assert kv_cache.shape[1] == num_combined_kv_heads + assert kv_cache.shape[2] == head_dim + assert head_dim % 128 == 0 + # TODO: Add dynamic check to make sure that the all the slice lengths are + # smaller or equal to page_size + + in_specs = [ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + + out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] + out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] + + scalar_prefetches = [slices] + scratch = pltpu.VMEM( + (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), + new_kv.dtype, + ) + + scratch_shapes = [ + scratch, + pltpu.SemaphoreType.DMA, + ] + + kernel = pl.pallas_call( + _kv_cache_update_kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=(slices.shape[1] // num_slices_per_block, ), + scratch_shapes=scratch_shapes, + ), + out_shape=out_shape, + input_output_aliases={len(scalar_prefetches) + 1: 0}, + ) + + return kernel(*scalar_prefetches, new_kv, kv_cache)[0] diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index ff2862edaa0..49f0772c62d 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,8 +5,12 @@ from typing import Any, Optional import torch -# Required to register custom ops. +import torch_xla.core.xla_builder as xb import torch_xla.experimental.custom_kernel # noqa: F401 +# Required to register custom ops. +from torch.library import impl +from torch_xla._internal.jax_workarounds import requires_jax +from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -107,6 +111,7 @@ class PallasMetadata: context_lens: torch.Tensor query_start_loc: torch.Tensor num_seqs: torch.Tensor + num_slices_per_kv_cache_update_block: int class PallasAttentionBackendImpl(AttentionImpl): @@ -212,7 +217,9 @@ def forward( # Write input keys and values to the KV cache. # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping - write_to_kv_cache(key, value, kv_cache, slot_mapping) + write_to_kv_cache( + key, value, kv_cache, slot_mapping, + attn_metadata.num_slices_per_kv_cache_update_block) output = torch.ops.xla.ragged_paged_attention( query, @@ -244,6 +251,7 @@ def write_to_kv_cache( value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, + num_slices_per_kv_cache_update_block: int, ) -> None: """ Write the key and values to the KV cache. @@ -251,9 +259,9 @@ def write_to_kv_cache( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] - + num_slices_per_kv_cache_update_block: int """ - _, _, num_combined_kv_heads, head_size = kv_cache.shape + _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, @@ -262,4 +270,41 @@ def write_to_kv_cache( torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) kv_cache = kv_cache.flatten(0, 1) - kv_cache.index_copy_(0, slot_mapping, kv) + new_kv_cache = torch.ops.xla.kv_cache_update_op( + kv, slot_mapping, kv_cache, page_size, + num_slices_per_kv_cache_update_block) + # NOTE: the in-place copy will be optimized away by XLA compiler. + kv_cache.copy_(new_kv_cache) + + +@requires_jax +def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + num_slices_per_block: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { + "page_size": page_size, + "num_slices_per_block": num_slices_per_block + }) + return new_kv_cache + + +XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " + "int page_size, int num_slices_per_block) -> Tensor", ) + + +@impl(XLA_LIB, "kv_cache_update_op", "XLA") +def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + num_slices_per_block: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + page_size, num_slices_per_block) + return new_kv_cache + + +@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") +def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + num_slices_per_block: int) -> torch.Tensor: + return kv_cache diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2d80bac3c95..bc334419c4c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -53,12 +53,11 @@ logger = init_logger(__name__) -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 +# Block size used for kv cache updating kernel +NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 ######################################################### @@ -526,6 +525,69 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec + def _get_slot_mapping_metadata(self, num_reqs, + num_scheduled_tokens_per_req): + """ + Computes metadata for mapping slots to blocks in the key-value (KV) + cache for a batch of requests. + + This function determines, for each request in the batch, how the + scheduled tokens are distributed across memory blocks, and generates + metadata needed to map slices of tokens to their corresponding positions + in the KV cache. + + Args: + num_reqs (int): Number of requests in the current batch. + num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens + to be scheduled for each request. + + Returns: + np.ndarray: A 2D array of shape (total_block_len, 3), where each row + contains: + - kv_cache_start_index (int): The starting index in the KV cache + for the corresponding slice. + - new_kv_start_index (int): The starting index in the new KV + cache for the corresponding slice. + - slice_len (int): The length of the slice. + """ + slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] + slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ + num_scheduled_tokens_per_req + local_block_start_idx = slices_start // self.block_size + local_block_end_idx = (slices_end - 1) // self.block_size + no_repeat_req_indices = self.arange_np[:num_reqs] + global_block_start_idx = ( + no_repeat_req_indices * self.max_num_blocks_per_req + + local_block_start_idx) + block_lens = local_block_end_idx - local_block_start_idx + 1 + global_block_start_idx = np.repeat(global_block_start_idx, block_lens) + slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) + global_block_indices = global_block_start_idx + slice_arange + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() + total_block_len = np.sum(block_lens) + slot_mapping_slices = np.repeat(np.array([[0, self.block_size]], + dtype=np.int32), + total_block_len, + axis=0) + cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) + np.cumsum(block_lens, out=cu_block_lens[1:]) + for req_idx in range(num_reqs): + slot_mapping_slices[cu_block_lens[req_idx]][ + 0] = slices_start[req_idx] % self.block_size + slot_mapping_slices[ + cu_block_lens[req_idx + 1] - + 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1 + slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] + cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) + np.cumsum(slice_lens, out=cu_slices_lens[1:]) + kv_cache_start_indices = slot_mapping_slices[:, 0] + \ + (block_numbers * self.block_size) + new_kv_start_indices = cu_slices_lens[:-1] + slot_mapping_metadata = np.stack( + [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1) + return slot_mapping_metadata + def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): assert scheduler_output.total_num_scheduled_tokens > 0 @@ -603,26 +665,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.input_batch.block_table[0]. - slot_mapping_np[:total_num_scheduled_tokens]) - # Prepare the attention metadata. self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, @@ -645,12 +687,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.input_batch.block_table[0].slot_mapping_cpu[ - total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = ( - self.input_batch.block_table[0]. - slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( - self.device)) if use_max_model_len: block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, : self.max_num_blocks_per_req] @@ -675,6 +711,19 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", self.device) block_tables = block_tables.to(self.device) + slot_mapping_metadata = self._get_slot_mapping_metadata( + num_reqs, num_scheduled_tokens_per_req) + padded_num_slices = _get_padded_num_kv_cache_update_slices( + padded_total_num_scheduled_tokens, self.max_num_reqs, + self.block_size) + slot_mapping_metadata = np.pad( + slot_mapping_metadata, + [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], + constant_values=0) + slot_mapping_metadata = np.transpose(slot_mapping_metadata) + slot_mapping_metadata = torch.tensor(slot_mapping_metadata, + device=self.device) + if self.lora_config is not None: # We need to respect padding when activating LoRA adapters padded_num_scheduled_tokens_per_req = np.copy( @@ -687,13 +736,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_metadata, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + num_slices_per_kv_cache_update_block= + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -1119,8 +1170,10 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, actual_num_reqs = min(num_tokens, num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) - slot_mapping = torch.zeros(num_tokens, - dtype=torch.int64).to(self.device) + padded_num_slices = _get_padded_num_kv_cache_update_slices( + num_tokens, self.max_num_reqs, self.block_size) + slot_mapping = torch.zeros((3, padded_num_slices), + dtype=torch.int32).to(self.device) block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to(self.device) query_lens = [1] * num_reqs @@ -1138,6 +1191,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + num_slices_per_kv_cache_update_block= + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, ) if self.is_multimodal_model: @@ -1742,6 +1797,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] +def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, + page_size: int) -> int: + """Calculates the padded number of KV cache update slices to avoid + recompilation.""" + padded_num_slices = 2 * max_num_reqs + num_tokens // page_size + padded_num_slices = min(padded_num_slices, num_tokens) + padded_num_slices = ( + padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1 + ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \ + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK + return padded_num_slices + + def replace_set_lora(model): def _tpu_set_lora(