Skip to content

[Kernel] support merge_attn_states CUDA kernel, 3x speedup #16173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Apr 11, 2025

Conversation

DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Apr 7, 2025

base on vllm/attention/ops/triton_merge_attn_states.py

Use CUDA kernel instead of Triton to minimize CPU overhead. Compared to the Triton kernel, the CUDA kernel implemented in this PR can achieve a maximum speedup of over 3x. @WoosukKwon, End2End performance improved for R1 with PP=3 + TP=8 on L20, 4K IN:1K OUT (TTFT 5687.80 ms -> 5654.02 ms). The performance of inference will not degrade.

  • float32
  • float16
  • bfloat16
  • dispatch by scalar_t
  • fallback strategy
  • unit tests (performance & correctness)
  • end2end test
  • CEval benchmark
  • test cascade_flash_attn (used merge_attn_states), passed

Performance

cases for MLA with TP=8, num query heads per rank is 16, headsize is 128.

tokens heads headsize dtype device torch triton cuda speedup
256 16 128 float32 L20 0.15258ms 0.05647ms 0.01638ms 3.4475x
512 16 128 float32 L20 0.14940ms 0.05417ms 0.01654ms 3.2749x
613 16 128 float32 L20 0.14996ms 0.05386ms 0.01628ms 3.3080x
1024 16 128 float32 L20 0.14817ms 0.05432ms 0.01618ms 3.3583x
1536 16 128 float32 L20 0.14960ms 0.05878ms 0.01643ms 3.5774x
4096 16 128 float32 L20 0.38063ms 0.12160ms 0.06748ms 1.8021x
256 16 128 float16 L20 0.14776ms 0.05509ms 0.01567ms 3.5165x
512 16 128 float16 L20 0.14807ms 0.05524ms 0.01551ms 3.5621x
613 16 128 float16 L20 0.14843ms 0.05380ms 0.01557ms 3.4564x
1024 16 128 float16 L20 0.14945ms 0.05437ms 0.01556ms 3.4938x
1536 16 128 float16 L20 0.15718ms 0.05836ms 0.01557ms 3.7494x
4096 16 128 float16 L20 0.31852ms 0.08372ms 0.01955ms 4.2813x
256 16 128 bfloat16 L20 0.14842ms 0.05381ms 0.01546ms 3.4801x
512 16 128 bfloat16 L20 0.14782ms 0.05356ms 0.01536ms 3.4858x
613 16 128 bfloat16 L20 0.14848ms 0.05320ms 0.01547ms 3.4398x
1024 16 128 bfloat16 L20 0.14935ms 0.05376ms 0.01562ms 3.4423x
1536 16 128 bfloat16 L20 0.15765ms 0.05934ms 0.01572ms 3.7736x
4096 16 128 bfloat16 L20 0.31912ms 0.08524ms 0.01925ms 4.4283x

Correctness

  • float16 (performance & correctness)
pytest -s test_merge_attn_states.py
----------------------------------------------------------------------------------------------------
NUM_TOKENS:512, NUM_HEADS:16, HEAD_SIZE:128, DTYPE: torch.float16, Device: NVIDIA L20
 Torch time: 0.149299ms
Triton time: 0.050995ms
  CUDA time: 0.015722ms, Performance: 3.24364x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
 (CUDA  vs Triton): 0.0009765625
(Triton vs Torch) : 0.0015368461608886719
  (CUDA vs Torch) : 0.0015368461608886719
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 2.384185791015625e-07
  (CUDA vs Torch) : 0.0
  (CUDA vs Triton): 2.384185791015625e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
show more details - float32 (performance & correctness)
pytest -s test_merge_attn_states.py
----------------------------------------------------------------------------------------------------
.
NUM_TOKENS:512, NUM_HEADS:16, HEAD_SIZE:128, DTYPE: torch.float32, Device: NVIDIA L20
 Torch time: 0.150216ms
Triton time: 0.051350ms
  CUDA time: 0.016072ms, Performance: 3.19502x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
 (CUDA  vs Triton): 4.76837158203125e-07
(Triton vs Torch) : 4.76837158203125e-07
  (CUDA vs Torch) : 2.384185791015625e-07
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 4.76837158203125e-07
  (CUDA vs Torch) : 0.0
 (CUDA  vs Triton): 4.76837158203125e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
  • bfloat16 (performance & correctness)
----------------------------------------------------------------------------------------------------
NUM_TOKENS:4096, NUM_HEADS:16, HEAD_SIZE:128, DTYPE: torch.bfloat16, Device: NVIDIA L20
 Torch time: 0.322397ms
Triton time: 0.081408ms
  CUDA time: 0.026824ms, Performance: 3.03489x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
 (CUDA  vs Triton): 0.015625
(Triton vs Torch) : 0.011169910430908203
  (CUDA vs Torch) : 0.011169910430908203
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 2.384185791015625e-07
  (CUDA vs Torch) : 0.0
 (CUDA  vs Triton): 2.384185791015625e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------

End2End test

R1 671B with L20x3, PP=3, TP=8

  • launch cmd
nohup python3 -m vllm.entrypoints.openai.api_server \
        --model=/workspace/dev/hf_models/DeepSeek-R1 \
        --dtype=auto \
        --block-size 32 \
        --tokenizer-mode=slow \
        --max-model-len 32768 \
        --max-num-batched-tokens 2048 \
        --tensor-parallel-size 8 \
        --pipeline-parallel-size 3 \
        --gpu-memory-utilization 0.90 \
        --max-num-seqs 128 \
        --trust-remote-code \
        --no-enable-prefix-caching \
        --enable-chunked-prefill=True \
        --disable-custom-all-reduce \
        --port 8862 > vllm.R1.log.3 2>&1 &

4K IN:1K OUT (TTFT 5687.80 ms -> 5654.02 ms), The performance of inference will not degrade.

show more details

4K IN:1K OUT (TTFT 5687.80 ms -> 5654.02 ms)

  • w/o this opt, 4K IN:1K OUT
Maximum request concurrency: 16
============ Serving Benchmark Result ============
Successful requests:                     32
Benchmark duration (s):                  207.14
Total input tokens:                      131072
Total generated tokens:                  32768
Request throughput (req/s):              0.15
Output token throughput (tok/s):         158.19
Total Token throughput (tok/s):          790.96
---------------Time to First Token----------------
Mean TTFT (ms):                          5687.80
Median TTFT (ms):                        3969.86
P99 TTFT (ms):                           11952.93
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          95.51
Median TPOT (ms):                        96.38
P99 TPOT (ms):                           98.71
---------------Inter-token Latency----------------
Mean ITL (ms):                           95.51
Median ITL (ms):                         89.71
P99 ITL (ms):                            97.03
==================================================
  • w/ this opt, 4K IN:1K OUT (TTFT 5687.80 ms -> 5654.02 ms)
Maximum request concurrency: 16
============ Serving Benchmark Result ============
Successful requests:                     32
Benchmark duration (s):                  206.65
Total input tokens:                      131072
Total generated tokens:                  32768
Request throughput (req/s):              0.15
Output token throughput (tok/s):         158.57
Total Token throughput (tok/s):          792.83
---------------Time to First Token----------------
Mean TTFT (ms):                          5654.02
Median TTFT (ms):                        3958.66
P99 TTFT (ms):                           11861.09
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          95.30
Median TPOT (ms):                        95.98
P99 TPOT (ms):                           98.70
---------------Inter-token Latency----------------
Mean ITL (ms):                           95.30
Median ITL (ms):                         89.62
P99 ITL (ms):                            96.89
==================================================

8K IN:64 OUT (TTFT 8861.07ms -> 8767.16ms)

  • w/o this opt, 8K IN:64 OUT
Maximum request concurrency: 16
============ Serving Benchmark Result ============
Successful requests:                     48
Benchmark duration (s):                  115.37
Total input tokens:                      393216
Total generated tokens:                  3072
Request throughput (req/s):              0.42
Output token throughput (tok/s):         26.63
Total Token throughput (tok/s):          3434.90
---------------Time to First Token----------------
Mean TTFT (ms):                          8861.07
Median TTFT (ms):                        6167.50
P99 TTFT (ms):                           23576.12
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          454.74
Median TPOT (ms):                        484.97
P99 TPOT (ms):                           504.62
---------------Inter-token Latency----------------
Mean ITL (ms):                           454.74
Median ITL (ms):                         273.69
P99 ITL (ms):                            1065.00
==================================================
  • w/ this opt, 8K IN:64 OUT (TTFT 8861.07ms -> 8767.16ms)
Maximum request concurrency: 16
============ Serving Benchmark Result ============
Successful requests:                     48
Benchmark duration (s):                  115.19
Total input tokens:                      393216
Total generated tokens:                  3072
Request throughput (req/s):              0.42
Output token throughput (tok/s):         26.67
Total Token throughput (tok/s):          3440.28
---------------Time to First Token----------------
Mean TTFT (ms):                          8767.16
Median TTFT (ms):                        6170.44
P99 TTFT (ms):                           23594.15
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          455.34
Median TPOT (ms):                        483.54
P99 TPOT (ms):                           504.48
---------------Inter-token Latency----------------
Mean ITL (ms):                           455.34
Median ITL (ms):                         270.61
P99 ITL (ms):                            1066.51
==================================================

CEval benchmark (0.90197884615385)

We use evalscope to run benchmark on CEval dataset.

evalscope eval \
 --model /workspace/dev/hf_models/DeepSeek-R1 \
 --api-url http://0.0.0.0:8862/v1/chat/completions \
 --api-key EMPTY \
 --eval-batch-size 32 \
 --eval-type service \
 --datasets ceval \
 --dataset-args '{"ceval": {"local_path": "/workspace/dev/openllm/benchmarks/data/ceval"}}'

Total AverageAccuracy: 0.90197884615385

show more details
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| Model       | Dataset   | Metric          | Subset                                   |   Num |   Score | Cat.0          |
+=============+===========+=================+==========================================+=======+=========+================+
| DeepSeek-R1 | ceval     | AverageAccuracy | modern_chinese_history                   |    23 |  0.8696 | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | ideological_and_moral_cultivation        |    19 |  1      | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | logic                                    |    22 |  0.9091 | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | law                                      |    24 |  0.875  | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | chinese_language_and_literature          |    23 |  0.8261 | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | art_studies                              |    33 |  0.9091 | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | professional_tour_guide                  |    29 |  0.9655 | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | legal_professional                       |    23 |  0.913  | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_chinese                      |    19 |  0.7895 | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_history                      |    20 |  0.95   | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | middle_school_history                    |    22 |  0.9545 | Humanities     |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | civil_servant                            |    47 |  0.8723 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | sports_science                           |    19 |  0.8947 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | plant_protection                         |    22 |  1      | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | basic_medicine                           |    19 |  1      | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | clinical_medicine                        |    22 |  0.9091 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | urban_and_rural_planner                  |    46 |  0.8913 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | accountant                               |    49 |  0.9184 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | fire_engineer                            |    31 |  1      | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | environmental_impact_assessment_engineer |    31 |  0.9032 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | tax_accountant                           |    49 |  0.9184 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | physician                                |    49 |  0.9184 | Other          |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | computer_network                         |    19 |  0.7895 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | operating_system                         |    19 |  0.8947 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | computer_architecture                    |    21 |  1      | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | college_programming                      |    37 |  0.9189 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | college_physics                          |    19 |  0.8947 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | college_chemistry                        |    24 |  0.9167 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | advanced_mathematics                     |    19 |  0.9474 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | probability_and_statistics               |    18 |  0.7778 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | discrete_mathematics                     |    16 |  0.5625 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | electrical_engineer                      |    37 |  0.7027 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | metrology_engineer                       |    24 |  0.9583 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_mathematics                  |    18 |  0.7778 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_physics                      |    19 |  1      | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_chemistry                    |    19 |  0.9474 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_biology                      |    19 |  0.9474 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | middle_school_mathematics                |    19 |  1      | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | middle_school_biology                    |    21 |  0.8571 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | middle_school_physics                    |    19 |  1      | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | middle_school_chemistry                  |    20 |  1      | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | veterinary_medicine                      |    23 |  0.8696 | STEM           |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | college_economics                        |    55 |  0.8727 | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | business_administration                  |    33 |  0.8182 | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | marxism                                  |    19 |  0.9474 | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | mao_zedong_thought                       |    24 |  1      | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | education_science                        |    29 |  0.931  | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | teacher_qualification                    |    44 |  0.9318 | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_politics                     |    19 |  1      | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | high_school_geography                    |    19 |  0.9474 | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | middle_school_politics                   |    21 |  1      | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+
| DeepSeek-R1 | ceval     | AverageAccuracy | middle_school_geography                  |    12 |  0.9167 | Social Science |
+-------------+-----------+-----------------+------------------------------------------+-------+---------+----------------+

Test cascade_flash_attn

pytest -s test_cascade_flash_attn.py
================================================================================== test session starts ===================================================================================
collected 198 items
Running 198 items in this shard

test_cascade_flash_attn.py ..............................................................................................................................ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss

============================================================================ 126 passed, 72 skipped in 1.05s =============================================================================

Copy link

github-actions bot commented Apr 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Apr 7, 2025
@DefTruth DefTruth changed the title Cuda merge attn states [Kernel] support cuda merge_attn_states kernel, max ~3x improved Apr 7, 2025
@mergify mergify bot added the v1 label Apr 7, 2025
@DefTruth DefTruth marked this pull request as draft April 7, 2025 10:54
@DefTruth DefTruth marked this pull request as ready for review April 8, 2025 06:16
@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 8, 2025

@WoosukKwon, @tlrmchlsmth Hi~ This PR is ready. Could you please take a look? Compared to the Triton kernel, the CUDA kernel implemented in this PR can achieve a maximum speedup of over 3x.

@DefTruth DefTruth changed the title [Kernel] support cuda merge_attn_states kernel, max ~3x improved [Kernel] support merge_attn_states CUDA kernel Apr 8, 2025
@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 8, 2025

@WoosukKwon, @tlrmchlsmth End2End performance improved for R1 with PP=3 + TP=8 on L20, 4K IN:1K OUT (TTFT 5687.80 ms -> 5654.02 ms), 16 concurrency.

@DefTruth DefTruth changed the title [Kernel] support merge_attn_states CUDA kernel [Kernel] support merge_attn_states CUDA kernel, 3x speedup Apr 9, 2025
Copy link

mergify bot commented Apr 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @DefTruth.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 9, 2025
@mergify mergify bot removed the needs-rebase label Apr 9, 2025
@DefTruth
Copy link
Contributor Author

Lint and Deploy CI failed, but seems not related to this PR.

ERROR 04-10 04:06:36 [registry.py:345]     from torch._inductor.runtime.hints import DeviceProperties
ERROR 04-10 04:06:36 [registry.py:345]   File "/opt/venv/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py", line 67, in <module>
ERROR 04-10 04:06:36 [registry.py:345]     from triton.compiler.compiler import AttrsDescriptor
ERROR 04-10 04:06:36 [registry.py:345] ImportError: cannot import name 'AttrsDescriptor' from 'triton.compiler.compiler' (/opt/venv/lib/python3.12/site-packages/triton/compiler/compiler.py)

@DefTruth
Copy link
Contributor Author

@WoosukKwon , @tlrmchlsmth @mgoin Hi~ This PR is ready. Could you please take a look?

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the careful work and verbose testing, we appreciate it! I think this is essentially good to go given the results you've already shared. Just a few questions

@mgoin mgoin enabled auto-merge (squash) April 11, 2025 02:34
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 11, 2025
@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 11, 2025

@mgoin AMD build failed, should I add the #ifndef USE_ROCM macro restriction when binding in PyTorch?

#13 96.26 /app/vllm/csrc/torch_bindings.cpp:78:48: error: ‘merge_attn_states’ was not declared in this scope
--
  | #13 96.26    78 \|   ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
  | #13 96.26       \|                                                ^~~~~~~~~~~~~~~~~

like this

#ifndef USE_ROCM
  // Merge attn states
  // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
  // can be used to combine partial attention results (in the split-KV case)
  ops.def(
      "merge_attn_states("
      "    Tensor! output,"
      "    Tensor!? output_lse,"
      "    Tensor prefix_output,"
      "    Tensor prefix_lse,"
      "    Tensor suffix_output,"
      "    Tensor suffix_lse) -> ()");
  ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#endif

Signed-off-by: DefTruth <[email protected]>
auto-merge was automatically disabled April 11, 2025 06:04

Head branch was pushed to by a user without write access

@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 11, 2025

@mgoin All tests passed~ PTAL.

@DefTruth DefTruth requested a review from mgoin April 11, 2025 11:57
@mgoin mgoin merged commit e9528f6 into vllm-project:main Apr 11, 2025
67 checks passed
@DefTruth DefTruth deleted the cuda-merge-attn-states branch April 15, 2025 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants