@@ -197,25 +197,38 @@ def generate_inference_str(
197
197
)
198
198
199
199
if params ["input2_dtype" ] == "int8" and params ["input_dtype" ] in ["float32" , "int8" ]:
200
- if params ["first_k_channel" ] is not None :
201
- string += (
202
- f"{ function_name } _int8w_partialCH(conv_params,{ input_address_string } ,"
203
- + f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
204
- + f"(q7_t*){ weight_string } ,(q7_t*){ weight_string } Flash,{ params ['first_k_channel' ]} ,"
205
- + f"{ params ['kernel_h' ]} ,{ params ['kernel_w' ]} ,NULL,"
206
- + f"{ output_address_string } ,"
207
- + f"{ str (params ['output_h' ])} ,{ str (params ['output_w' ])} ,{ str (params ['output_c' ])} ,"
208
- + "(float*)sbuf,1);\n "
209
- )
210
- else :
200
+ if params ["first_k_channel" ] is None :
211
201
string += (
212
- f"{ function_name } _int8w (conv_params,{ input_address_string } ,"
202
+ f"{ function_name } _int8weight (conv_params,{ input_address_string } ,"
213
203
+ f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
214
204
+ f"(q7_t*){ weight_string } ,{ params ['kernel_h' ]} ,{ params ['kernel_w' ]} ,NULL,"
215
205
+ f"{ output_address_string } ,"
216
206
+ f"{ str (params ['output_h' ])} ,{ str (params ['output_w' ])} ,{ str (params ['output_c' ])} ,"
217
207
+ "(float*)sbuf,1);\n "
218
208
)
209
+ else :
210
+ if params ["first_k_channel" ] % 8 == 0 :
211
+ string += (
212
+ f"{ function_name } _int8weight_partialCH_8innercol(conv_params,{ input_address_string } ,"
213
+ + f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
214
+ + f"(q7_t*){ weight_string } ,(q7_t*){ weight_string } Flash,{ params ['first_k_channel' ]} ,"
215
+ + f"{ params ['kernel_h' ]} ,{ params ['kernel_w' ]} ,NULL,"
216
+ + f"{ output_address_string } ,"
217
+ + f"{ str (params ['output_h' ])} ,{ str (params ['output_w' ])} ,{ str (params ['output_c' ])} ,"
218
+ + "(float*)sbuf,1);\n "
219
+ )
220
+ elif params ["first_k_channel" ] % 4 == 0 :
221
+ string += (
222
+ f"{ function_name } _int8weight_partialCH_4innercol(conv_params,{ input_address_string } ,"
223
+ + f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
224
+ + f"(q7_t*){ weight_string } ,(q7_t*){ weight_string } Flash,{ params ['first_k_channel' ]} ,"
225
+ + f"{ params ['kernel_h' ]} ,{ params ['kernel_w' ]} ,NULL,"
226
+ + f"{ output_address_string } ,"
227
+ + f"{ str (params ['output_h' ])} ,{ str (params ['output_w' ])} ,{ str (params ['output_c' ])} ,"
228
+ + "(float*)sbuf,1);\n "
229
+ )
230
+ else :
231
+ raise NotImplementedError
219
232
else :
220
233
string += (
221
234
f"{ function_name } (conv_params,{ input_address_string } ,"
@@ -272,7 +285,7 @@ def generate_inference_str(
272
285
if params ["input2_dtype" ] == "int8" and params ["input_dtype" ] in ["float32" , "int8" ]:
273
286
if params ["first_k_channel" ] is not None :
274
287
string += (
275
- f"{ function_name } _int8w_partialCH ("
288
+ f"{ function_name } _int8weight_partialCH ("
276
289
+ f"{ self ._getBufferstrCast (params ['input_buf_add' ], params ['input_buf_add_offset' ])} ,"
277
290
+ f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
278
291
+ f"(q7_t*){ weight_string } ,(q7_t*){ weight_string } Flash,{ params ['first_k_channel' ]} ,NULL,"
@@ -291,7 +304,7 @@ def generate_inference_str(
291
304
string += "(float*)sbuf, 1);\n "
292
305
else :
293
306
string += (
294
- f"{ function_name } _int8w ("
307
+ f"{ function_name } _int8weight ("
295
308
+ f"{ self ._getBufferstrCast (params ['input_buf_add' ], params ['input_buf_add_offset' ])} ,"
296
309
+ f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
297
310
+ f"(q7_t*){ weight_string } ,NULL,"
@@ -353,7 +366,7 @@ def generate_inference_str(
353
366
354
367
if params ["input2_dtype" ] == "int8" and params ["input_dtype" ] == "float32" :
355
368
string += (
356
- f"{ function_name } _int8w (conv_params,"
369
+ f"{ function_name } _int8weight (conv_params,"
357
370
+ f"{ self ._getBufferstrCast (params ['input_buf_add' ], params ['input_buf_add_offset' ])} ,"
358
371
+ f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
359
372
+ f"{ weight_string } ,{ params ['kernel_h' ]} ,{ params ['kernel_w' ]} ,NULL,"
@@ -373,7 +386,7 @@ def generate_inference_str(
373
386
string += "(float*)sbuf,1);\n "
374
387
elif params ["group" ] == params ["input_c" ] and params ["group" ] == params ["output_c" ] and not tflite_op :
375
388
# function name
376
- function_name = "transpose_depthwise_conv_kernel "
389
+ function_name = "transpose_depthwise_conv_fp_kernel "
377
390
if params ["stride_h" ] == 1 :
378
391
outpad = 0
379
392
elif params ["stride_h" ] == 2 :
@@ -384,7 +397,7 @@ def generate_inference_str(
384
397
raise NotImplementedError
385
398
function_name += (
386
399
f"{ str (params ['kernel_h' ])} _stride{ str (params ['stride_h' ])} _"
387
- + f"inpad{ str (params ['padding_h' ])} _outpad{ str (outpad )} _revised "
400
+ + f"inpad{ str (params ['padding_h' ])} _outpad{ str (outpad )} "
388
401
)
389
402
390
403
if params ["kernel_layout" ] == "IOHW" :
@@ -398,7 +411,7 @@ def generate_inference_str(
398
411
399
412
if params ["input2_dtype" ] == "int8" and params ["input_dtype" ] == "float32" :
400
413
string += (
401
- f"{ function_name } _int8w ("
414
+ f"{ function_name } _int8weight ("
402
415
+ f"{ self ._getBufferstrCast (params ['input_buf_add' ], params ['input_buf_add_offset' ])} ,"
403
416
+ f"{ params ['input_h' ]} ,{ params ['input_w' ]} ,{ params ['input_c' ]} ,"
404
417
+ f"{ weight_string } ,NULL,"
0 commit comments