diff --git a/examples/attention.py b/examples/attention.py index 14698e45..87811295 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -12,9 +12,9 @@ @helion.kernel( config=helion.Config( # This config was autotuned on a 3090, it won't be fast for other architectures - block_sizes=[32, 16], - num_warps=1, - num_stages=2, + block_sizes=[128, 64], + num_warps=4, + num_stages=3, indexing="block_ptr", ), # Static shapes provides a speedup for attention diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 1eea5deb..3efcae99 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -777,52 +777,106 @@ def apply_dot_requirements( return LambdaLowering(handler, masked_value_fn=masked_value_fn) +def reduce_3d_dot( + ctx: GraphInterpreter, node: torch.fx.Node, with_acc: bool +) -> ast.AST: + datatype = CompileEnvironment.current().settings.dot_precision + acc = None + if with_acc: + acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) + assert isinstance(acc, ast.AST) + lhs_node = node.args[1] + rhs_node = node.args[2] + else: + lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) + lhs_node = node.args[0] + rhs_node = node.args[1] + assert isinstance(lhs, ast.AST) + assert isinstance(rhs, ast.AST) + + lhs_size = lhs_node.meta["val"].size() + rhs_size = rhs_node.meta["val"].size() + # check to see if it is 3D and the highest dim is 1 + reduce_dim = False + if len(lhs_size) == 3: + env = CompileEnvironment.current() + lhs_dim_idx = env.get_block_id(lhs_size[0]) + rhs_dim_idx = env.get_block_id(rhs_size[0]) + if lhs_dim_idx is not None and rhs_dim_idx is not None: + lhs_dim_val = env.block_sizes[lhs_dim_idx] + rhs_dim_val = env.block_sizes[rhs_dim_idx] + if ( + lhs_dim_val.from_config(ctx.cg.device_function.config) == 1 + and rhs_dim_val.from_config(ctx.cg.device_function.config) == 1 + ): + reduce_dim = True + + if not reduce_dim: + if with_acc: + return expr_from_string( + f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})", + lhs=lhs, + rhs=rhs, + acc=acc, + ) + # without accumulator + return expr_from_string( + f"tl.dot(lhs, rhs, input_precision={datatype!r})", lhs=lhs, rhs=rhs + ) + + # create reshape, dot, then reshape + lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str( + [*lhs_node.meta["val"].size()[1:]] + ) + rhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str( + [*rhs_node.meta["val"].size()[1:]] + ) + out_shape_str = ctx.cg.device_function.tile_strategy.shape_str( + [*node.meta["val"].size()] + ) + lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs) + rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs) + if with_acc: + acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str( + [*node.args[0].meta["val"].size()[1:]] + ) + acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc) + comp = expr_from_string( + f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})", + lhs=lhs_reshape, + rhs=rhs_reshape, + acc=acc_reshape, + ) + else: + comp = expr_from_string( + f"tl.dot(lhs, rhs, input_precision={datatype!r})", + lhs=lhs_reshape, + rhs=rhs_reshape, + ) + return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp) + + @register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements) # pyre-fixme[56] @register_lowering(torch.ops.aten.mm.default, apply_dot_requirements) def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST: assert not node.kwargs, "matmul kwargs not supported" - lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(lhs, ast.AST) - assert isinstance(rhs, ast.AST) - tf32 = CompileEnvironment.current().settings.dot_precision - return expr_from_string( - f"tl.dot(lhs, rhs, input_precision={tf32!r})", lhs=lhs, rhs=rhs - ) + + return reduce_3d_dot(ctx, node, False) # pyre-fixme[56] @register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements) def codegen_addmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST: assert not node.kwargs, "addmm kwargs not supported" - acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(acc, ast.AST) - assert isinstance(lhs, ast.AST) - assert isinstance(rhs, ast.AST) - tf32 = CompileEnvironment.current().settings.dot_precision - return expr_from_string( - f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})", - lhs=lhs, - rhs=rhs, - acc=acc, - ) + return reduce_3d_dot(ctx, node, True) # pyre-fixme[56] @register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements) def codegen_baddbmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST: assert not node.kwargs, "baddbmm kwargs not supported" - acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg]) - assert isinstance(acc, ast.AST) - assert isinstance(lhs, ast.AST) - assert isinstance(rhs, ast.AST) - tf32 = CompileEnvironment.current().settings.dot_precision - return expr_from_string( - f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})", - lhs=lhs, - rhs=rhs, - acc=acc, - ) + return reduce_3d_dot(ctx, node, True) class GenerateASTFromInductor(DefaultHandler): diff --git a/test/test_examples.py b/test/test_examples.py index 0bd6578c..ec7f7e20 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1144,7 +1144,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, l_i_copy_0 = l_i_copy acc_copy_0 = acc_copy k = tl.load(k_view + (indices_0[:, None, None] * 32768 + indices_4[None, :, None] * 1 + indices_2[None, None, :] * 64), None) - qk = tl.dot(q_copy_0, k, input_precision='tf32') + qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]) amax = tl.max(qk, 2) v_0 = 0.18033688 v_1 = amax * v_0 @@ -1162,7 +1162,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, subscript_1 = v_8[:, :, None] v_11 = acc_copy_0 * subscript_1 v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None) - acc = tl.dot(v_6, v, acc=v_11, input_precision='tf32') + acc = tl.reshape(tl.dot(tl.reshape(v_6, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64]) subscript_2 = l_i[:, :, None] v_12 = acc / subscript_2 tl.store(out + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_12, None) @@ -1249,7 +1249,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, l_i_copy_0 = l_i_copy acc_copy_0 = acc_copy k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero') - qk = tl.dot(q_copy_0, k, input_precision='tf32') + qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]) amax = tl.max(qk, 2) v_0 = tl.full([], 0.18033688, tl.float16) v_1 = amax * v_0 @@ -1270,7 +1270,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, v_13 = acc_copy_0 * subscript_1 v = tl.load(tl.make_block_ptr(v_view, [64, 512, 64], [32768, 64, 1], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero') v_14 = v_8.to(tl.float16) - acc = tl.dot(v_14, v, acc=v_13, input_precision='tf32') + acc = tl.reshape(tl.dot(tl.reshape(v_14, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64]) subscript_2 = l_i[:, :, None] v_15 = acc / subscript_2 v_16 = v_15.to(tl.float16) @@ -1361,7 +1361,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, l_i_copy_0 = l_i_copy acc_copy_0 = acc_copy k = tl.load(tl.make_block_ptr(k_view, [k_view_size_0, 64, k_view_size_2], [k_view_stride_0, k_view_stride_1, k_view_stride_2], [offset_0, 0, offset_2], [1, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero') - qk = tl.dot(q_copy_0, k, input_precision='tf32') + qk = tl.reshape(tl.dot(tl.reshape(q_copy_0, [_BLOCK_SIZE_1, 64]), tl.reshape(k, [64, _BLOCK_SIZE_3]), input_precision='tf32'), [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]) _mask_to_2 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), qk, float('-inf')) amax = tl.max(_mask_to_2, 2) v_0 = 0.18033688 @@ -1381,7 +1381,7 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2, subscript_1 = v_8[:, :, None] v_11 = acc_copy_0 * subscript_1 v = tl.load(tl.make_block_ptr(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero') - acc = tl.dot(_mask_to_3, v, acc=v_11, input_precision='tf32') + acc = tl.reshape(tl.dot(tl.reshape(_mask_to_3, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64]) subscript_2 = l_i[:, :, None] v_12 = acc / subscript_2 tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), v_12, boundary_check=[0, 1, 2]) @@ -1398,10 +1398,10 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor): out = torch.empty_like(q_view) sm_scale = 1.0 / math.sqrt(head_dim) qk_scale = sm_scale * 1.44269504 - _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_1 = 128 _RDIM_SIZE_2 = 64 - _BLOCK_SIZE_3 = 16 - _attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1, num_stages=2) + _BLOCK_SIZE_3 = 64 + _attention_kernel[q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),](q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) return out.view(q_in.size()) def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor): @@ -1416,11 +1416,11 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to out = torch.empty_like(q_view) sm_scale = 1.0 / math.sqrt(head_dim) qk_scale = sm_scale * 1.44269504 - _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_1 = 128 _RDIM_SIZE_2 = 64 - _BLOCK_SIZE_3 = 16 + _BLOCK_SIZE_3 = 64 from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=1, num_stages=2)""", + return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, k_view.size(0), k_view.size(2), out.size(0), out.size(1), q_in.size(1), q_view.size(0), q_view.size(1), v_view.size(0), v_view.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)""", ) def test_concat(self): diff --git a/test/test_matmul.py b/test/test_matmul.py index 5033a347..63734aa8 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -8,7 +8,6 @@ import helion from helion import Config -from helion._compat import get_triton_tensor_descriptor_class_import_path from helion._compat import supports_tensor_descriptor from helion._testing import DEVICE from helion._testing import code_and_output @@ -353,13 +352,13 @@ def test_matmul_tensor_descriptor(self): code = examples_matmul.bind(args).to_triton_code(config) self.assertExpectedInline( code, - f"""\ + """\ from __future__ import annotations import torch import triton import triton.language as tl -{get_triton_tensor_descriptor_class_import_path()} +from triton.tools.tensor_descriptor import TensorDescriptor @triton.jit def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): @@ -376,15 +375,16 @@ def _matmul_kernel(x_desc, y_desc, out_desc, _BLOCK_SIZE_0: tl.constexpr, _BLOCK acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) for offset_2 in range(0, 128, _BLOCK_SIZE_2): acc_copy = acc + acc_copy_0 = acc_copy load = x_desc.load([offset_0, offset_2]) load_1 = y_desc.load([offset_2, offset_1]) - acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32') + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') out_desc.store([offset_0, offset_1], acc) def matmul(x: torch.Tensor, y: torch.Tensor): m, k = x.size() k2, n = y.size() - assert k == k2, f'size mismatch {{k}} != {{k2}}' + assert k == k2, f'size mismatch {k} != {k2}' out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) _BLOCK_SIZE_0 = 16 _BLOCK_SIZE_1 = 16 @@ -395,7 +395,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor): def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor): m, k = x.size() k2, n = y.size() - assert k == k2, f'size mismatch {{k}} != {{k2}}' + assert k == k2, f'size mismatch {k} != {k2}' out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) _BLOCK_SIZE_0 = 16 _BLOCK_SIZE_1 = 16