Skip to content

[Executorch][SDPA] Refactor + Make quantized sdpa handle sequence at dim 1 or 2 #9943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
@@ -264,14 +264,14 @@ Tensor& flash_attention_kernel_out(
InvalidArgument,
output);

auto q_seq_len = query.size(2);
auto seq_len = query.size(2);

ET_SWITCH_FLOAT_TYPES(
query.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
output,
query,
@@ -287,7 +287,7 @@ Tensor& flash_attention_kernel_out(
nullopt,
nullopt,
nullopt);
} else if (q_seq_len >= 192) {
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
query,
@@ -341,7 +341,8 @@ Tensor& custom_sdpa_out_impl(
const optional<Tensor>& k_zero_points = nullopt,
const optional<Tensor>& k_scales = nullopt,
const optional<Tensor>& v_zero_points = nullopt,
const optional<Tensor>& v_scales = nullopt) {
const optional<Tensor>& v_scales = nullopt,
bool is_seq_at_dim_2 = false) {
ET_KERNEL_CHECK_MSG(
ctx,
!attn_mask.has_value() || !is_causal,
@@ -357,13 +358,15 @@ Tensor& custom_sdpa_out_impl(
"Invalid arguments");

int64_t seq_len = q.size(1);
auto q_seq_len = q.size(1);
SeqDim seq_dim{SeqDim::TWO};
if (!is_seq_at_dim_2) {
seq_dim = SeqDim::ONE;
}

bool is_seq_at_dim_1{true};
if (q.scalar_type() == ScalarType::Char) {
is_seq_at_dim_1 = false;
seq_len = q.size(2);
q_seq_len = q.size(2);
if (seq_dim == SeqDim::TWO) {
seq_len = q.size(2);
}
ET_KERNEL_CHECK_MSG(
ctx,
q_scales.has_value() && q_zero_points.has_value() &&
@@ -412,7 +415,7 @@ Tensor& custom_sdpa_out_impl(
// TODO we need to re-evaluate this for ARM CPUs
// And there can be many so instead of templatizing
// we might consider another appraoch
if (q_seq_len >= 768) {
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
output,
q,
@@ -428,10 +431,10 @@ Tensor& custom_sdpa_out_impl(
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
seq_dim, /* seq_dim */
start_pos,
num_keys_for_causal_attention);
} else if (q_seq_len >= 192) {
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
output,
q,
@@ -447,7 +450,7 @@ Tensor& custom_sdpa_out_impl(
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
seq_dim, /* seq_dim */
start_pos,
num_keys_for_causal_attention);
} else {
@@ -466,7 +469,7 @@ Tensor& custom_sdpa_out_impl(
k_scales, // k_scales
v_zero_points, // v_zero_points
v_scales, // v_scales
is_seq_at_dim_1, /* is_seq_at_dim_1 */
seq_dim, /* seq_dim */
start_pos,
num_keys_for_causal_attention);
}
@@ -492,6 +495,7 @@ Tensor& custom_quantized_sdpa_out(
const optional<Tensor>& k_scales,
const optional<Tensor>& v_zero_points,
const optional<Tensor>& v_scales,
const bool is_seq_at_dim_2,
Tensor& output) {
return custom_sdpa_out_impl(
ctx,
@@ -509,7 +513,8 @@ Tensor& custom_quantized_sdpa_out(
k_zero_points,
k_scales,
v_zero_points,
v_scales);
v_scales,
is_seq_at_dim_2);
}
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA

1 change: 1 addition & 0 deletions extension/llm/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
@@ -74,6 +74,7 @@ Tensor& custom_quantized_sdpa_out(
const optional<Tensor>& k_scales,
const optional<Tensor>& v_zero_points,
const optional<Tensor>& v_scales,
const bool is_seq_at_dim_1,
Tensor& output);
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
} // namespace native
18 changes: 12 additions & 6 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
@@ -96,6 +96,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
const optional<Tensor> k_scales,
const optional<Tensor> v_zero_points,
const optional<Tensor> v_scales,
const bool is_seq_at_dim_2,
Tensor& output);

at::Tensor custom_quantized_sdpa_aten(
@@ -115,7 +116,8 @@ at::Tensor custom_quantized_sdpa_aten(
const std::optional<at::Tensor>& k_zero_points,
const std::optional<at::Tensor>& k_scales,
const std::optional<at::Tensor>& v_zero_points,
const std::optional<at::Tensor>& v_scales);
const std::optional<at::Tensor>& v_scales,
const bool is_seq_at_dim_2);
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA

Tensor& update_cache_out_no_context(
@@ -258,6 +260,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
const optional<Tensor> k_scales,
const optional<Tensor> v_zero_points,
const optional<Tensor> v_scales,
const bool is_seq_at_dim_2,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::custom_quantized_sdpa_out(
@@ -276,6 +279,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
k_scales,
v_zero_points,
v_scales,
is_seq_at_dim_2,
output);
}

@@ -296,9 +300,10 @@ at::Tensor custom_quantized_sdpa_aten(
const std::optional<at::Tensor>& k_zero_points,
const std::optional<at::Tensor>& k_scales,
const std::optional<at::Tensor>& v_zero_points,
const std::optional<at::Tensor>& v_scales) {
const std::optional<at::Tensor>& v_scales,
const bool is_seq_at_dim_2) {
auto output = at::empty(q.sizes());
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 14)
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15)
(q,
k,
v,
@@ -313,6 +318,7 @@ at::Tensor custom_quantized_sdpa_aten(
k_scales,
v_zero_points,
v_scales,
is_seq_at_dim_2,
output);
return output;
}
@@ -371,13 +377,13 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
"Tensor? v_scales=None) -> Tensor");
"Tensor? v_scales=None, bool is_seq_at_dim_2=False) -> Tensor");
m.def(
"custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
"Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
"Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)");
"Tensor? v_scales=None, bool is_seq_at_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)");
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
}

@@ -404,6 +410,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"custom_quantized_sdpa.out",
WRAP_TO_ATEN(
torch::executor::native::custom_quantized_sdpa_out_no_context, 14));
torch::executor::native::custom_quantized_sdpa_out_no_context, 15));
#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
}
Loading