@@ -777,17 +777,23 @@ def apply_dot_requirements(
777
777
return LambdaLowering (handler , masked_value_fn = masked_value_fn )
778
778
779
779
780
- @register_lowering (torch .ops .aten .bmm .default , apply_dot_requirements )
781
- # pyre-fixme[56]
782
- @register_lowering (torch .ops .aten .mm .default , apply_dot_requirements )
783
- def codegen_mm (ctx : GraphInterpreter , node : torch .fx .Node ) -> ast .AST :
784
- assert not node .kwargs , "matmul kwargs not supported"
785
- lhs , rhs = map_arg (node .args , lambda arg : ctx .env [arg ])
780
+ def reduce_3d_dot (ctx : GraphInterpreter , node : torch .fx .Node , withAcc : bool ) -> ast .AST :
781
+ datatype = CompileEnvironment .current ().settings .dot_precision
782
+ if withAcc :
783
+ acc , lhs , rhs = map_arg (node .args , lambda arg : ctx .env [arg ])
784
+ assert isinstance (acc , ast .AST )
785
+ lhsNode = node .args [1 ]
786
+ rhsNode = node .args [2 ]
787
+ else :
788
+ lhs , rhs = map_arg (node .args , lambda arg : ctx .env [arg ])
789
+ lhsNode = node .args [0 ]
790
+ rhsNode = node .args [1 ]
786
791
assert isinstance (lhs , ast .AST )
787
792
assert isinstance (rhs , ast .AST )
788
- lhsSize = node .args [0 ].meta ["val" ].size ()
789
- rhsSize = node .args [1 ].meta ["val" ].size ()
790
- # check to see if it is 3D
793
+
794
+ lhsSize = lhsNode .meta ["val" ].size ()
795
+ rhsSize = rhsNode .meta ["val" ].size ()
796
+ # check to see if it is 3D and the highest dim is 1
791
797
reduceDim = False
792
798
if len (lhsSize ) == 3 :
793
799
env = CompileEnvironment .current ()
@@ -802,102 +808,72 @@ def codegen_mm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST:
802
808
):
803
809
reduceDim = True
804
810
805
- tf32 = CompileEnvironment .current ().settings .dot_precision
806
811
if not reduceDim :
812
+ if withAcc :
813
+ return expr_from_string (
814
+ f"tl.dot(lhs, rhs, acc=acc, input_precision={ datatype !r} )" ,
815
+ lhs = lhs ,
816
+ rhs = rhs ,
817
+ acc = acc ,
818
+ )
819
+ # without accumulator
807
820
return expr_from_string (
808
- f"tl.dot(lhs, rhs, input_precision={ tf32 !r} )" , lhs = lhs , rhs = rhs
821
+ f"tl.dot(lhs, rhs, input_precision={ datatype !r} )" , lhs = lhs , rhs = rhs
809
822
)
823
+
810
824
# create reshape, dot, then reshape
811
825
lhs_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
812
- [* node . args [ 0 ] .meta ["val" ].size ()[1 :]]
826
+ [* lhsNode .meta ["val" ].size ()[1 :]]
813
827
)
814
828
rhs_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
815
- [* node . args [ 1 ] .meta ["val" ].size ()[1 :]]
829
+ [* rhsNode .meta ["val" ].size ()[1 :]]
816
830
)
817
831
out_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
818
832
[* node .meta ["val" ].size ()]
819
833
)
820
834
lhs_reshape = expr_from_string (f"tl.reshape(lhs, { lhs_shape_str } )" , lhs = lhs )
821
835
rhs_reshape = expr_from_string (f"tl.reshape(rhs, { rhs_shape_str } )" , rhs = rhs )
822
- comp = expr_from_string (
823
- f"tl.dot(lhs, rhs, input_precision={ tf32 !r} )" , lhs = lhs_reshape , rhs = rhs_reshape
824
- )
836
+ if withAcc :
837
+ acc_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
838
+ [* node .args [0 ].meta ["val" ].size ()[1 :]]
839
+ )
840
+ acc_reshape = expr_from_string (f"tl.reshape(rhs, { acc_shape_str } )" , rhs = acc )
841
+ comp = expr_from_string (
842
+ f"tl.dot(lhs, rhs, acc=acc, input_precision={ datatype !r} )" ,
843
+ lhs = lhs_reshape ,
844
+ rhs = rhs_reshape ,
845
+ acc = acc_reshape ,
846
+ )
847
+ else :
848
+ comp = expr_from_string (
849
+ f"tl.dot(lhs, rhs, input_precision={ datatype !r} )" ,
850
+ lhs = lhs_reshape ,
851
+ rhs = rhs_reshape ,
852
+ )
825
853
return expr_from_string (f"tl.reshape(lhs, { out_shape_str } )" , lhs = comp )
826
854
827
855
856
+ @register_lowering (torch .ops .aten .bmm .default , apply_dot_requirements )
857
+ # pyre-fixme[56]
858
+ @register_lowering (torch .ops .aten .mm .default , apply_dot_requirements )
859
+ def codegen_mm (ctx : GraphInterpreter , node : torch .fx .Node ) -> ast .AST :
860
+ assert not node .kwargs , "matmul kwargs not supported"
861
+
862
+ return reduce_3d_dot (ctx , node , False )
863
+
864
+
828
865
# pyre-fixme[56]
829
866
@register_lowering (torch .ops .aten .addmm .default , apply_dot_requirements )
830
867
def codegen_addmm (ctx : GraphInterpreter , node : torch .fx .Node ) -> ast .AST :
831
868
assert not node .kwargs , "addmm kwargs not supported"
832
- acc , lhs , rhs = map_arg (node .args , lambda arg : ctx .env [arg ])
833
- assert isinstance (acc , ast .AST )
834
- assert isinstance (lhs , ast .AST )
835
- assert isinstance (rhs , ast .AST )
836
- tf32 = CompileEnvironment .current ().settings .dot_precision
837
- return expr_from_string (
838
- f"tl.dot(lhs, rhs, acc=acc, input_precision={ tf32 !r} )" ,
839
- lhs = lhs ,
840
- rhs = rhs ,
841
- acc = acc ,
842
- )
869
+ return reduce_3d_dot (ctx , node , True )
843
870
844
871
845
872
# pyre-fixme[56]
846
873
@register_lowering (torch .ops .aten .baddbmm .default , apply_dot_requirements )
847
874
def codegen_baddbmm (ctx : GraphInterpreter , node : torch .fx .Node ) -> ast .AST :
848
875
assert not node .kwargs , "baddbmm kwargs not supported"
849
- acc , lhs , rhs = map_arg (node .args , lambda arg : ctx .env [arg ])
850
- assert isinstance (acc , ast .AST )
851
- assert isinstance (lhs , ast .AST )
852
- assert isinstance (rhs , ast .AST )
853
- tf32 = CompileEnvironment .current ().settings .dot_precision
854
- lhsSize = node .args [1 ].meta ["val" ].size ()
855
- rhsSize = node .args [2 ].meta ["val" ].size ()
856
- # check to see if it is 3D
857
- reduceDim = False
858
- if len (lhsSize ) == 3 :
859
- env = CompileEnvironment .current ()
860
- lhsDimIdx = env .get_block_id (lhsSize [0 ])
861
- rhsDimIdx = env .get_block_id (rhsSize [0 ])
862
- if lhsDimIdx is not None and rhsDimIdx is not None :
863
- lhsDimVal = env .block_sizes [lhsDimIdx ]
864
- rhsDimVal = env .block_sizes [rhsDimIdx ]
865
- if (
866
- lhsDimVal .from_config (ctx .cg .device_function .config ) == 1
867
- and rhsDimVal .from_config (ctx .cg .device_function .config ) == 1
868
- ):
869
- reduceDim = True
870
-
871
- if not reduceDim :
872
- return expr_from_string (
873
- f"tl.dot(lhs, rhs, acc=acc, input_precision={ tf32 !r} )" ,
874
- lhs = lhs ,
875
- rhs = rhs ,
876
- acc = acc ,
877
- )
878
- # create reshape, dot, then reshape
879
- lhs_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
880
- [* node .args [1 ].meta ["val" ].size ()[1 :]]
881
- )
882
- rhs_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
883
- [* node .args [2 ].meta ["val" ].size ()[1 :]]
884
- )
885
- acc_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
886
- [* node .args [0 ].meta ["val" ].size ()[1 :]]
887
- )
888
- out_shape_str = ctx .cg .device_function .tile_strategy .shape_str (
889
- [* node .meta ["val" ].size ()]
890
- )
891
- lhs_reshape = expr_from_string (f"tl.reshape(lhs, { lhs_shape_str } )" , lhs = lhs )
892
- rhs_reshape = expr_from_string (f"tl.reshape(rhs, { rhs_shape_str } )" , rhs = rhs )
893
- acc_reshape = expr_from_string (f"tl.reshape(rhs, { acc_shape_str } )" , rhs = acc )
894
- comp = expr_from_string (
895
- f"tl.dot(lhs, rhs, acc=acc, input_precision={ tf32 !r} )" ,
896
- lhs = lhs_reshape ,
897
- rhs = rhs_reshape ,
898
- acc = acc_reshape ,
899
- )
900
- return expr_from_string (f"tl.reshape(lhs, { out_shape_str } )" , lhs = comp )
876
+ return reduce_3d_dot (ctx , node , True )
901
877
902
878
903
879
class GenerateASTFromInductor (DefaultHandler ):
0 commit comments