Skip to content

Fix a performance issue with Helion-emitted Flash Attention #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
@helion.kernel(
config=helion.Config(
# This config was autotuned on a 3090, it won't be fast for other architectures
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment still valid? maybe delete it?

block_sizes=[32, 16],
num_warps=1,
num_stages=2,
block_sizes=[128, 64],
num_warps=4,
num_stages=3,
indexing="block_ptr",
),
# Static shapes provides a speedup for attention
Expand Down
112 changes: 83 additions & 29 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,52 +777,106 @@ def apply_dot_requirements(
return LambdaLowering(handler, masked_value_fn=masked_value_fn)


def reduce_3d_dot(
ctx: GraphInterpreter, node: torch.fx.Node, with_acc: bool
) -> ast.AST:
datatype = CompileEnvironment.current().settings.dot_precision
acc = None
if with_acc:
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
assert isinstance(acc, ast.AST)
lhs_node = node.args[1]
rhs_node = node.args[2]
else:
lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
lhs_node = node.args[0]
rhs_node = node.args[1]
assert isinstance(lhs, ast.AST)
assert isinstance(rhs, ast.AST)

lhs_size = lhs_node.meta["val"].size()
rhs_size = rhs_node.meta["val"].size()
# check to see if it is 3D and the highest dim is 1
reduce_dim = False
if len(lhs_size) == 3:
env = CompileEnvironment.current()
lhs_dim_idx = env.get_block_id(lhs_size[0])
rhs_dim_idx = env.get_block_id(rhs_size[0])
if lhs_dim_idx is not None and rhs_dim_idx is not None:
lhs_dim_val = env.block_sizes[lhs_dim_idx]
rhs_dim_val = env.block_sizes[rhs_dim_idx]
if (
lhs_dim_val.from_config(ctx.cg.device_function.config) == 1
and rhs_dim_val.from_config(ctx.cg.device_function.config) == 1
):
reduce_dim = True

if not reduce_dim:
if with_acc:
return expr_from_string(
f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})",
lhs=lhs,
rhs=rhs,
acc=acc,
)
# without accumulator
return expr_from_string(
f"tl.dot(lhs, rhs, input_precision={datatype!r})", lhs=lhs, rhs=rhs
)

# create reshape, dot, then reshape
lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*lhs_node.meta["val"].size()[1:]]
)
rhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*rhs_node.meta["val"].size()[1:]]
)
out_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*node.meta["val"].size()]
)
lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs)
rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs)
if with_acc:
acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
[*node.args[0].meta["val"].size()[1:]]
)
acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc)
comp = expr_from_string(
f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})",
lhs=lhs_reshape,
rhs=rhs_reshape,
acc=acc_reshape,
)
else:
comp = expr_from_string(
f"tl.dot(lhs, rhs, input_precision={datatype!r})",
lhs=lhs_reshape,
rhs=rhs_reshape,
)
return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp)


@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements)
# pyre-fixme[56]
@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements)
def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
assert not node.kwargs, "matmul kwargs not supported"
lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
assert isinstance(lhs, ast.AST)
assert isinstance(rhs, ast.AST)
tf32 = CompileEnvironment.current().settings.dot_precision
return expr_from_string(
f"tl.dot(lhs, rhs, input_precision={tf32!r})", lhs=lhs, rhs=rhs
)

return reduce_3d_dot(ctx, node, False)


# pyre-fixme[56]
@register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements)
def codegen_addmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
assert not node.kwargs, "addmm kwargs not supported"
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
assert isinstance(acc, ast.AST)
assert isinstance(lhs, ast.AST)
assert isinstance(rhs, ast.AST)
tf32 = CompileEnvironment.current().settings.dot_precision
return expr_from_string(
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
lhs=lhs,
rhs=rhs,
acc=acc,
)
return reduce_3d_dot(ctx, node, True)


# pyre-fixme[56]
@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements)
def codegen_baddbmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
assert not node.kwargs, "baddbmm kwargs not supported"
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
assert isinstance(acc, ast.AST)
assert isinstance(lhs, ast.AST)
assert isinstance(rhs, ast.AST)
tf32 = CompileEnvironment.current().settings.dot_precision
return expr_from_string(
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
lhs=lhs,
rhs=rhs,
acc=acc,
)
return reduce_3d_dot(ctx, node, True)


class GenerateASTFromInductor(DefaultHandler):
Expand Down
24 changes: 12 additions & 12 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
l_i_copy_0 = l_i_copy
acc_copy_0 = acc_copy
k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None)
qk = tl.dot(q_copy_0, k, input_precision='tf32')
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
amax = tl.max(qk, 2)
v_0 = 0.18033688
v_1 = amax * v_0
Expand All @@ -1162,7 +1162,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
subscript_1 = v_8[:, :, None]
v_11 = acc_copy_0 * subscript_1
v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
acc = tl.dot(v_6, v, acc=v_11, input_precision='tf32')
acc = tl.reshape(tl.dot(tl.reshape(v_6, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
subscript_2 = l_i[:, :, None]
v_12 = acc / subscript_2
tl.store(out + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_12, None)
Expand Down Expand Up @@ -1249,7 +1249,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
l_i_copy_0 = l_i_copy
acc_copy_0 = acc_copy
k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
qk = tl.dot(q_copy_0, k, input_precision='tf32')
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
amax = tl.max(qk, 2)
v_0 = tl.full([], 0.18033688, tl.float16)
v_1 = amax * v_0
Expand All @@ -1270,7 +1270,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
v_13 = acc_copy_0 * subscript_1
v = tl.load(tl.make_block_ptr(v_view, [64, 512, 64], [32768, 64, 1], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
v_14 = v_8.to(tl.float16)
acc = tl.dot(v_14, v, acc=v_13, input_precision='tf32')
acc = tl.reshape(tl.dot(tl.reshape(v_14, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
subscript_2 = l_i[:, :, None]
v_15 = acc / subscript_2
v_16 = v_15.to(tl.float16)
Expand Down Expand Up @@ -1361,7 +1361,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
l_i_copy_0 = l_i_copy
acc_copy_0 = acc_copy
k = tl.load(tl.make_block_ptr(k_view, [k_view_size_0, 64, k_view_size_2], [k_view_stride_0, k_view_stride_1, k_view_stride_2], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
qk = tl.dot(q_copy_0, k, input_precision='tf32')
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, float('-inf'))
amax = tl.max(_mask_to_2, 2)
v_0 = 0.18033688
Expand All @@ -1381,7 +1381,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
subscript_1 = v_8[:, :, None]
v_11 = acc_copy_0 * subscript_1
v = tl.load(tl.make_block_ptr(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
acc = tl.dot(_mask_to_3, v, acc=v_11, input_precision='tf32')
acc = tl.reshape(tl.dot(tl.reshape(_mask_to_3, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
subscript_2 = l_i[:, :, None]
v_12 = acc / subscript_2
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), v_12, boundary_check=[0, 1, 2])
Expand All @@ -1398,10 +1398,10 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
out = torch.empty_like(q_view)
sm_scale = 1.0 / math.sqrt(head_dim)
qk_scale = sm_scale * 1.44269504
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_1 = 128
_RDIM_SIZE_2 = 64
_BLOCK_SIZE_3 = 16
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1, num_stages=2)
_BLOCK_SIZE_3 = 64
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return out.view(q_in.size())

def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
Expand All @@ -1416,11 +1416,11 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
out = torch.empty_like(q_view)
sm_scale = 1.0 / math.sqrt(head_dim)
qk_scale = sm_scale * 1.44269504
_BLOCK_SIZE_1 = 32
_BLOCK_SIZE_1 = 128
_RDIM_SIZE_2 = 64
_BLOCK_SIZE_3 = 16
_BLOCK_SIZE_3 = 64
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1, num_stages=2)""",
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)""",
)

def test_concat(self):
Expand Down
12 changes: 6 additions & 6 deletions test/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import helion
from helion import Config
from helion._compat import get_triton_tensor_descriptor_class_import_path
from helion._compat import supports_tensor_descriptor
from helion._testing import DEVICE
from helion._testing import code_and_output
Expand Down Expand Up @@ -353,13 +352,13 @@ def test_matmul_tensor_descriptor(self):
code = examples_matmul.bind(args).to_triton_code(config)
self.assertExpectedInline(
code,
f"""\
"""\
from __future__ import annotations

import torch
import triton
import triton.language as tl
{get_triton_tensor_descriptor_class_import_path()}
from triton.tools.tensor_descriptor import TensorDescriptor

@triton.jit
def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
Expand All @@ -376,15 +375,16 @@ def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
for offset_2 in range(0, 128, _BLOCK_SIZE_2):
acc_copy = acc
acc_copy_0 = acc_copy
load = x_desc.load([offset_0, offset_2])
load_1 = y_desc.load([offset_2, offset_1])
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
out_desc.store([offset_0, offset_1], acc)

def matmul(x: torch.Tensor, y: torch.Tensor):
m, k = x.size()
k2, n = y.size()
assert k == k2, f'size mismatch {{k}} != {{k2}}'
assert k == k2, f'size mismatch {k} != {k2}'
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 16
Expand All @@ -395,7 +395,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor):
def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
m, k = x.size()
k2, n = y.size()
assert k == k2, f'size mismatch {{k}} != {{k2}}'
assert k == k2, f'size mismatch {k} != {k2}'
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 16
Expand Down
Loading