Skip to content

MLA - Flashinfer Ragged Prefill #20034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

MLA - Flashinfer Ragged Prefill #20034

wants to merge 6 commits into from

Conversation

alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Jun 24, 2025

This is a draft PR that runs the flashinfer ragged prefill for deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct. It is in rough form but has correctness. Currently, there is a slowdown for using the flashinfer ragged prefill. For example:

Batch size 1 with FlashInfer Ragged Prefill

============ Serving Benchmark Result ============
Successful requests:                     1         
Benchmark duration (s):                  0.67      
Total input tokens:                      999       
Total generated tokens:                  100       
Request throughput (req/s):              1.48      
Output token throughput (tok/s):         148.28    
Total Token throughput (tok/s):          1629.57   
---------------Time to First Token----------------
Mean TTFT (ms):                          37.58     
Median TTFT (ms):                        37.58     
P99 TTFT (ms):                           37.58     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.43      
Median TPOT (ms):                        6.43      
P99 TPOT (ms):                           6.43      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.43      
Median ITL (ms):                         6.42      
P99 ITL (ms):                            6.68      
==================================================

Batch size 1 with the original main - uses FA2 for prefill

============ Serving Benchmark Result ============
Successful requests:                     1         
Benchmark duration (s):                  0.65      
Total input tokens:                      999       
Total generated tokens:                  100       
Request throughput (req/s):              1.53      
Output token throughput (tok/s):         153.49    
Total Token throughput (tok/s):          1686.83   
---------------Time to First Token----------------
Mean TTFT (ms):                          29.76     
Median TTFT (ms):                        29.76     
P99 TTFT (ms):                           29.76     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.27      
Median TPOT (ms):                        6.27      
P99 TPOT (ms):                           6.27      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.27      
Median ITL (ms):                         6.25      
P99 ITL (ms):                            6.74      
==================================================

Batch size 100 with FlashInfer Ragged Prefill

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  3.63      
Total input tokens:                      99900     
Total generated tokens:                  10000     
Request throughput (req/s):              27.51     
Output token throughput (tok/s):         2751.26   
Total Token throughput (tok/s):          30236.31  
---------------Time to First Token----------------
Mean TTFT (ms):                          1196.92   
Median TTFT (ms):                        1253.75   
P99 TTFT (ms):                           2187.23   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          23.25     
Median TPOT (ms):                        22.71     
P99 TPOT (ms):                           33.14     
---------------Inter-token Latency----------------
Mean ITL (ms):                           23.25     
Median ITL (ms):                         13.48     
P99 ITL (ms):                            185.82    
==================================================

Batch size 100 with the original main - uses FA2 for prefill

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  2.63      
Total input tokens:                      99900     
Total generated tokens:                  10000     
Request throughput (req/s):              37.99     
Output token throughput (tok/s):         3799.28   
Total Token throughput (tok/s):          41754.09  
---------------Time to First Token----------------
Mean TTFT (ms):                          659.17    
Median TTFT (ms):                        666.53    
P99 TTFT (ms):                           1196.01   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          18.57     
Median TPOT (ms):                        18.52     
P99 TPOT (ms):                           24.14     
---------------Inter-token Latency----------------
Mean ITL (ms):                           18.57     
Median ITL (ms):                         13.39     
P99 ITL (ms):                            108.37    
==================================================

Things I'm currently doing to improve this result:

  1. Enable cuda graph for FlashInfer ragged prefill (requires creating a bunch of planned instances with fixed batch sizes). This is in progress and will be ready today.
  2. Trace profiling
  3. What SGLang is doing?
  4. Anything else?

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @alexm-redhat, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the initial integration of FlashInfer's ragged prefill functionality into the MLA backend. The aim is to leverage FlashInfer for prefill operations, with current benchmarks indicating a performance regression compared to the existing FlashAttention 2 implementation. The changes primarily involve setting up FlashInfer-specific metadata, managing its prefill wrappers, and adapting the attention computation logic to utilize the new backend, along with an updated example for testing.

Highlights

  • FlashInfer Ragged Prefill Integration: I've integrated FlashInfer's BatchPrefillWithRaggedKVCacheWrapper to handle ragged prefill operations within the Multi-Layer Attention (MLA) backend. This involves setting up the necessary FlashInfer wrappers and planning their execution.
  • FlashInfer-Specific Metadata and Hyperparameter Management: I've introduced a new FIPrefillMetadata dataclass to store FlashInfer-specific prefill parameters and wrappers. Additionally, I've added logic to extract and validate attention layer hyperparameters (such as window size and scale) from the VLLM configuration, ensuring these are consistent and correctly passed to the FlashInfer kernels.
  • Attention Logic Adaptation for Prefill: The core attention computation paths for prefill (_compute_prefill_context and _forward_prefill) have been updated. They now utilize a new _fi_prefill_run method, which dispatches to the FlashInfer prefill kernel, replacing the previous FlashAttention 2 calls for these operations.
  • Model Update for Benchmarking: The basic.py offline inference example has been updated to use the deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct model. This change facilitates focused testing and benchmarking of the new FlashInfer ragged prefill integration with a specific model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Jun 24, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FlashInfer ragged prefill for the DeepSeek-Coder-V2-Lite-Instruct model. While the changes have correctness, there's a performance slowdown. The author is actively working on improvements like enabling CUDA graph and trace profiling. The review identifies areas for code cleanup, consistency, and verification of key logic.

Comment on lines +1213 to +1214
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cu_seqlens_k argument is being passed the same value as cu_seqlens_q. Verify that this is the intended behavior, as it might lead to incorrect attention calculations if the key and query sequence lengths are different.

Comment on lines +1215 to +1216
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The max_seqlen_k argument is being passed the same value as max_seqlen_q. Verify that this is the intended behavior, as it might lead to incorrect attention calculations if the key and query sequence lengths are different.

Comment on lines 1288 to 1289
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cu_seqlens_k argument is being passed the same value as cu_seqlens_q. Verify that this is the intended behavior, as it might lead to incorrect attention calculations if the key and query sequence lengths are different.

Comment on lines 1290 to 1291
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The max_seqlen_k argument is being passed the same value as max_seqlen_q. Verify that this is the intended behavior, as it might lead to incorrect attention calculations if the key and query sequence lengths are different.

@@ -564,6 +564,8 @@ def forward(
"""
assert output is not None, "Output tensor must be provided."

print("kv_cache.shape = {}".format(kv_cache.shape))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement should be removed before merging. It's useful for debugging, but not for production code.

Comment on lines +649 to +651
192, #head_dim_qk,
causal=True,
head_dim_vo=128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The head dimensions are hardcoded here. It would be better to use variables or constants defined elsewhere in the code to ensure consistency and avoid magic numbers.

            head_dim_qk,  #head_dim_qk,
            causal=True,
            head_dim_vo=self.kv_cache_spec.head_size,

Comment on lines +680 to +682
192, #head_dim_qk,
causal=False,
head_dim_vo=128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The head dimensions are hardcoded here. It would be better to use variables or constants defined elsewhere in the code to ensure consistency and avoid magic numbers.

            head_dim_qk,  #head_dim_qk,
            causal=False,
            head_dim_vo=self.kv_cache_spec.head_size,

Comment on lines +1352 to +1353
# print("INSIDE forward")
# print(" kv_cache.shape = {}".format(kv_cache.shape))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement should be removed before merging. It's useful for debugging, but not for production code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant