Skip to content

[V1][Core] Add async kv cache offload #16159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zeroorhero
Copy link

@zeroorhero zeroorhero commented Apr 7, 2025

TL, DR

In V1, we've already implemented storing the KV cache in CPU memory(#13377). However, when multiple vLLM nodes are involved, the router needs to forward requests based on cache awareness. If one node has too many active requests, requests that should have hit the cache might get forwarded to other nodes for recomputation. To avoid this, we can pool the CPU memory across nodes. But performing cross-node KV cache swap-in and swap-out operations can hurt inference performance. Inspired by NVIDIA's open-source project Dynamo (https://github.com/ai-dynamo/dynamo), an asynchronous KV cache transmission scheme has been implemented. In our tests, the performance of asynchronous KV cache transfer is on par with or even better than storing the cache locally in CPU(#13377).

Swap Strategy

Memory Pool -> GPU: happens when a block in a certain request hits the cache.
GPU -> Memory Pool: happens when a new block that does not exist in the memory pool is generated.

During swap-in, operations are performed at the granularity of a single request. Once all cache-hit blocks in a request have been swapped in, the request can be scheduled normally for inference. Meanwhile, each newly generated block is immediately swapped out.

Implementation

  1. Replace the previous step method with async_step. In this method, first retrieve the swap-in and swap-out requests and blocks, check whether they are completed, and record the completed requests and blocks.
  2. Pass the requests that have completed swap-in and the blocks that have been swapped out together to the schedule method. During scheduling, prioritize requests that have finished swap-in, then proceed to schedule requests from the waiting queue. Additionally, collect the requests requiring swap-in and the blocks to be swapped out, and return them via schedule_out.
  3. For the requests in schedule_out that require swap-in, send an asynchronous swap-in request to the model runner. Since Python's multithreading cannot fully utilize multi-core advantages, a better approach is to spawn a new thread within each swapper implementation to handle the sending. Here, we simply ​put the request into a queue and return immediately.
  4. Perform model inference operations.
  5. Swap out the newly generated blocks. This stage also runs asynchronously, handled by the swapper's independent thread.

Benchmark

This PR currently only implements Redis and Valkey support, but the observed performance was suboptimal. However, our internal benchmarking shows it achieves comparable (or even superior) performance to the local CPU implementation in vLLM PR #13377. Future iterations could extend support to other open-source distributed storage systems for further optimization.

Copy link

github-actions bot commented Apr 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Apr 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zeroorhero.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 7, 2025
for i in range(len(self.kv_caches)):
layer_cache = self.kv_caches[i]
key_cache = layer_cache[0]
val_cache = layer_cache[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please support MLA as DeepSeek is powerful and popular.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait for the community's feedback, and it will be implemented later.

@mergify mergify bot added the tpu Related to Google TPUs label Apr 9, 2025
Copy link

mergify bot commented Apr 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zeroorhero.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 9, 2025
key_cache_bytes = self.swapper.get(key_cache_key)
val_cache_bytes = self.swapper.get(val_cache_key)

gpu_key_cache = tensor_from_bytes(key_cache_bytes).to(
Copy link

@singzhou singzhou Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi there, it seems you start 2 python threads for swapping in and out respectively. I think there is a chance that the perfermance will be greatly affected by GIL. Perhaps considering start two native threads for triggering the tensor loading/offloading to get rid of GIL?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. It can be in c or rust threads.

@mergify mergify bot removed the tpu Related to Google TPUs label Apr 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants