Skip to content

Commit c86b278

Browse files
authored
Fix a performance issue with Helion-emitted Flash Attention (#181)
1 parent 9554e7e commit c86b278

File tree

4 files changed

+104
-50
lines changed

4 files changed

+104
-50
lines changed

examples/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
@helion.kernel(
1313
config=helion.Config(
1414
# This config was autotuned on a 3090, it won't be fast for other architectures
15-
block_sizes=[32, 16],
16-
num_warps=1,
17-
num_stages=2,
15+
block_sizes=[128, 64],
16+
num_warps=4,
17+
num_stages=3,
1818
indexing="block_ptr",
1919
),
2020
# Static shapes provides a speedup for attention

helion/_compiler/inductor_lowering.py

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -777,52 +777,106 @@ def apply_dot_requirements(
777777
return LambdaLowering(handler, masked_value_fn=masked_value_fn)
778778

779779

780+
def reduce_3d_dot(
781+
ctx: GraphInterpreter, node: torch.fx.Node, with_acc: bool
782+
) -> ast.AST:
783+
datatype = CompileEnvironment.current().settings.dot_precision
784+
acc = None
785+
if with_acc:
786+
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
787+
assert isinstance(acc, ast.AST)
788+
lhs_node = node.args[1]
789+
rhs_node = node.args[2]
790+
else:
791+
lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
792+
lhs_node = node.args[0]
793+
rhs_node = node.args[1]
794+
assert isinstance(lhs, ast.AST)
795+
assert isinstance(rhs, ast.AST)
796+
797+
lhs_size = lhs_node.meta["val"].size()
798+
rhs_size = rhs_node.meta["val"].size()
799+
# check to see if it is 3D and the highest dim is 1
800+
reduce_dim = False
801+
if len(lhs_size) == 3:
802+
env = CompileEnvironment.current()
803+
lhs_dim_idx = env.get_block_id(lhs_size[0])
804+
rhs_dim_idx = env.get_block_id(rhs_size[0])
805+
if lhs_dim_idx is not None and rhs_dim_idx is not None:
806+
lhs_dim_val = env.block_sizes[lhs_dim_idx]
807+
rhs_dim_val = env.block_sizes[rhs_dim_idx]
808+
if (
809+
lhs_dim_val.from_config(ctx.cg.device_function.config) == 1
810+
and rhs_dim_val.from_config(ctx.cg.device_function.config) == 1
811+
):
812+
reduce_dim = True
813+
814+
if not reduce_dim:
815+
if with_acc:
816+
return expr_from_string(
817+
f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})",
818+
lhs=lhs,
819+
rhs=rhs,
820+
acc=acc,
821+
)
822+
# without accumulator
823+
return expr_from_string(
824+
f"tl.dot(lhs, rhs, input_precision={datatype!r})", lhs=lhs, rhs=rhs
825+
)
826+
827+
# create reshape, dot, then reshape
828+
lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
829+
[*lhs_node.meta["val"].size()[1:]]
830+
)
831+
rhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
832+
[*rhs_node.meta["val"].size()[1:]]
833+
)
834+
out_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
835+
[*node.meta["val"].size()]
836+
)
837+
lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs)
838+
rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs)
839+
if with_acc:
840+
acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
841+
[*node.args[0].meta["val"].size()[1:]]
842+
)
843+
acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc)
844+
comp = expr_from_string(
845+
f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})",
846+
lhs=lhs_reshape,
847+
rhs=rhs_reshape,
848+
acc=acc_reshape,
849+
)
850+
else:
851+
comp = expr_from_string(
852+
f"tl.dot(lhs, rhs, input_precision={datatype!r})",
853+
lhs=lhs_reshape,
854+
rhs=rhs_reshape,
855+
)
856+
return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp)
857+
858+
780859
@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements)
781860
# pyre-fixme[56]
782861
@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements)
783862
def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
784863
assert not node.kwargs, "matmul kwargs not supported"
785-
lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
786-
assert isinstance(lhs, ast.AST)
787-
assert isinstance(rhs, ast.AST)
788-
tf32 = CompileEnvironment.current().settings.dot_precision
789-
return expr_from_string(
790-
f"tl.dot(lhs, rhs, input_precision={tf32!r})", lhs=lhs, rhs=rhs
791-
)
864+
865+
return reduce_3d_dot(ctx, node, False)
792866

793867

794868
# pyre-fixme[56]
795869
@register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements)
796870
def codegen_addmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
797871
assert not node.kwargs, "addmm kwargs not supported"
798-
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
799-
assert isinstance(acc, ast.AST)
800-
assert isinstance(lhs, ast.AST)
801-
assert isinstance(rhs, ast.AST)
802-
tf32 = CompileEnvironment.current().settings.dot_precision
803-
return expr_from_string(
804-
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
805-
lhs=lhs,
806-
rhs=rhs,
807-
acc=acc,
808-
)
872+
return reduce_3d_dot(ctx, node, True)
809873

810874

811875
# pyre-fixme[56]
812876
@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements)
813877
def codegen_baddbmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
814878
assert not node.kwargs, "baddbmm kwargs not supported"
815-
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
816-
assert isinstance(acc, ast.AST)
817-
assert isinstance(lhs, ast.AST)
818-
assert isinstance(rhs, ast.AST)
819-
tf32 = CompileEnvironment.current().settings.dot_precision
820-
return expr_from_string(
821-
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
822-
lhs=lhs,
823-
rhs=rhs,
824-
acc=acc,
825-
)
879+
return reduce_3d_dot(ctx, node, True)
826880

827881

828882
class GenerateASTFromInductor(DefaultHandler):

test/test_examples.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
11441144
l_i_copy_0 = l_i_copy
11451145
acc_copy_0 = acc_copy
11461146
k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None)
1147-
qk = tl.dot(q_copy_0, k, input_precision='tf32')
1147+
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
11481148
amax = tl.max(qk, 2)
11491149
v_0 = 0.18033688
11501150
v_1 = amax * v_0
@@ -1162,7 +1162,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
11621162
subscript_1 = v_8[:, :, None]
11631163
v_11 = acc_copy_0 * subscript_1
11641164
v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
1165-
acc = tl.dot(v_6, v, acc=v_11, input_precision='tf32')
1165+
acc = tl.reshape(tl.dot(tl.reshape(v_6, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
11661166
subscript_2 = l_i[:, :, None]
11671167
v_12 = acc / subscript_2
11681168
tl.store(out + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_12, None)
@@ -1249,7 +1249,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
12491249
l_i_copy_0 = l_i_copy
12501250
acc_copy_0 = acc_copy
12511251
k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1252-
qk = tl.dot(q_copy_0, k, input_precision='tf32')
1252+
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
12531253
amax = tl.max(qk, 2)
12541254
v_0 = tl.full([], 0.18033688, tl.float16)
12551255
v_1 = amax * v_0
@@ -1270,7 +1270,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
12701270
v_13 = acc_copy_0 * subscript_1
12711271
v = tl.load(tl.make_block_ptr(v_view, [64, 512, 64], [32768, 64, 1], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
12721272
v_14 = v_8.to(tl.float16)
1273-
acc = tl.dot(v_14, v, acc=v_13, input_precision='tf32')
1273+
acc = tl.reshape(tl.dot(tl.reshape(v_14, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
12741274
subscript_2 = l_i[:, :, None]
12751275
v_15 = acc / subscript_2
12761276
v_16 = v_15.to(tl.float16)
@@ -1361,7 +1361,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
13611361
l_i_copy_0 = l_i_copy
13621362
acc_copy_0 = acc_copy
13631363
k = tl.load(tl.make_block_ptr(k_view, [k_view_size_0, 64, k_view_size_2], [k_view_stride_0, k_view_stride_1, k_view_stride_2], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
1364-
qk = tl.dot(q_copy_0, k, input_precision='tf32')
1364+
qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3])
13651365
_mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, float('-inf'))
13661366
amax = tl.max(_mask_to_2, 2)
13671367
v_0 = 0.18033688
@@ -1381,7 +1381,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
13811381
subscript_1 = v_8[:, :, None]
13821382
v_11 = acc_copy_0 * subscript_1
13831383
v = tl.load(tl.make_block_ptr(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
1384-
acc = tl.dot(_mask_to_3, v, acc=v_11, input_precision='tf32')
1384+
acc = tl.reshape(tl.dot(tl.reshape(_mask_to_3, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
13851385
subscript_2 = l_i[:, :, None]
13861386
v_12 = acc / subscript_2
13871387
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), v_12, boundary_check=[0, 1, 2])
@@ -1398,10 +1398,10 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
13981398
out = torch.empty_like(q_view)
13991399
sm_scale = 1.0 / math.sqrt(head_dim)
14001400
qk_scale = sm_scale * 1.44269504
1401-
_BLOCK_SIZE_1 = 32
1401+
_BLOCK_SIZE_1 = 128
14021402
_RDIM_SIZE_2 = 64
1403-
_BLOCK_SIZE_3 = 16
1404-
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1, num_stages=2)
1403+
_BLOCK_SIZE_3 = 64
1404+
_attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
14051405
return out.view(q_in.size())
14061406
14071407
def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor):
@@ -1416,11 +1416,11 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
14161416
out = torch.empty_like(q_view)
14171417
sm_scale = 1.0 / math.sqrt(head_dim)
14181418
qk_scale = sm_scale * 1.44269504
1419-
_BLOCK_SIZE_1 = 32
1419+
_BLOCK_SIZE_1 = 128
14201420
_RDIM_SIZE_2 = 64
1421-
_BLOCK_SIZE_3 = 16
1421+
_BLOCK_SIZE_3 = 64
14221422
from helion.runtime.precompile_shim import make_precompiler
1423-
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1, num_stages=2)""",
1423+
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)""",
14241424
)
14251425

14261426
def test_concat(self):

test/test_matmul.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import helion
1010
from helion import Config
11-
from helion._compat import get_triton_tensor_descriptor_class_import_path
1211
from helion._compat import supports_tensor_descriptor
1312
from helion._testing import DEVICE
1413
from helion._testing import code_and_output
@@ -353,13 +352,13 @@ def test_matmul_tensor_descriptor(self):
353352
code = examples_matmul.bind(args).to_triton_code(config)
354353
self.assertExpectedInline(
355354
code,
356-
f"""\
355+
"""\
357356
from __future__ import annotations
358357
359358
import torch
360359
import triton
361360
import triton.language as tl
362-
{get_triton_tensor_descriptor_class_import_path()}
361+
from triton.tools.tensor_descriptor import TensorDescriptor
363362
364363
@triton.jit
365364
def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
@@ -376,15 +375,16 @@ def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK
376375
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
377376
for offset_2 in range(0, 128, _BLOCK_SIZE_2):
378377
acc_copy = acc
378+
acc_copy_0 = acc_copy
379379
load = x_desc.load([offset_0, offset_2])
380380
load_1 = y_desc.load([offset_2, offset_1])
381-
acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32')
381+
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
382382
out_desc.store([offset_0, offset_1], acc)
383383
384384
def matmul(x: torch.Tensor, y: torch.Tensor):
385385
m, k = x.size()
386386
k2, n = y.size()
387-
assert k == k2, f'size mismatch {{k}} != {{k2}}'
387+
assert k == k2, f'size mismatch {k} != {k2}'
388388
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
389389
_BLOCK_SIZE_0 = 16
390390
_BLOCK_SIZE_1 = 16
@@ -395,7 +395,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor):
395395
def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
396396
m, k = x.size()
397397
k2, n = y.size()
398-
assert k == k2, f'size mismatch {{k}} != {{k2}}'
398+
assert k == k2, f'size mismatch {k} != {k2}'
399399
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
400400
_BLOCK_SIZE_0 = 16
401401
_BLOCK_SIZE_1 = 16

0 commit comments

Comments
 (0)