Skip to content

[Kernel] Add Conch Triton Attention backend #19625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jmanning-stackav
Copy link

@jmanning-stackav jmanning-stackav commented Jun 13, 2025

Purpose

This PR adds a new V1 Attention backend for Conch. Conch implements all required kernels in Triton (varlen attention w/ kv cache, reshape_and_cache, and scaled_fp8_quant) and should be compatible with any hardware platform that is supported by Triton (but this PR has only been tested on H100 and MI300X).

Why add another Triton backend? We already have one!

In my testing of both microbenchmarks and end-to-end serving, Conch provides slightly better performance for both prefill and decode attention than the existing Triton backend in vLLM. I haven't invested a tremendous amount of time into analyzing the differences, but I believe Conch performs better on prefill because vLLM's Triton attention impl has unnecessary loop iterations and Conch performs better on decode because it uses FlashDecoding for long sequences.

Conch's implementation of reshape_and_cache likely gives similar performance to CUDA, but I haven't tried tuning it. Conch's implementation of scaled_fp8_quant is likely a bit slower than CUDA, but again, I haven't spent time trying to optimize it.

Test Plan

Disclaimer: This is my first PR to vLLM, so I'm happy to run additional testing/performance measurements.

I measured end-to-end serving performance on both H100 and MI300X via the following commands (which I copied from another kernel PR; I'm not sure if this is standard at all):

VLLM_USE_V1=1 VLLM_LOGGING_LEVEL=DEBUG VLLM_WORKER_MULTIPROC_METHOD=spawn   vllm serve meta-llama/Llama-3.1-8B-Instruct     --trust-remote-code     --max-model-len=2048     --block-size=128     --max-num-seqs=128     --gpu_memory_utilization=0.90     --data-parallel-size 1     --disable-log-requests  --port 8001
python benchmarks/benchmark_serving.py   --model meta-llama/Llama-3.1-8B-Instruct   --dataset-name random   --ignore-eos   --num-prompts 3000   --max-concurrency 3000   --random-input-len 500   --random-output-len 500   --seed 1 --port 8001

I modified the attention backend to test Flash Attention (H100), Triton Attention (H100 and MI300X) and Conch (H100 and MI300X).

Test Result

H100

(Baseline) Flash Attention

============ Serving Benchmark Result ============
Successful requests:                     990
Benchmark duration (s):                  81.65
Total input tokens:                      494010
Total generated tokens:                  495000
Request throughput (req/s):              12.13
Output token throughput (tok/s):         6062.76
Total Token throughput (tok/s):          12113.39
---------------Time to First Token----------------
Mean TTFT (ms):                          41367.51
Median TTFT (ms):                        38679.81
P99 TTFT (ms):                           73222.22
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.29
Median TPOT (ms):                        17.63
P99 TPOT (ms):                           17.97
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.55
Median ITL (ms):                         14.78
P99 ITL (ms):                            199.87
==================================================

(Baseline) Triton Attention

============ Serving Benchmark Result ============
Successful requests:                     990
Benchmark duration (s):                  84.44
Total input tokens:                      494010
Total generated tokens:                  495000
Request throughput (req/s):              11.72
Output token throughput (tok/s):         5862.35
Total Token throughput (tok/s):          11712.98
---------------Time to First Token----------------
Mean TTFT (ms):                          42167.85
Median TTFT (ms):                        39447.24
P99 TTFT (ms):                           75402.71
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          18.12
Median TPOT (ms):                        18.48
P99 TPOT (ms):                           18.78
---------------Inter-token Latency----------------
Mean ITL (ms):                           18.38
Median ITL (ms):                         15.31
P99 ITL (ms):                            220.03
==================================================

Conch

============ Serving Benchmark Result ============
Successful requests:                     990
Benchmark duration (s):                  83.25
Total input tokens:                      494010
Total generated tokens:                  495000
Request throughput (req/s):              11.89
Output token throughput (tok/s):         5946.11
Total Token throughput (tok/s):          11880.34
---------------Time to First Token----------------
Mean TTFT (ms):                          41762.37
Median TTFT (ms):                        38986.60
P99 TTFT (ms):                           74405.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.65
Median TPOT (ms):                        17.96
P99 TPOT (ms):                           18.23
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.92
Median ITL (ms):                         15.01
P99 ITL (ms):                            209.06
==================================================

MI300X

(Baseline) Triton Attention

============ Serving Benchmark Result ============
Successful requests:                     3000
Benchmark duration (s):                  273.38
Total input tokens:                      1497000
Total generated tokens:                  1500000
Request throughput (req/s):              10.97
Output token throughput (tok/s):         5486.92
Total Token throughput (tok/s):          10962.86
---------------Time to First Token----------------
Mean TTFT (ms):                          134943.71
Median TTFT (ms):                        132286.62
P99 TTFT (ms):                           264667.74
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          21.54
Median TPOT (ms):                        21.74
P99 TPOT (ms):                           21.90
---------------Inter-token Latency----------------
Mean ITL (ms):                           21.96
Median ITL (ms):                         17.84
P99 ITL (ms):                            287.36
==================================================

Conch

============ Serving Benchmark Result ============
Successful requests:                     3000
Benchmark duration (s):                  246.01
Total input tokens:                      1497000
Total generated tokens:                  1500000
Request throughput (req/s):              12.19
Output token throughput (tok/s):         6097.27
Total Token throughput (tok/s):          12182.34
---------------Time to First Token----------------
Mean TTFT (ms):                          121659.70
Median TTFT (ms):                        119735.36
P99 TTFT (ms):                           237212.95
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          19.29
Median TPOT (ms):                        19.42
P99 TPOT (ms):                           19.74
---------------Inter-token Latency----------------
Mean ITL (ms):                           19.69
Median ITL (ms):                         16.17
P99 ITL (ms):                            239.03
==================================================

Conclusion

Conch performs slightly better than the existing Triton backend on H100, but quite a bit better on MI300X. I also only tested on relatively short sequences, but the difference should be more pronounced for long sequences (>=4096 tokens) because Conch is using FlashDecoding. I also have more benchmark results for quantized models and FP8 KV cache that I can share as well, and I'm happy to collect more as-needed. Please ask any questions below and thank you in advance for your review!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @jmanning-stackav, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new attention backend based on the Conch library's Triton kernels. The primary goal is to offer a performance alternative to the existing Triton backend within vLLM's V1 engine, especially targeting improvements on platforms like MI300X and potentially for longer sequences. The change involves adding the Conch dependency and integrating the new backend into vLLM's attention backend selection logic.

Highlights

  • New Attention Backend: Adds a new V1 attention backend leveraging the Conch library, which provides Triton-based kernels for variable-length attention with KV cache, reshape_and_cache, and scaled FP8 quantization.
  • Performance Improvements: The Conch backend aims to provide better performance than the existing vLLM Triton backend, particularly noted for decode performance (using FlashDecoding) and showing significant gains on MI300X in provided benchmarks.
  • Platform Support: The Conch backend is intended to be compatible with any hardware supported by Triton, with specific testing reported on H100 and MI300X.
  • Configuration: The new backend can be enabled by setting the environment variable VLLM_ATTENTION_BACKEND=CONCH.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added ci/build rocm Related to AMD ROCm v1 labels Jun 13, 2025
Copy link

mergify bot commented Jun 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jmanning-stackav.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR adds the Conch Triton Attention backend, improving performance, especially on MI300X. The integration is well-structured, and the new backend code is clear. I've suggested a minor formatting improvement and a docstring correction for better consistency.

@jmanning-stackav jmanning-stackav force-pushed the feature/jmanning/conch-backend branch from 31a5525 to 6cf839f Compare June 13, 2025 21:41
@mergify mergify bot removed the needs-rebase label Jun 13, 2025
Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jmanning-stackav.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant