Skip to content

Commit 00fece7

Browse files
committed
Revise group conv and transposed conv
1 parent c349831 commit 00fece7

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

code_generator/operators/group_conv2d.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ def generate_inference_str(
290290
function_name += "_int8weight"
291291

292292
if (params["output_c"] / params["groups"]) % 16 == 0 or (params["output_c"] / params["groups"]) % 8 == 0:
293-
function_name += "_inplace_revised"
294-
else:
295293
function_name += "_inplace"
294+
else:
295+
raise NotImplementedError
296296

297297
# weight name
298298
if isinstance(params["weight_name"], str) and isweightstr(params["weight_name"]):

code_generator/operators/transpose_conv2d.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -197,25 +197,38 @@ def generate_inference_str(
197197
)
198198

199199
if params["input2_dtype"] == "int8" and params["input_dtype"] in ["float32", "int8"]:
200-
if params["first_k_channel"] is not None:
201-
string += (
202-
f"{function_name}_int8w_partialCH(conv_params,{input_address_string},"
203-
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
204-
+ f"(q7_t*){weight_string},(q7_t*){weight_string}Flash,{params['first_k_channel']},"
205-
+ f"{params['kernel_h']},{params['kernel_w']},NULL,"
206-
+ f"{output_address_string},"
207-
+ f"{str(params['output_h'])},{str(params['output_w'])},{str(params['output_c'])},"
208-
+ "(float*)sbuf,1);\n"
209-
)
210-
else:
200+
if params["first_k_channel"] is None:
211201
string += (
212-
f"{function_name}_int8w(conv_params,{input_address_string},"
202+
f"{function_name}_int8weight(conv_params,{input_address_string},"
213203
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
214204
+ f"(q7_t*){weight_string},{params['kernel_h']},{params['kernel_w']},NULL,"
215205
+ f"{output_address_string},"
216206
+ f"{str(params['output_h'])},{str(params['output_w'])},{str(params['output_c'])},"
217207
+ "(float*)sbuf,1);\n"
218208
)
209+
else:
210+
if params["first_k_channel"] % 8 == 0:
211+
string += (
212+
f"{function_name}_int8weight_partialCH_8innercol(conv_params,{input_address_string},"
213+
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
214+
+ f"(q7_t*){weight_string},(q7_t*){weight_string}Flash,{params['first_k_channel']},"
215+
+ f"{params['kernel_h']},{params['kernel_w']},NULL,"
216+
+ f"{output_address_string},"
217+
+ f"{str(params['output_h'])},{str(params['output_w'])},{str(params['output_c'])},"
218+
+ "(float*)sbuf,1);\n"
219+
)
220+
elif params["first_k_channel"] % 4 == 0:
221+
string += (
222+
f"{function_name}_int8weight_partialCH_4innercol(conv_params,{input_address_string},"
223+
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
224+
+ f"(q7_t*){weight_string},(q7_t*){weight_string}Flash,{params['first_k_channel']},"
225+
+ f"{params['kernel_h']},{params['kernel_w']},NULL,"
226+
+ f"{output_address_string},"
227+
+ f"{str(params['output_h'])},{str(params['output_w'])},{str(params['output_c'])},"
228+
+ "(float*)sbuf,1);\n"
229+
)
230+
else:
231+
raise NotImplementedError
219232
else:
220233
string += (
221234
f"{function_name}(conv_params,{input_address_string},"
@@ -272,7 +285,7 @@ def generate_inference_str(
272285
if params["input2_dtype"] == "int8" and params["input_dtype"] in ["float32", "int8"]:
273286
if params["first_k_channel"] is not None:
274287
string += (
275-
f"{function_name}_int8w_partialCH("
288+
f"{function_name}_int8weight_partialCH("
276289
+ f"{self._getBufferstrCast(params['input_buf_add'], params['input_buf_add_offset'])},"
277290
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
278291
+ f"(q7_t*){weight_string},(q7_t*){weight_string}Flash,{params['first_k_channel']},NULL,"
@@ -291,7 +304,7 @@ def generate_inference_str(
291304
string += "(float*)sbuf, 1);\n"
292305
else:
293306
string += (
294-
f"{function_name}_int8w("
307+
f"{function_name}_int8weight("
295308
+ f"{self._getBufferstrCast(params['input_buf_add'], params['input_buf_add_offset'])},"
296309
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
297310
+ f"(q7_t*){weight_string},NULL,"
@@ -353,7 +366,7 @@ def generate_inference_str(
353366

354367
if params["input2_dtype"] == "int8" and params["input_dtype"] == "float32":
355368
string += (
356-
f"{function_name}_int8w(conv_params,"
369+
f"{function_name}_int8weight(conv_params,"
357370
+ f"{self._getBufferstrCast(params['input_buf_add'], params['input_buf_add_offset'])},"
358371
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
359372
+ f"{weight_string},{params['kernel_h']},{params['kernel_w']},NULL,"
@@ -373,7 +386,7 @@ def generate_inference_str(
373386
string += "(float*)sbuf,1);\n"
374387
elif params["group"] == params["input_c"] and params["group"] == params["output_c"] and not tflite_op:
375388
# function name
376-
function_name = "transpose_depthwise_conv_kernel"
389+
function_name = "transpose_depthwise_conv_fp_kernel"
377390
if params["stride_h"] == 1:
378391
outpad = 0
379392
elif params["stride_h"] == 2:
@@ -384,7 +397,7 @@ def generate_inference_str(
384397
raise NotImplementedError
385398
function_name += (
386399
f"{str(params['kernel_h'])}_stride{str(params['stride_h'])}_"
387-
+ f"inpad{str(params['padding_h'])}_outpad{str(outpad)}_revised"
400+
+ f"inpad{str(params['padding_h'])}_outpad{str(outpad)}"
388401
)
389402

390403
if params["kernel_layout"] == "IOHW":
@@ -398,7 +411,7 @@ def generate_inference_str(
398411

399412
if params["input2_dtype"] == "int8" and params["input_dtype"] == "float32":
400413
string += (
401-
f"{function_name}_int8w("
414+
f"{function_name}_int8weight("
402415
+ f"{self._getBufferstrCast(params['input_buf_add'], params['input_buf_add_offset'])},"
403416
+ f"{params['input_h']},{params['input_w']},{params['input_c']},"
404417
+ f"{weight_string},NULL,"

0 commit comments

Comments
 (0)