@@ -96,6 +96,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
96
96
const optional<Tensor> k_scales,
97
97
const optional<Tensor> v_zero_points,
98
98
const optional<Tensor> v_scales,
99
+ const bool is_seq_at_dim_2,
99
100
Tensor& output);
100
101
101
102
at::Tensor custom_quantized_sdpa_aten (
@@ -115,7 +116,8 @@ at::Tensor custom_quantized_sdpa_aten(
115
116
const std::optional<at::Tensor>& k_zero_points,
116
117
const std::optional<at::Tensor>& k_scales,
117
118
const std::optional<at::Tensor>& v_zero_points,
118
- const std::optional<at::Tensor>& v_scales);
119
+ const std::optional<at::Tensor>& v_scales,
120
+ const bool is_seq_at_dim_2);
119
121
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
120
122
121
123
Tensor& update_cache_out_no_context (
@@ -258,6 +260,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
258
260
const optional<Tensor> k_scales,
259
261
const optional<Tensor> v_zero_points,
260
262
const optional<Tensor> v_scales,
263
+ const bool is_seq_at_dim_2,
261
264
Tensor& output) {
262
265
executorch::aten::RuntimeContext context{};
263
266
return torch::executor::native::custom_quantized_sdpa_out (
@@ -276,6 +279,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
276
279
k_scales,
277
280
v_zero_points,
278
281
v_scales,
282
+ is_seq_at_dim_2,
279
283
output);
280
284
}
281
285
@@ -296,9 +300,10 @@ at::Tensor custom_quantized_sdpa_aten(
296
300
const std::optional<at::Tensor>& k_zero_points,
297
301
const std::optional<at::Tensor>& k_scales,
298
302
const std::optional<at::Tensor>& v_zero_points,
299
- const std::optional<at::Tensor>& v_scales) {
303
+ const std::optional<at::Tensor>& v_scales,
304
+ const bool is_seq_at_dim_2) {
300
305
auto output = at::empty (q.sizes ());
301
- WRAP_TO_ATEN (custom_quantized_sdpa_out_no_context, 14 )
306
+ WRAP_TO_ATEN (custom_quantized_sdpa_out_no_context, 15 )
302
307
(q,
303
308
k,
304
309
v,
@@ -313,6 +318,7 @@ at::Tensor custom_quantized_sdpa_aten(
313
318
k_scales,
314
319
v_zero_points,
315
320
v_scales,
321
+ is_seq_at_dim_2,
316
322
output);
317
323
return output;
318
324
}
@@ -371,13 +377,13 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
371
377
" Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
372
378
" float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
373
379
" Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
374
- " Tensor? v_scales=None) -> Tensor" );
380
+ " Tensor? v_scales=None, bool is_seq_at_dim_2=False ) -> Tensor" );
375
381
m.def (
376
382
" custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
377
383
" Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
378
384
" float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
379
385
" Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
380
- " Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)" );
386
+ " Tensor? v_scales=None, bool is_seq_at_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)" );
381
387
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
382
388
}
383
389
@@ -404,6 +410,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
404
410
m.impl (
405
411
" custom_quantized_sdpa.out" ,
406
412
WRAP_TO_ATEN (
407
- torch::executor::native::custom_quantized_sdpa_out_no_context, 14 ));
413
+ torch::executor::native::custom_quantized_sdpa_out_no_context, 15 ));
408
414
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
409
415
}
0 commit comments