6
6
7
7
import vllm .model_executor .layers .fused_moe .modular_kernel as mk
8
8
from vllm .logger import init_logger
9
- from vllm .model_executor .layers .fused_moe .utils import (
10
- _resize_cache , per_token_group_quant_fp8 )
9
+ from vllm .model_executor .layers .fused_moe .utils import _resize_cache
10
+ from vllm . triton_utils import tl , triton
11
11
12
12
logger = init_logger (__name__ )
13
13
14
14
has_deep_gemm = importlib .util .find_spec ("deep_gemm" ) is not None
15
15
16
16
17
+ @triton .jit
18
+ def _silu_mul_fp8_quant_deep_gemm (
19
+ # Pointers ------------------------------------------------------------
20
+ input_ptr , # 16-bit activations (E, T, 2*H)
21
+ y_q_ptr , # fp8 quantized activations (E, T, H)
22
+ y_s_ptr , # 16-bit scales (E, T, G)
23
+ counts_ptr , # int32 num tokens per expert (E)
24
+
25
+ # Sizes ---------------------------------------------------------------
26
+ H : tl .constexpr , # hidden dimension (per output)
27
+ GROUP_SIZE : tl .constexpr , # elements per group (usually 128)
28
+
29
+ # Strides for input (elements) ---------------------------------------
30
+ stride_i_e ,
31
+ stride_i_t ,
32
+ stride_i_h ,
33
+
34
+ # Strides for y_q (elements) -----------------------------------------
35
+ stride_yq_e ,
36
+ stride_yq_t ,
37
+ stride_yq_h ,
38
+
39
+ # Strides for y_s (elements) -----------------------------------------
40
+ stride_ys_e ,
41
+ stride_ys_t ,
42
+ stride_ys_g ,
43
+
44
+ # Stride for counts (elements)
45
+ stride_counts_e ,
46
+
47
+ # Numeric params ------------------------------------------------------
48
+ eps : tl .constexpr ,
49
+ fp8_min : tl .constexpr ,
50
+ fp8_max : tl .constexpr ,
51
+
52
+ # Meta ---------------------------------------------------------------
53
+ BLOCK : tl .constexpr ,
54
+ ):
55
+ G = H // GROUP_SIZE
56
+
57
+ # map program id -> (e, g)
58
+ pid = tl .program_id (0 )
59
+ e = pid // G
60
+ g = pid % G
61
+
62
+ e = e .to (tl .int64 )
63
+ g = g .to (tl .int64 )
64
+
65
+ # number of valid tokens for this expert
66
+ n_tokens = tl .load (counts_ptr + e * stride_counts_e ).to (tl .int64 )
67
+
68
+ cols = tl .arange (0 , BLOCK )
69
+ cols = cols .to (tl .int64 )
70
+ mask_h = cols < BLOCK
71
+
72
+ t = tl .zeros ([], tl .int64 )
73
+ while t < n_tokens :
74
+ base_i_offset = (e * stride_i_e + t * stride_i_t +
75
+ g * GROUP_SIZE * stride_i_h )
76
+ base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
77
+ g * GROUP_SIZE * stride_yq_h )
78
+ base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
79
+
80
+ mask = mask_h
81
+ x = tl .load (input_ptr + base_i_offset + cols * stride_i_h ,
82
+ mask = mask ,
83
+ other = 0.0 ).to (tl .float32 )
84
+ y2 = tl .load (input_ptr + base_i_offset + H * stride_i_h +
85
+ cols * stride_i_h ,
86
+ mask = mask ,
87
+ other = 0.0 ).to (tl .float32 )
88
+
89
+ x = x * (1.0 / (1.0 + tl .exp (- x )))
90
+ y = x * y2
91
+
92
+ _absmax = tl .maximum (tl .max (tl .abs (y )), eps )
93
+ y_s = _absmax / fp8_max
94
+ y_q = tl .clamp (y / y_s , fp8_min , fp8_max ).to (y_q_ptr .dtype .element_ty )
95
+
96
+ tl .store (y_q_ptr + base_yq_offset + cols * stride_yq_h , y_q , mask = mask )
97
+ tl .store (y_s_ptr + base_ys_offset , y_s )
98
+
99
+ t += 1
100
+
101
+
102
+ def silu_mul_fp8_quant_deep_gemm (
103
+ y : torch .Tensor , # (E, T, 2*H) float32
104
+ tokens_per_expert : torch .Tensor , # (E,) number of valid tokens per expert
105
+ group_size : int = 128 ,
106
+ eps : float = 1e-10 ,
107
+ ):
108
+ """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
109
+
110
+ y has shape (E, T, 2*H). The first half of the last dimension is
111
+ silu-activated, multiplied by the second half, then quantized into FP8.
112
+
113
+ Returns `(y_q, y_s)` where
114
+ * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
115
+ * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
116
+ """
117
+ assert y .ndim == 3 , "y must be (E, T, 2*H)"
118
+ E , T , H2 = y .shape
119
+ assert H2 % 2 == 0 , "last dim of y must be even (2*H)"
120
+ H = H2 // 2
121
+ G = H // group_size
122
+ assert H % group_size == 0 , "H must be divisible by group_size"
123
+ assert tokens_per_expert .ndim == 1 and tokens_per_expert .shape [0 ] == E , \
124
+ "tokens_per_expert must be shape (E,)"
125
+ tokens_per_expert = tokens_per_expert .to (device = y .device ,
126
+ dtype = torch .int32 )
127
+
128
+ # allocate outputs
129
+ fp8_dtype = torch .float8_e4m3fn
130
+ y_q = torch .empty ((E , T , H ), dtype = fp8_dtype , device = y .device )
131
+
132
+ # strides (elements)
133
+ stride_i_e , stride_i_t , stride_i_h = y .stride ()
134
+ stride_yq_e , stride_yq_t , stride_yq_h = y_q .stride ()
135
+
136
+ # desired scale strides (elements): (T*G, 1, T)
137
+ stride_ys_e = T * G
138
+ stride_ys_t = 1
139
+ stride_ys_g = T
140
+ y_s = torch .empty_strided ((E , T , G ),
141
+ (stride_ys_e , stride_ys_t , stride_ys_g ),
142
+ dtype = torch .float32 ,
143
+ device = y .device )
144
+
145
+ stride_cnt_e = tokens_per_expert .stride ()[0 ]
146
+
147
+ # static grid over experts and H-groups.
148
+ # A loop inside the kernel handles the token dim
149
+ grid = (E * G , )
150
+
151
+ f_info = torch .finfo (fp8_dtype )
152
+ fp8_max = f_info .max
153
+ fp8_min = f_info .min
154
+
155
+ _silu_mul_fp8_quant_deep_gemm [grid ](
156
+ y ,
157
+ y_q ,
158
+ y_s ,
159
+ tokens_per_expert ,
160
+ H ,
161
+ group_size ,
162
+ stride_i_e ,
163
+ stride_i_t ,
164
+ stride_i_h ,
165
+ stride_yq_e ,
166
+ stride_yq_t ,
167
+ stride_yq_h ,
168
+ stride_ys_e ,
169
+ stride_ys_t ,
170
+ stride_ys_g ,
171
+ stride_cnt_e ,
172
+ eps ,
173
+ fp8_min ,
174
+ fp8_max ,
175
+ BLOCK = group_size ,
176
+ num_warps = 4 ,
177
+ )
178
+
179
+ return y_q , y_s
180
+
181
+
17
182
class BatchedDeepGemmExperts (mk .FusedMoEPermuteExpertsUnpermute ):
18
183
19
184
# The Deep Gemm kernels only support block size of 128
@@ -96,7 +261,6 @@ def apply(
96
261
hidden_states , w1 , w2 , topk_ids )
97
262
98
263
workspace1 = _resize_cache (workspace13 , (E , max_num_tokens , N ))
99
- workspace2 = _resize_cache (workspace2 , (E , max_num_tokens , N // 2 ))
100
264
101
265
# (from deepgemm docs) : A value hint (which is a value on CPU)
102
266
# for the M expectation of each batch, correctly setting this value
@@ -109,19 +273,9 @@ def apply(
109
273
masked_m = expert_num_tokens ,
110
274
expected_m = expected_m )
111
275
112
- # TODO (varun) [Optimization]: Use a batched version of activation.
113
- # Similarly for the quant below.
114
- self .activation (activation , workspace2 , workspace1 .view (- 1 , N ))
115
-
116
- w2_hidden_size = workspace2 .size (- 1 )
117
- workspace2 = workspace2 .view (- 1 , w2_hidden_size )
118
-
119
- a2q_scale : Optional [torch .Tensor ] = None
120
- a2q , a2q_scale = per_token_group_quant_fp8 (workspace2 ,
121
- self .block_shape [1 ],
122
- column_major_scales = False )
123
- a2q = a2q .view (E , max_num_tokens , - 1 )
124
- a2q_scale = a2q_scale .view (E , max_num_tokens , - 1 )
276
+ assert expert_num_tokens is not None
277
+ a2q , a2q_scale = silu_mul_fp8_quant_deep_gemm (workspace1 ,
278
+ expert_num_tokens )
125
279
126
280
dg .m_grouped_gemm_fp8_fp8_bf16_nt_masked ((a2q , a2q_scale ),
127
281
(w2 , w2_scale ),
0 commit comments