Skip to content

[RFC]: Introduce a Triton-only Transformer Execution Path in vLLM #13319

Open
@JackChuang

Description

@JackChuang

Motivation.

As heterogeneous hardware becomes increasingly popular—and will only continue to grow in the future—vLLM's Roadmap shows a clear trend of continuously expanding hardware support, both in the past and present. This makes it essential to establish an execution path that can seamlessly operate across all heterogeneous hardware.

Clearly, CUDA and platform-specific languages are not sufficient to achieve this goal. In contrast, Triton code has been consistently designed to support various heterogeneous hardware. As long as a vendor supports Triton, our execution path can run on that vendor's chip.

Additionally, vLLM has adopted Triton for the prefill and decode phases in attention, such as triton_attention and triton_decode_attention. Unfortunately, there is no fully Triton-based execution path for vLLM inference serving.

As a result, we propose a fully CUDA-free/Triton-only transformer execution path in vLLM, including attention and non-attention operators, to address the aforementioned problems and keep up with the trend.

Proposed Change.

We plan to submit two major PRs to implement the Triton-only transformer execution path. Each PR will include multiple commits.

PR#1: Add VLLM_USE_TRITON_NON_ATTN Flag for Triton-Based Non-Attention Operators
Add a new environment flag called VLLM_USE_TRITON_NON_ATTN. Implement all non-attention path operators in Triton, categorized into two major types:

  1. CustomOP-Based Operators: Includes RMSNorm, activation functions, and RoPE.
  2. Non-CustomOP Operators: Covers other operators, including linear operators.

PR#2: Add Triton Multihead Attention (MHA) Backend
Add a new Triton-based Multihead Attention (MHA) backend, which includes:

  1. Prefill phase implementation
  2. Decode phase implementation
  3. KV cache management

We are currently preparing these PRs. The detailed design and list of planned changes are here: [Design doc] Introduce a Triton-only Transformer Execution Path in vLLM

Feedback Period.

No response

CC List.

@WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @simon-mo @russellb @LucasWilkinson @cadedaniel @ywang96 @LiuXiaoxuanPKU @bringlein @chizhang118 @yicwang @zixuanzhang226 @chenqianfzh @thesues @rainj-me @deepak-vij @XiaoningDing

Any Other Things.

As previously discussed in #5083 "[RFC]: OpenAI Triton-only backend", a complete Triton transformer execution path requires more than just attention prefill/decode. Our solution fully implements the entire pipeline.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCunstaleRecieved activity after being labelled stale

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions