Skip to content

Commit 68aaeb3

Browse files
tlrmchlsmthVarun Sundar Rabindranath
andauthored
[EP+DP] Optimize the little operations in the DeepGEMM + DeepEP low latency case (#19885)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent c3649e4 commit 68aaeb3

File tree

3 files changed

+263
-18
lines changed

3 files changed

+263
-18
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
8+
silu_mul_fp8_quant_deep_gemm)
9+
from vllm.platforms import current_platform
10+
11+
# (E, T, H, group_size, seed)
12+
CASES = [
13+
(1, 1, 128, 64, 0),
14+
(1, 4, 128, 128, 0),
15+
(2, 4, 256, 128, 0),
16+
(32, 64, 256, 128, 0),
17+
(17, 31, 768, 128, 0),
18+
]
19+
20+
21+
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
22+
@torch.inference_mode()
23+
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
24+
current_platform.seed_everything(seed)
25+
26+
# Input tensor of shape (E, T, 2*H)
27+
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
28+
tokens_per_expert = torch.randint(
29+
low=0,
30+
high=T,
31+
size=(E, ),
32+
dtype=torch.int32,
33+
device="cuda",
34+
)
35+
36+
# Run the Triton kernel
37+
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
38+
tokens_per_expert,
39+
group_size=group_size,
40+
eps=1e-10)
41+
42+
# Reference implementation
43+
fp8_info = torch.finfo(torch.float8_e4m3fn)
44+
fp8_max = fp8_info.max
45+
fp8_min = fp8_info.min
46+
eps = 1e-10
47+
48+
# Compute silu activation and elementwise multiplication
49+
y1 = y[..., :H]
50+
y2 = y[..., H:]
51+
silu_x = y1 * torch.sigmoid(y1)
52+
merged = silu_x * y2
53+
54+
# Compute reference scales and quantized output, skipping padded tokens
55+
for e in range(E):
56+
nt = tokens_per_expert[e].item()
57+
ref_s = torch.empty((T, H // group_size),
58+
dtype=torch.float32,
59+
device="cuda")
60+
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
61+
for t in range(nt):
62+
data = merged[e, t]
63+
data_grp = data.view(H // group_size, group_size)
64+
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
65+
scale = amax / fp8_max
66+
67+
scaled = data / scale.repeat_interleave(group_size)
68+
clamped = scaled.clamp(fp8_min, fp8_max)
69+
q = clamped.to(torch.float8_e4m3fn)
70+
71+
ref_s[t] = scale
72+
ref_q[t] = q
73+
74+
y_se = y_s[e]
75+
y_qe = y_q[e]
76+
77+
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
78+
torch.testing.assert_close(
79+
y_qe[:nt].to(torch.float32),
80+
ref_q[:nt].to(torch.float32),
81+
atol=2,
82+
rtol=2e-1,
83+
)

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 170 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,179 @@
66

77
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
88
from vllm.logger import init_logger
9-
from vllm.model_executor.layers.fused_moe.utils import (
10-
_resize_cache, per_token_group_quant_fp8)
9+
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
10+
from vllm.triton_utils import tl, triton
1111

1212
logger = init_logger(__name__)
1313

1414
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
1515

1616

17+
@triton.jit
18+
def _silu_mul_fp8_quant_deep_gemm(
19+
# Pointers ------------------------------------------------------------
20+
input_ptr, # 16-bit activations (E, T, 2*H)
21+
y_q_ptr, # fp8 quantized activations (E, T, H)
22+
y_s_ptr, # 16-bit scales (E, T, G)
23+
counts_ptr, # int32 num tokens per expert (E)
24+
25+
# Sizes ---------------------------------------------------------------
26+
H: tl.constexpr, # hidden dimension (per output)
27+
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
28+
29+
# Strides for input (elements) ---------------------------------------
30+
stride_i_e,
31+
stride_i_t,
32+
stride_i_h,
33+
34+
# Strides for y_q (elements) -----------------------------------------
35+
stride_yq_e,
36+
stride_yq_t,
37+
stride_yq_h,
38+
39+
# Strides for y_s (elements) -----------------------------------------
40+
stride_ys_e,
41+
stride_ys_t,
42+
stride_ys_g,
43+
44+
# Stride for counts (elements)
45+
stride_counts_e,
46+
47+
# Numeric params ------------------------------------------------------
48+
eps: tl.constexpr,
49+
fp8_min: tl.constexpr,
50+
fp8_max: tl.constexpr,
51+
52+
# Meta ---------------------------------------------------------------
53+
BLOCK: tl.constexpr,
54+
):
55+
G = H // GROUP_SIZE
56+
57+
# map program id -> (e, g)
58+
pid = tl.program_id(0)
59+
e = pid // G
60+
g = pid % G
61+
62+
e = e.to(tl.int64)
63+
g = g.to(tl.int64)
64+
65+
# number of valid tokens for this expert
66+
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
67+
68+
cols = tl.arange(0, BLOCK)
69+
cols = cols.to(tl.int64)
70+
mask_h = cols < BLOCK
71+
72+
t = tl.zeros([], tl.int64)
73+
while t < n_tokens:
74+
base_i_offset = (e * stride_i_e + t * stride_i_t +
75+
g * GROUP_SIZE * stride_i_h)
76+
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
77+
g * GROUP_SIZE * stride_yq_h)
78+
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
79+
80+
mask = mask_h
81+
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
82+
mask=mask,
83+
other=0.0).to(tl.float32)
84+
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
85+
cols * stride_i_h,
86+
mask=mask,
87+
other=0.0).to(tl.float32)
88+
89+
x = x * (1.0 / (1.0 + tl.exp(-x)))
90+
y = x * y2
91+
92+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
93+
y_s = _absmax / fp8_max
94+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
95+
96+
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
97+
tl.store(y_s_ptr + base_ys_offset, y_s)
98+
99+
t += 1
100+
101+
102+
def silu_mul_fp8_quant_deep_gemm(
103+
y: torch.Tensor, # (E, T, 2*H) float32
104+
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
105+
group_size: int = 128,
106+
eps: float = 1e-10,
107+
):
108+
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
109+
110+
y has shape (E, T, 2*H). The first half of the last dimension is
111+
silu-activated, multiplied by the second half, then quantized into FP8.
112+
113+
Returns `(y_q, y_s)` where
114+
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
115+
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
116+
"""
117+
assert y.ndim == 3, "y must be (E, T, 2*H)"
118+
E, T, H2 = y.shape
119+
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
120+
H = H2 // 2
121+
G = H // group_size
122+
assert H % group_size == 0, "H must be divisible by group_size"
123+
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
124+
"tokens_per_expert must be shape (E,)"
125+
tokens_per_expert = tokens_per_expert.to(device=y.device,
126+
dtype=torch.int32)
127+
128+
# allocate outputs
129+
fp8_dtype = torch.float8_e4m3fn
130+
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
131+
132+
# strides (elements)
133+
stride_i_e, stride_i_t, stride_i_h = y.stride()
134+
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
135+
136+
# desired scale strides (elements): (T*G, 1, T)
137+
stride_ys_e = T * G
138+
stride_ys_t = 1
139+
stride_ys_g = T
140+
y_s = torch.empty_strided((E, T, G),
141+
(stride_ys_e, stride_ys_t, stride_ys_g),
142+
dtype=torch.float32,
143+
device=y.device)
144+
145+
stride_cnt_e = tokens_per_expert.stride()[0]
146+
147+
# static grid over experts and H-groups.
148+
# A loop inside the kernel handles the token dim
149+
grid = (E * G, )
150+
151+
f_info = torch.finfo(fp8_dtype)
152+
fp8_max = f_info.max
153+
fp8_min = f_info.min
154+
155+
_silu_mul_fp8_quant_deep_gemm[grid](
156+
y,
157+
y_q,
158+
y_s,
159+
tokens_per_expert,
160+
H,
161+
group_size,
162+
stride_i_e,
163+
stride_i_t,
164+
stride_i_h,
165+
stride_yq_e,
166+
stride_yq_t,
167+
stride_yq_h,
168+
stride_ys_e,
169+
stride_ys_t,
170+
stride_ys_g,
171+
stride_cnt_e,
172+
eps,
173+
fp8_min,
174+
fp8_max,
175+
BLOCK=group_size,
176+
num_warps=4,
177+
)
178+
179+
return y_q, y_s
180+
181+
17182
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
18183

19184
# The Deep Gemm kernels only support block size of 128
@@ -96,7 +261,6 @@ def apply(
96261
hidden_states, w1, w2, topk_ids)
97262

98263
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
99-
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
100264

101265
# (from deepgemm docs) : A value hint (which is a value on CPU)
102266
# for the M expectation of each batch, correctly setting this value
@@ -109,19 +273,9 @@ def apply(
109273
masked_m=expert_num_tokens,
110274
expected_m=expected_m)
111275

112-
# TODO (varun) [Optimization]: Use a batched version of activation.
113-
# Similarly for the quant below.
114-
self.activation(activation, workspace2, workspace1.view(-1, N))
115-
116-
w2_hidden_size = workspace2.size(-1)
117-
workspace2 = workspace2.view(-1, w2_hidden_size)
118-
119-
a2q_scale: Optional[torch.Tensor] = None
120-
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
121-
self.block_shape[1],
122-
column_major_scales=False)
123-
a2q = a2q.view(E, max_num_tokens, -1)
124-
a2q_scale = a2q_scale.view(E, max_num_tokens, -1)
276+
assert expert_num_tokens is not None
277+
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
278+
expert_num_tokens)
125279

126280
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
127281
(w2, w2_scale),

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
from .pplx_prepare_finalize import PplxPrepareAndFinalize
4646
if has_deepep:
4747
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
48-
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
48+
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
49+
DeepEPLLPrepareAndFinalize)
4950
else:
5051
fused_experts = None # type: ignore
5152
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@@ -377,6 +378,13 @@ def init_prepare_finalize(self, moe: MoEConfig,
377378
all2all_manager.world_size)
378379
handle = all2all_manager.get_handle(all_to_all_args)
379380

381+
# Note : We may want to use FP8 dispatch even otherwise just to
382+
# reduce datamovement
383+
assert act_quant_block_size is not None
384+
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
385+
and act_quant_block_size[1]
386+
== DEEPEP_QUANT_BLOCK_SIZE)
387+
380388
# Note (varun): Whether to use FP8 dispatch or not needs some
381389
# profiling. Turning it off for now.
382390
prepare_finalize = DeepEPLLPrepareAndFinalize(
@@ -386,7 +394,7 @@ def init_prepare_finalize(self, moe: MoEConfig,
386394
max_tokens_per_rank=moe.max_num_tokens,
387395
quant_dtype=quant_dtype,
388396
block_shape=act_quant_block_size,
389-
use_fp8_dispatch=False,
397+
use_fp8_dispatch=use_fp8_dispatch,
390398
)
391399

392400
self.topk_indices_dtype = None

0 commit comments

Comments
 (0)