Skip to content

Commit e79e647

Browse files
committed
Merge branch 're-arch-support-extra-ops' of github.com:uTensor/utensor_cgen into re-arch-support-extra-ops
2 parents 3464cfd + 05e8882 commit e79e647

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
544544
tensor_var_map=tensor_var_map,
545545
)
546546

547+
@OperatorFactory.register
548+
class _ConvOperator(_CommonParams):
549+
op_type = "ConvOperator"
550+
551+
@classmethod
552+
@must_return_type(Hashable)
553+
def get_constructor_parameters(cls, op_info):
554+
555+
strides = [
556+
1,
557+
op_info.op_attr['StrideW'],
558+
op_info.op_attr['StrideH'],
559+
1,
560+
]
561+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
562+
strides_str = ','.join(map(str, strides))
563+
return ("{{ {} }}".format(strides_str), padding)
564+
565+
def get_declare_snippet(self, op_var_name, tensor_var_map):
566+
return DeclareOpSnippet(
567+
op=self,
568+
templ_dtypes=[self.out_dtypes[0]],
569+
op_var_name=op_var_name,
570+
)
571+
572+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
573+
return ConvOpEvalSnippet(
574+
op_info=op_info,
575+
templ_dtypes=[self.out_dtypes[0]],
576+
op_name=op_var_name,
577+
tensor_var_map=tensor_var_map,
578+
)
579+
547580

548581
@OperatorFactory.register
549582
class _QuantizedFullyConnectedOperator(_CommonParams):
@@ -842,3 +875,142 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
842875
op_name=op_var_name,
843876
tensor_var_map=tensor_var_map,
844877
)
878+
879+
@OperatorFactory.register
880+
class _BatchNormOperator(_CommonParams):
881+
op_type = "BatchNormOperator"
882+
883+
@classmethod
884+
@must_return_type(Hashable)
885+
def get_constructor_parameters(cls, op_info):
886+
strides = [
887+
1,
888+
op_info.op_attr['StrideW'],
889+
op_info.op_attr['StrideH'],
890+
1,
891+
]
892+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
893+
strides_str = ','.join(map(str, strides))
894+
return ("{{ {} }}".format(strides_str), padding)
895+
896+
def get_declare_snippet(self, op_var_name, tensor_var_map):
897+
return DeclareOpSnippet(
898+
op=self,
899+
templ_dtypes=[self.out_dtypes[0]],
900+
op_var_name=op_var_name,
901+
)
902+
903+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
904+
return BatchNormSnippet(
905+
op_info=op_info,
906+
templ_dtypes=[self.out_dtypes[0]],
907+
op_name=op_var_name,
908+
tensor_var_map=tensor_var_map,
909+
)
910+
911+
@OperatorFactory.register
912+
class _MeanOperator(_CommonParams):
913+
op_type = "MeanOperator"
914+
915+
@classmethod
916+
@must_return_type(Hashable)
917+
def get_constructor_parameters(cls, op_info):
918+
keep_dims = str(op_info.op_attr["keep_dims"])
919+
return (" {} ".format(keep_dims), )
920+
921+
def get_declare_snippet(self, op_var_name, tensor_var_map):
922+
return DeclareOpSnippet(
923+
op=self,
924+
templ_dtypes=[self.out_dtypes[0]],
925+
op_var_name=op_var_name,
926+
)
927+
928+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
929+
return BatchNormSnippet(
930+
op_info=op_info,
931+
templ_dtypes=[self.out_dtypes[0]],
932+
op_name=op_var_name,
933+
tensor_var_map=tensor_var_map,
934+
)
935+
936+
@OperatorFactory.register
937+
class _SoftmaxOperator(_CommonParams):
938+
op_type = "SoftmaxOperator"
939+
940+
@classmethod
941+
@must_return_type(Hashable)
942+
def get_constructor_parameters(cls, op_info):
943+
Beta = op_info.op_attr["Beta"]
944+
return (" %f " % Beta,)
945+
946+
def get_declare_snippet(self, op_var_name, tensor_var_map):
947+
return DeclareOpSnippet(
948+
op=self,
949+
templ_dtypes=[self.out_dtypes[0]],
950+
op_var_name=op_var_name,
951+
)
952+
953+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
954+
return BatchNormSnippet(
955+
op_info=op_info,
956+
templ_dtypes=[self.out_dtypes[0]],
957+
op_name=op_var_name,
958+
tensor_var_map=tensor_var_map,
959+
)
960+
961+
@OperatorFactory.register
962+
class _MulOperator(_Operator):
963+
op_type = 'MulOperator'
964+
965+
def get_declare_snippet(self, op_var_name, tensor_var_map):
966+
return DeclareOpSnippet(
967+
op=self,
968+
templ_dtypes=[self.in_dtypes[0]],
969+
op_var_name=op_var_name,
970+
)
971+
972+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
973+
return MulOpEvalSnippet(
974+
op_info=op_info,
975+
templ_dtypes=[self.in_dtypes[0]],
976+
op_name=op_var_name,
977+
tensor_var_map=tensor_var_map,
978+
)
979+
980+
@OperatorFactory.register
981+
class _SubOperator(_Operator):
982+
op_type = 'SubOperator'
983+
984+
def get_declare_snippet(self, op_var_name, tensor_var_map):
985+
return DeclareOpSnippet(
986+
op=self,
987+
templ_dtypes=[self.in_dtypes[0]],
988+
op_var_name=op_var_name,
989+
)
990+
991+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
992+
return SubOpEvalSnippet(
993+
op_info=op_info,
994+
templ_dtypes=[self.in_dtypes[0]],
995+
op_name=op_var_name,
996+
tensor_var_map=tensor_var_map,
997+
)
998+
999+
@OperatorFactory.register
1000+
class _SigmoidOperator(_Operator):
1001+
op_type = 'SigmoidOperator'
1002+
1003+
def get_declare_snippet(self, op_var_name, tensor_var_map):
1004+
return DeclareOpSnippet(
1005+
op=self,
1006+
templ_dtypes=[self.in_dtypes[0]],
1007+
op_var_name=op_var_name,
1008+
)
1009+
1010+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
1011+
return SigmoidOpEvalSnippet(
1012+
op_info=op_info,
1013+
templ_dtypes=[self.in_dtypes[0]],
1014+
op_name=op_var_name,
1015+
tensor_var_map=tensor_var_map,
1016+
)

utensor_cgen/backend/utensor/snippets/rearch/_snippets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
"MaxPoolEvalSnippet",
3030
"QuantizedFullyConnectedSnippet",
3131
"MissingOpEvalSnippet",
32+
"BatchNormSnippet",
3233
"TimeSlotContainer",
3334
"MulOpEvalSnippet",
3435
"SubOpEvalSnippet",
3536
"ConvOpEvalSnippet",
3637
"MeanOpEvalSnippet",
3738
"SoftmaxOpEvalSnippet",
39+
"SigmoidOpEvalSnippet",
3840
"SimpleContainer",
3941
]
4042

@@ -256,6 +258,7 @@ class SoftmaxOpEvalSnippet(OpEvalSnippet):
256258
__inputs__ = ['input']
257259
__outputs__ = ['output']
258260

261+
<<<<<<< HEAD
259262

260263
class MissingOpEvalSnippet(OpEvalSnippet):
261264
__template_name__ = "snippets/rearch/op_missing.cpp"
@@ -277,6 +280,11 @@ def __init__(self, op_info, tensor_var_map):
277280
]
278281
self.template_vars['output_tensors'] = op_info.output_tensors[:]
279282
self.template_vars['quant_params_map'] = quant_params_map
283+
=======
284+
class SigmoidOpEvalSnippet(OpEvalSnippet):
285+
__inputs__ = ['in']
286+
__outputs__ = ['out']
287+
>>>>>>> 05e8882f5f9fd828586bbb708782c9e173677041
280288

281289

282290
class TimeSlotContainer(SnippetBase):

utensor_cgen/legalizer/tflite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class _OpTypeRename(object):
3838
"Mean": "MeanOperator",
3939
"Softmax": "SoftmaxOperator",
4040
"Sigmoid": "SigmoidOperator",
41+
"Logistic": "SigmoidOperator",
4142
}
4243

4344
@classmethod

0 commit comments

Comments
 (0)