@@ -544,6 +544,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
544
544
tensor_var_map = tensor_var_map ,
545
545
)
546
546
547
+ @OperatorFactory .register
548
+ class _ConvOperator (_CommonParams ):
549
+ op_type = "ConvOperator"
550
+
551
+ @classmethod
552
+ @must_return_type (Hashable )
553
+ def get_constructor_parameters (cls , op_info ):
554
+
555
+ strides = [
556
+ 1 ,
557
+ op_info .op_attr ['StrideW' ],
558
+ op_info .op_attr ['StrideH' ],
559
+ 1 ,
560
+ ]
561
+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
562
+ strides_str = ',' .join (map (str , strides ))
563
+ return ("{{ {} }}" .format (strides_str ), padding )
564
+
565
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
566
+ return DeclareOpSnippet (
567
+ op = self ,
568
+ templ_dtypes = [self .out_dtypes [0 ]],
569
+ op_var_name = op_var_name ,
570
+ )
571
+
572
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
573
+ return ConvOpEvalSnippet (
574
+ op_info = op_info ,
575
+ templ_dtypes = [self .out_dtypes [0 ]],
576
+ op_name = op_var_name ,
577
+ tensor_var_map = tensor_var_map ,
578
+ )
579
+
547
580
548
581
@OperatorFactory .register
549
582
class _QuantizedFullyConnectedOperator (_CommonParams ):
@@ -842,3 +875,142 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
842
875
op_name = op_var_name ,
843
876
tensor_var_map = tensor_var_map ,
844
877
)
878
+
879
+ @OperatorFactory .register
880
+ class _BatchNormOperator (_CommonParams ):
881
+ op_type = "BatchNormOperator"
882
+
883
+ @classmethod
884
+ @must_return_type (Hashable )
885
+ def get_constructor_parameters (cls , op_info ):
886
+ strides = [
887
+ 1 ,
888
+ op_info .op_attr ['StrideW' ],
889
+ op_info .op_attr ['StrideH' ],
890
+ 1 ,
891
+ ]
892
+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
893
+ strides_str = ',' .join (map (str , strides ))
894
+ return ("{{ {} }}" .format (strides_str ), padding )
895
+
896
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
897
+ return DeclareOpSnippet (
898
+ op = self ,
899
+ templ_dtypes = [self .out_dtypes [0 ]],
900
+ op_var_name = op_var_name ,
901
+ )
902
+
903
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
904
+ return BatchNormSnippet (
905
+ op_info = op_info ,
906
+ templ_dtypes = [self .out_dtypes [0 ]],
907
+ op_name = op_var_name ,
908
+ tensor_var_map = tensor_var_map ,
909
+ )
910
+
911
+ @OperatorFactory .register
912
+ class _MeanOperator (_CommonParams ):
913
+ op_type = "MeanOperator"
914
+
915
+ @classmethod
916
+ @must_return_type (Hashable )
917
+ def get_constructor_parameters (cls , op_info ):
918
+ keep_dims = str (op_info .op_attr ["keep_dims" ])
919
+ return (" {} " .format (keep_dims ), )
920
+
921
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
922
+ return DeclareOpSnippet (
923
+ op = self ,
924
+ templ_dtypes = [self .out_dtypes [0 ]],
925
+ op_var_name = op_var_name ,
926
+ )
927
+
928
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
929
+ return BatchNormSnippet (
930
+ op_info = op_info ,
931
+ templ_dtypes = [self .out_dtypes [0 ]],
932
+ op_name = op_var_name ,
933
+ tensor_var_map = tensor_var_map ,
934
+ )
935
+
936
+ @OperatorFactory .register
937
+ class _SoftmaxOperator (_CommonParams ):
938
+ op_type = "SoftmaxOperator"
939
+
940
+ @classmethod
941
+ @must_return_type (Hashable )
942
+ def get_constructor_parameters (cls , op_info ):
943
+ Beta = op_info .op_attr ["Beta" ]
944
+ return (" %f " % Beta ,)
945
+
946
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
947
+ return DeclareOpSnippet (
948
+ op = self ,
949
+ templ_dtypes = [self .out_dtypes [0 ]],
950
+ op_var_name = op_var_name ,
951
+ )
952
+
953
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
954
+ return BatchNormSnippet (
955
+ op_info = op_info ,
956
+ templ_dtypes = [self .out_dtypes [0 ]],
957
+ op_name = op_var_name ,
958
+ tensor_var_map = tensor_var_map ,
959
+ )
960
+
961
+ @OperatorFactory .register
962
+ class _MulOperator (_Operator ):
963
+ op_type = 'MulOperator'
964
+
965
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
966
+ return DeclareOpSnippet (
967
+ op = self ,
968
+ templ_dtypes = [self .in_dtypes [0 ]],
969
+ op_var_name = op_var_name ,
970
+ )
971
+
972
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
973
+ return MulOpEvalSnippet (
974
+ op_info = op_info ,
975
+ templ_dtypes = [self .in_dtypes [0 ]],
976
+ op_name = op_var_name ,
977
+ tensor_var_map = tensor_var_map ,
978
+ )
979
+
980
+ @OperatorFactory .register
981
+ class _SubOperator (_Operator ):
982
+ op_type = 'SubOperator'
983
+
984
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
985
+ return DeclareOpSnippet (
986
+ op = self ,
987
+ templ_dtypes = [self .in_dtypes [0 ]],
988
+ op_var_name = op_var_name ,
989
+ )
990
+
991
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
992
+ return SubOpEvalSnippet (
993
+ op_info = op_info ,
994
+ templ_dtypes = [self .in_dtypes [0 ]],
995
+ op_name = op_var_name ,
996
+ tensor_var_map = tensor_var_map ,
997
+ )
998
+
999
+ @OperatorFactory .register
1000
+ class _SigmoidOperator (_Operator ):
1001
+ op_type = 'SigmoidOperator'
1002
+
1003
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
1004
+ return DeclareOpSnippet (
1005
+ op = self ,
1006
+ templ_dtypes = [self .in_dtypes [0 ]],
1007
+ op_var_name = op_var_name ,
1008
+ )
1009
+
1010
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
1011
+ return SigmoidOpEvalSnippet (
1012
+ op_info = op_info ,
1013
+ templ_dtypes = [self .in_dtypes [0 ]],
1014
+ op_name = op_var_name ,
1015
+ tensor_var_map = tensor_var_map ,
1016
+ )
0 commit comments