Skip to content

Commit acd6b1c

Browse files
authored
[Executorch][SDPA] Refactor + Make quantized sdpa handle sequence at dim 1 or 2
Differential Revision: D71833060 Pull Request resolved: #9943
1 parent 2197e98 commit acd6b1c

File tree

5 files changed

+189
-62
lines changed

5 files changed

+189
-62
lines changed

extension/llm/custom_ops/op_sdpa.cpp

+20-15
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,14 @@ Tensor& flash_attention_kernel_out(
264264
InvalidArgument,
265265
output);
266266

267-
auto q_seq_len = query.size(2);
267+
auto seq_len = query.size(2);
268268

269269
ET_SWITCH_FLOAT_TYPES(
270270
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
271271
// TODO we need to re-evaluate this for ARM CPUs
272272
// And there can be many so instead of templatizing
273273
// we might consider another appraoch
274-
if (q_seq_len >= 768) {
274+
if (seq_len >= 768) {
275275
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
276276
output,
277277
query,
@@ -287,7 +287,7 @@ Tensor& flash_attention_kernel_out(
287287
nullopt,
288288
nullopt,
289289
nullopt);
290-
} else if (q_seq_len >= 192) {
290+
} else if (seq_len >= 192) {
291291
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
292292
output,
293293
query,
@@ -341,7 +341,8 @@ Tensor& custom_sdpa_out_impl(
341341
const optional<Tensor>& k_zero_points = nullopt,
342342
const optional<Tensor>& k_scales = nullopt,
343343
const optional<Tensor>& v_zero_points = nullopt,
344-
const optional<Tensor>& v_scales = nullopt) {
344+
const optional<Tensor>& v_scales = nullopt,
345+
bool is_seq_at_dim_2 = false) {
345346
ET_KERNEL_CHECK_MSG(
346347
ctx,
347348
!attn_mask.has_value() || !is_causal,
@@ -357,13 +358,15 @@ Tensor& custom_sdpa_out_impl(
357358
"Invalid arguments");
358359

359360
int64_t seq_len = q.size(1);
360-
auto q_seq_len = q.size(1);
361+
SeqDim seq_dim{SeqDim::TWO};
362+
if (!is_seq_at_dim_2) {
363+
seq_dim = SeqDim::ONE;
364+
}
361365

362-
bool is_seq_at_dim_1{true};
363366
if (q.scalar_type() == ScalarType::Char) {
364-
is_seq_at_dim_1 = false;
365-
seq_len = q.size(2);
366-
q_seq_len = q.size(2);
367+
if (seq_dim == SeqDim::TWO) {
368+
seq_len = q.size(2);
369+
}
367370
ET_KERNEL_CHECK_MSG(
368371
ctx,
369372
q_scales.has_value() && q_zero_points.has_value() &&
@@ -412,7 +415,7 @@ Tensor& custom_sdpa_out_impl(
412415
// TODO we need to re-evaluate this for ARM CPUs
413416
// And there can be many so instead of templatizing
414417
// we might consider another appraoch
415-
if (q_seq_len >= 768) {
418+
if (seq_len >= 768) {
416419
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
417420
output,
418421
q,
@@ -428,10 +431,10 @@ Tensor& custom_sdpa_out_impl(
428431
k_scales, // k_scales
429432
v_zero_points, // v_zero_points
430433
v_scales, // v_scales
431-
is_seq_at_dim_1, /* is_seq_at_dim_1 */
434+
seq_dim, /* seq_dim */
432435
start_pos,
433436
num_keys_for_causal_attention);
434-
} else if (q_seq_len >= 192) {
437+
} else if (seq_len >= 192) {
435438
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
436439
output,
437440
q,
@@ -447,7 +450,7 @@ Tensor& custom_sdpa_out_impl(
447450
k_scales, // k_scales
448451
v_zero_points, // v_zero_points
449452
v_scales, // v_scales
450-
is_seq_at_dim_1, /* is_seq_at_dim_1 */
453+
seq_dim, /* seq_dim */
451454
start_pos,
452455
num_keys_for_causal_attention);
453456
} else {
@@ -466,7 +469,7 @@ Tensor& custom_sdpa_out_impl(
466469
k_scales, // k_scales
467470
v_zero_points, // v_zero_points
468471
v_scales, // v_scales
469-
is_seq_at_dim_1, /* is_seq_at_dim_1 */
472+
seq_dim, /* seq_dim */
470473
start_pos,
471474
num_keys_for_causal_attention);
472475
}
@@ -492,6 +495,7 @@ Tensor& custom_quantized_sdpa_out(
492495
const optional<Tensor>& k_scales,
493496
const optional<Tensor>& v_zero_points,
494497
const optional<Tensor>& v_scales,
498+
const bool is_seq_at_dim_2,
495499
Tensor& output) {
496500
return custom_sdpa_out_impl(
497501
ctx,
@@ -509,7 +513,8 @@ Tensor& custom_quantized_sdpa_out(
509513
k_zero_points,
510514
k_scales,
511515
v_zero_points,
512-
v_scales);
516+
v_scales,
517+
is_seq_at_dim_2);
513518
}
514519
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
515520

extension/llm/custom_ops/op_sdpa.h

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ Tensor& custom_quantized_sdpa_out(
7474
const optional<Tensor>& k_scales,
7575
const optional<Tensor>& v_zero_points,
7676
const optional<Tensor>& v_scales,
77+
const bool is_seq_at_dim_1,
7778
Tensor& output);
7879
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
7980
} // namespace native

extension/llm/custom_ops/op_sdpa_aot.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
9696
const optional<Tensor> k_scales,
9797
const optional<Tensor> v_zero_points,
9898
const optional<Tensor> v_scales,
99+
const bool is_seq_at_dim_2,
99100
Tensor& output);
100101

101102
at::Tensor custom_quantized_sdpa_aten(
@@ -115,7 +116,8 @@ at::Tensor custom_quantized_sdpa_aten(
115116
const std::optional<at::Tensor>& k_zero_points,
116117
const std::optional<at::Tensor>& k_scales,
117118
const std::optional<at::Tensor>& v_zero_points,
118-
const std::optional<at::Tensor>& v_scales);
119+
const std::optional<at::Tensor>& v_scales,
120+
const bool is_seq_at_dim_2);
119121
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
120122

121123
Tensor& update_cache_out_no_context(
@@ -258,6 +260,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
258260
const optional<Tensor> k_scales,
259261
const optional<Tensor> v_zero_points,
260262
const optional<Tensor> v_scales,
263+
const bool is_seq_at_dim_2,
261264
Tensor& output) {
262265
executorch::aten::RuntimeContext context{};
263266
return torch::executor::native::custom_quantized_sdpa_out(
@@ -276,6 +279,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
276279
k_scales,
277280
v_zero_points,
278281
v_scales,
282+
is_seq_at_dim_2,
279283
output);
280284
}
281285

@@ -296,9 +300,10 @@ at::Tensor custom_quantized_sdpa_aten(
296300
const std::optional<at::Tensor>& k_zero_points,
297301
const std::optional<at::Tensor>& k_scales,
298302
const std::optional<at::Tensor>& v_zero_points,
299-
const std::optional<at::Tensor>& v_scales) {
303+
const std::optional<at::Tensor>& v_scales,
304+
const bool is_seq_at_dim_2) {
300305
auto output = at::empty(q.sizes());
301-
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 14)
306+
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15)
302307
(q,
303308
k,
304309
v,
@@ -313,6 +318,7 @@ at::Tensor custom_quantized_sdpa_aten(
313318
k_scales,
314319
v_zero_points,
315320
v_scales,
321+
is_seq_at_dim_2,
316322
output);
317323
return output;
318324
}
@@ -371,13 +377,13 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
371377
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
372378
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
373379
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
374-
"Tensor? v_scales=None) -> Tensor");
380+
"Tensor? v_scales=None, bool is_seq_at_dim_2=False) -> Tensor");
375381
m.def(
376382
"custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
377383
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
378384
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
379385
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
380-
"Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)");
386+
"Tensor? v_scales=None, bool is_seq_at_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)");
381387
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
382388
}
383389

@@ -404,6 +410,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
404410
m.impl(
405411
"custom_quantized_sdpa.out",
406412
WRAP_TO_ATEN(
407-
torch::executor::native::custom_quantized_sdpa_out_no_context, 14));
413+
torch::executor::native::custom_quantized_sdpa_out_no_context, 15));
408414
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
409415
}

0 commit comments

Comments
 (0)