Skip to content

Commit d1134cf

Browse files
committed
refactoring
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent fead68f commit d1134cf

File tree

1 file changed

+55
-79
lines changed

1 file changed

+55
-79
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 55 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -777,17 +777,23 @@ def apply_dot_requirements(
777777
return LambdaLowering(handler, masked_value_fn=masked_value_fn)
778778

779779

780-
@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements)
781-
# pyre-fixme[56]
782-
@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements)
783-
def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
784-
assert not node.kwargs, "matmul kwargs not supported"
785-
lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
780+
def reduce_3d_dot(ctx: GraphInterpreter, node: torch.fx.Node, withAcc: bool) -> ast.AST:
781+
datatype = CompileEnvironment.current().settings.dot_precision
782+
if withAcc:
783+
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
784+
assert isinstance(acc, ast.AST)
785+
lhsNode = node.args[1]
786+
rhsNode = node.args[2]
787+
else:
788+
lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
789+
lhsNode = node.args[0]
790+
rhsNode = node.args[1]
786791
assert isinstance(lhs, ast.AST)
787792
assert isinstance(rhs, ast.AST)
788-
lhsSize = node.args[0].meta["val"].size()
789-
rhsSize = node.args[1].meta["val"].size()
790-
# check to see if it is 3D
793+
794+
lhsSize = lhsNode.meta["val"].size()
795+
rhsSize = rhsNode.meta["val"].size()
796+
# check to see if it is 3D and the highest dim is 1
791797
reduceDim = False
792798
if len(lhsSize) == 3:
793799
env = CompileEnvironment.current()
@@ -802,102 +808,72 @@ def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
802808
):
803809
reduceDim = True
804810

805-
tf32 = CompileEnvironment.current().settings.dot_precision
806811
if not reduceDim:
812+
if withAcc:
813+
return expr_from_string(
814+
f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})",
815+
lhs=lhs,
816+
rhs=rhs,
817+
acc=acc,
818+
)
819+
# without accumulator
807820
return expr_from_string(
808-
f"tl.dot(lhs, rhs, input_precision={tf32!r})", lhs=lhs, rhs=rhs
821+
f"tl.dot(lhs, rhs, input_precision={datatype!r})", lhs=lhs, rhs=rhs
809822
)
823+
810824
# create reshape, dot, then reshape
811825
lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
812-
[*node.args[0].meta["val"].size()[1:]]
826+
[*lhsNode.meta["val"].size()[1:]]
813827
)
814828
rhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
815-
[*node.args[1].meta["val"].size()[1:]]
829+
[*rhsNode.meta["val"].size()[1:]]
816830
)
817831
out_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
818832
[*node.meta["val"].size()]
819833
)
820834
lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs)
821835
rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs)
822-
comp = expr_from_string(
823-
f"tl.dot(lhs, rhs, input_precision={tf32!r})", lhs=lhs_reshape, rhs=rhs_reshape
824-
)
836+
if withAcc:
837+
acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
838+
[*node.args[0].meta["val"].size()[1:]]
839+
)
840+
acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc)
841+
comp = expr_from_string(
842+
f"tl.dot(lhs, rhs, acc=acc, input_precision={datatype!r})",
843+
lhs=lhs_reshape,
844+
rhs=rhs_reshape,
845+
acc=acc_reshape,
846+
)
847+
else:
848+
comp = expr_from_string(
849+
f"tl.dot(lhs, rhs, input_precision={datatype!r})",
850+
lhs=lhs_reshape,
851+
rhs=rhs_reshape,
852+
)
825853
return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp)
826854

827855

856+
@register_lowering(torch.ops.aten.bmm.default, apply_dot_requirements)
857+
# pyre-fixme[56]
858+
@register_lowering(torch.ops.aten.mm.default, apply_dot_requirements)
859+
def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
860+
assert not node.kwargs, "matmul kwargs not supported"
861+
862+
return reduce_3d_dot(ctx, node, False)
863+
864+
828865
# pyre-fixme[56]
829866
@register_lowering(torch.ops.aten.addmm.default, apply_dot_requirements)
830867
def codegen_addmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
831868
assert not node.kwargs, "addmm kwargs not supported"
832-
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
833-
assert isinstance(acc, ast.AST)
834-
assert isinstance(lhs, ast.AST)
835-
assert isinstance(rhs, ast.AST)
836-
tf32 = CompileEnvironment.current().settings.dot_precision
837-
return expr_from_string(
838-
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
839-
lhs=lhs,
840-
rhs=rhs,
841-
acc=acc,
842-
)
869+
return reduce_3d_dot(ctx, node, True)
843870

844871

845872
# pyre-fixme[56]
846873
@register_lowering(torch.ops.aten.baddbmm.default, apply_dot_requirements)
847874
def codegen_baddbmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
848875
assert not node.kwargs, "baddbmm kwargs not supported"
849-
acc, lhs, rhs = map_arg(node.args, lambda arg: ctx.env[arg])
850-
assert isinstance(acc, ast.AST)
851-
assert isinstance(lhs, ast.AST)
852-
assert isinstance(rhs, ast.AST)
853-
tf32 = CompileEnvironment.current().settings.dot_precision
854-
lhsSize = node.args[1].meta["val"].size()
855-
rhsSize = node.args[2].meta["val"].size()
856-
# check to see if it is 3D
857-
reduceDim = False
858-
if len(lhsSize) == 3:
859-
env = CompileEnvironment.current()
860-
lhsDimIdx = env.get_block_id(lhsSize[0])
861-
rhsDimIdx = env.get_block_id(rhsSize[0])
862-
if lhsDimIdx is not None and rhsDimIdx is not None:
863-
lhsDimVal = env.block_sizes[lhsDimIdx]
864-
rhsDimVal = env.block_sizes[rhsDimIdx]
865-
if (
866-
lhsDimVal.from_config(ctx.cg.device_function.config) == 1
867-
and rhsDimVal.from_config(ctx.cg.device_function.config) == 1
868-
):
869-
reduceDim = True
870-
871-
if not reduceDim:
872-
return expr_from_string(
873-
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
874-
lhs=lhs,
875-
rhs=rhs,
876-
acc=acc,
877-
)
878-
# create reshape, dot, then reshape
879-
lhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
880-
[*node.args[1].meta["val"].size()[1:]]
881-
)
882-
rhs_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
883-
[*node.args[2].meta["val"].size()[1:]]
884-
)
885-
acc_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
886-
[*node.args[0].meta["val"].size()[1:]]
887-
)
888-
out_shape_str = ctx.cg.device_function.tile_strategy.shape_str(
889-
[*node.meta["val"].size()]
890-
)
891-
lhs_reshape = expr_from_string(f"tl.reshape(lhs, {lhs_shape_str})", lhs=lhs)
892-
rhs_reshape = expr_from_string(f"tl.reshape(rhs, {rhs_shape_str})", rhs=rhs)
893-
acc_reshape = expr_from_string(f"tl.reshape(rhs, {acc_shape_str})", rhs=acc)
894-
comp = expr_from_string(
895-
f"tl.dot(lhs, rhs, acc=acc, input_precision={tf32!r})",
896-
lhs=lhs_reshape,
897-
rhs=rhs_reshape,
898-
acc=acc_reshape,
899-
)
900-
return expr_from_string(f"tl.reshape(lhs, {out_shape_str})", lhs=comp)
876+
return reduce_3d_dot(ctx, node, True)
901877

902878

903879
class GenerateASTFromInductor(DefaultHandler):

0 commit comments

Comments
 (0)