@@ -28,8 +28,8 @@ namespace fbgemm_gpu {
28
28
29
29
at::Tensor nope_qkv_varseq_prefill (
30
30
at::Tensor XQ,
31
- at::Tensor XK,
32
- at::Tensor XV,
31
+ std::optional< at::Tensor> XK,
32
+ std::optional< at::Tensor> XV,
33
33
at::Tensor cache_K,
34
34
at::Tensor cache_V,
35
35
at::Tensor varseq_batch,
@@ -39,12 +39,13 @@ at::Tensor nope_qkv_varseq_prefill(
39
39
std::optional<at::Tensor> varseq_cache_seqpos,
40
40
std::optional<at::Tensor> qparam_k,
41
41
std::optional<at::Tensor> qparam_v,
42
- bool k_norm);
42
+ bool k_norm,
43
+ bool update_kv);
43
44
44
45
at::Tensor nope_qkv_decoding (
45
46
at::Tensor XQ,
46
- at::Tensor XK,
47
- at::Tensor XV,
47
+ std::optional< at::Tensor> XK,
48
+ std::optional< at::Tensor> XV,
48
49
at::Tensor cache_K,
49
50
at::Tensor cache_V,
50
51
at::Tensor seqpos,
@@ -55,12 +56,13 @@ at::Tensor nope_qkv_decoding(
55
56
std::optional<at::Tensor> cache_seqpos,
56
57
std::optional<at::Tensor> qparam_k,
57
58
std::optional<at::Tensor> qparam_v,
58
- bool k_norm);
59
+ bool k_norm,
60
+ bool update_kv);
59
61
60
62
at::Tensor rope_qkv_varseq_prefill (
61
63
at::Tensor XQ,
62
- at::Tensor XK,
63
- at::Tensor XV,
64
+ std::optional< at::Tensor> XK,
65
+ std::optional< at::Tensor> XV,
64
66
at::Tensor cache_K,
65
67
at::Tensor cache_V,
66
68
at::Tensor varseq_batch,
@@ -79,12 +81,13 @@ at::Tensor rope_qkv_varseq_prefill(
79
81
std::optional<at::Tensor> qparam_k,
80
82
std::optional<at::Tensor> qparam_v,
81
83
bool write_k_back,
82
- bool k_norm);
84
+ bool k_norm,
85
+ bool update_kv);
83
86
84
87
at::Tensor rope_qkv_decoding (
85
88
at::Tensor XQ,
86
- at::Tensor XK,
87
- at::Tensor XV,
89
+ std::optional< at::Tensor> XK,
90
+ std::optional< at::Tensor> XV,
88
91
at::Tensor cache_K,
89
92
at::Tensor cache_V,
90
93
at::Tensor seqpos,
@@ -103,7 +106,8 @@ at::Tensor rope_qkv_decoding(
103
106
double hi_freq_factor,
104
107
std::optional<at::Tensor> qparam_k,
105
108
std::optional<at::Tensor> qparam_v,
106
- bool k_norm);
109
+ bool k_norm,
110
+ bool update_kv);
107
111
108
112
at::Tensor xpos_qkv_varseq_prefill (
109
113
at::Tensor XQ,
@@ -181,15 +185,15 @@ at::Tensor mqa_attn(
181
185
int64_t cache_logical_dtype_int);
182
186
183
187
TORCH_LIBRARY_FRAGMENT (fbgemm, m) {
184
- m.def (" rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!) XK, Tensor XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING (
185
- DEFAULT_PAGE_SIZE) " , Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
186
- " , float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False) -> Tensor" );
187
- m.def (" rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING (
188
- DEFAULT_PAGE_SIZE) " , Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor" );
189
- m.def (" nope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING (
190
- DEFAULT_PAGE_SIZE) " , Tensor? varseq_cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor" );
191
- m.def (" nope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING (
192
- DEFAULT_PAGE_SIZE) " , Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor" );
188
+ m.def (" rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!)? XK, Tensor? XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING (
189
+ DEFAULT_PAGE_SIZE) " , Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
190
+ " , float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False, bool update_kv=True ) -> Tensor" );
191
+ m.def (" rope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING (
192
+ DEFAULT_PAGE_SIZE) " , Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True ) -> Tensor" );
193
+ m.def (" nope_qkv_varseq_prefill(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING (
194
+ DEFAULT_PAGE_SIZE) " , Tensor? varseq_cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True ) -> Tensor" );
195
+ m.def (" nope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING (
196
+ DEFAULT_PAGE_SIZE) " , Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True ) -> Tensor" );
193
197
m.def (" xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING (
194
198
DEFAULT_PAGE_SIZE) " , Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor" );
195
199
m.def (" xpos_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING (
@@ -225,8 +229,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
225
229
226
230
at::Tensor rope_qkv_varseq_prefill_meta (
227
231
at::Tensor XQ,
228
- at::Tensor /* XK */ ,
229
- at::Tensor /* XV */ ,
232
+ std::optional< at::Tensor> /* XK */ ,
233
+ std::optional< at::Tensor> /* XV */ ,
230
234
at::Tensor /* cache_K */ ,
231
235
at::Tensor /* cache_V */ ,
232
236
at::Tensor /* varseq_batch */ ,
@@ -245,15 +249,16 @@ at::Tensor rope_qkv_varseq_prefill_meta(
245
249
std::optional<at::Tensor> /* qparam_k */ ,
246
250
std::optional<at::Tensor> /* qparam_v */ ,
247
251
bool /* write_k_back */ ,
248
- bool /* k_norm */
252
+ bool /* k_norm */ ,
253
+ bool /* update_kv */
249
254
) {
250
255
return at::empty_like (XQ);
251
256
}
252
257
253
258
at::Tensor rope_qkv_decoding_meta (
254
259
at::Tensor XQ,
255
- at::Tensor /* XK */ ,
256
- at::Tensor /* XV */ ,
260
+ std::optional< at::Tensor> /* XK */ ,
261
+ std::optional< at::Tensor> /* XV */ ,
257
262
at::Tensor /* cache_K */ ,
258
263
at::Tensor /* cache_V */ ,
259
264
at::Tensor /* seqpos */ ,
@@ -272,15 +277,16 @@ at::Tensor rope_qkv_decoding_meta(
272
277
double /* hi_freq_factor */ ,
273
278
std::optional<at::Tensor> /* qparam_k */ ,
274
279
std::optional<at::Tensor> /* qparam_v */ ,
275
- bool /* k_norm */
280
+ bool /* k_norm */ ,
281
+ bool /* update_kv */
276
282
) {
277
283
return at::empty_like (XQ);
278
284
}
279
285
280
286
at::Tensor nope_qkv_varseq_prefill_meta (
281
287
at::Tensor XQ,
282
- at::Tensor /* XK */ ,
283
- at::Tensor /* XV */ ,
288
+ std::optional< at::Tensor> /* XK */ ,
289
+ std::optional< at::Tensor> /* XV */ ,
284
290
at::Tensor /* cache_K */ ,
285
291
at::Tensor /* cache_V */ ,
286
292
at::Tensor /* varseq_batch */ ,
@@ -290,15 +296,16 @@ at::Tensor nope_qkv_varseq_prefill_meta(
290
296
std::optional<at::Tensor> /* varseq_cache_seqpos */ ,
291
297
std::optional<at::Tensor> /* qparam_k */ ,
292
298
std::optional<at::Tensor> /* qparam_v */ ,
293
- bool /* k_norm */
299
+ bool /* k_norm */ ,
300
+ bool /* update_kv */
294
301
) {
295
302
return at::empty_like (XQ);
296
303
}
297
304
298
305
at::Tensor nope_qkv_decoding_meta (
299
306
at::Tensor XQ,
300
- at::Tensor /* XK */ ,
301
- at::Tensor /* XV */ ,
307
+ std::optional< at::Tensor> /* XK */ ,
308
+ std::optional< at::Tensor> /* XV */ ,
302
309
at::Tensor /* cache_K */ ,
303
310
at::Tensor /* cache_V */ ,
304
311
at::Tensor /* seqpos */ ,
@@ -309,7 +316,8 @@ at::Tensor nope_qkv_decoding_meta(
309
316
std::optional<at::Tensor> /* cache_seqpos */ ,
310
317
std::optional<at::Tensor> /* qparam_k */ ,
311
318
std::optional<at::Tensor> /* qparam_v */ ,
312
- bool /* k_norm */
319
+ bool /* k_norm */ ,
320
+ bool /* update_kv */
313
321
) {
314
322
return at::empty_like (XQ);
315
323
}
0 commit comments