@@ -220,7 +220,6 @@ def processInst(writer: io.TextIOWrapper,
220
220
conds = []
221
221
op_name = instruction ["opname" ]
222
222
fn_name = op_name [2 ].lower () + op_name [3 :]
223
- result_types = []
224
223
exts = instruction ["extensions" ] if "extensions" in instruction else []
225
224
226
225
if "capabilities" in instruction and len (instruction ["capabilities" ]) > 0 :
@@ -244,107 +243,97 @@ def processInst(writer: io.TextIOWrapper,
244
243
conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
245
244
break
246
245
case "U" :
247
- fn_name = fn_name [0 :m [1 ][0 ]] + fn_name [m [1 ][1 ]:]
248
- result_types = ["uint16_t" , "uint32_t" , "uint64_t" ]
246
+ conds .append ("is_unsigned_v<T>" )
249
247
break
250
248
case "S" :
251
- fn_name = fn_name [0 :m [1 ][0 ]] + fn_name [m [1 ][1 ]:]
252
- result_types = ["int16_t" , "int32_t" , "int64_t" ]
249
+ conds .append ("is_signed_v<T>" )
253
250
break
254
251
case "F" :
255
- fn_name = fn_name [0 :m [1 ][0 ]] + fn_name [m [1 ][1 ]:]
256
- result_types = ["float16_t" , "float32_t" , "float64_t" ]
252
+ conds .append ("is_floating_point<T>" )
257
253
break
258
-
259
- match instruction ["class" ]:
260
- case "Bit" :
261
- if len (result_types ) == 0 : conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
254
+ else :
255
+ if instruction ["class" ] == "Bit" :
256
+ conds .append ("(is_signed_v<T> || is_unsigned_v<T>)" )
262
257
263
258
if "operands" in instruction and instruction ["operands" ][0 ]["kind" ] == "IdResultType" :
264
- if len (result_types ) == 0 :
265
- if result_ty == None :
266
- result_types = ["T" ]
267
- else :
268
- result_types = [result_ty ]
259
+ if result_ty == None :
260
+ result_ty = "T"
269
261
else :
270
- assert len (result_types ) == 0
271
- result_types = ["void" ]
272
-
273
- for rt in result_types :
274
- overload_caps = caps .copy ()
275
- match rt :
276
- case "uint16_t" | "int16_t" : overload_caps .append ("Int16" )
277
- case "uint64_t" | "int64_t" : overload_caps .append ("Int64" )
278
- case "float16_t" : overload_caps .append ("Float16" )
279
- case "float64_t" : overload_caps .append ("Float64" )
280
-
281
- for cap in overload_caps or [None ]:
282
- final_fn_name = fn_name + "_" + cap if (len (overload_caps ) > 1 ) else fn_name
283
- final_templates = templates .copy ()
262
+ result_ty = "void"
263
+
264
+ match result_ty :
265
+ case "uint16_t" | "int16_t" : caps .append ("Int16" )
266
+ case "uint64_t" | "int64_t" : caps .append ("Int64" )
267
+ case "float16_t" : caps .append ("Float16" )
268
+ case "float64_t" : caps .append ("Float64" )
269
+
270
+ for cap in caps or [None ]:
271
+ final_fn_name = fn_name + "_" + cap if (len (caps ) > 1 ) else fn_name
272
+ final_templates = templates .copy ()
273
+
274
+ if (not "typename T" in final_templates ) and (result_ty == "T" ):
275
+ final_templates = ["typename T" ] + final_templates
276
+
277
+ if len (caps ) > 0 :
278
+ if (("Float16" in cap and result_ty != "float16_t" ) or
279
+ ("Float32" in cap and result_ty != "float32_t" ) or
280
+ ("Float64" in cap and result_ty != "float64_t" ) or
281
+ ("Int16" in cap and result_ty != "int16_t" and result_ty != "uint16_t" ) or
282
+ ("Int64" in cap and result_ty != "int64_t" and result_ty != "uint64_t" )): continue
284
283
285
- if (not "typename T" in final_templates ) and (rt == "T" ):
286
- final_templates = ["typename T" ] + final_templates
287
-
288
- if len (overload_caps ) > 0 :
289
- if (("Float16" in cap and rt != "float16_t" ) or
290
- ("Float32" in cap and rt != "float32_t" ) or
291
- ("Float64" in cap and rt != "float64_t" ) or
292
- ("Int16" in cap and rt != "int16_t" and rt != "uint16_t" ) or
293
- ("Int64" in cap and rt != "int64_t" and rt != "uint64_t" )): continue
294
-
295
- if "Vector" in cap :
296
- rt = "vector<" + rt + ", N> "
297
- final_templates .append ("uint32_t N" )
298
-
299
- op_ty = "T"
300
- if prefered_op_ty != None :
301
- op_ty = prefered_op_ty
302
- elif rt != "void" :
303
- op_ty = rt
304
-
305
- args = []
306
- if "operands" in instruction :
307
- for operand in instruction ["operands" ]:
308
- operand_name = operand ["name" ].strip ("'" ) if "name" in operand else None
309
- operand_name = operand_name [0 ].lower () + operand_name [1 :] if (operand_name != None ) else ""
310
- match operand ["kind" ]:
311
- case "IdResult" | "IdResultType" : continue
312
- case "IdRef" :
313
- match operand ["name" ]:
314
- case "'Pointer'" :
315
- if shape == Shape .PTR_TEMPLATE :
316
- args .append ("P " + operand_name )
317
- elif shape == Shape .BDA :
318
- if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
319
- final_templates = ["typename T" ] + final_templates
320
- args .append ("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
321
- else :
322
- if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
323
- final_templates = ["typename T" ] + final_templates
324
- args .append ("[[vk::ext_reference]] " + op_ty + " " + operand_name )
325
- case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'" :
326
- if (not "typename T" in final_templates ) and (rt == "T" or op_ty == "T" ):
284
+ if "Vector" in cap :
285
+ result_ty = "vector<" + result_ty + ", N> "
286
+ final_templates .append ("uint32_t N" )
287
+
288
+ op_ty = "T"
289
+ if prefered_op_ty != None :
290
+ op_ty = prefered_op_ty
291
+ elif result_ty != "void" :
292
+ op_ty = result_ty
293
+
294
+ args = []
295
+ if "operands" in instruction :
296
+ for operand in instruction ["operands" ]:
297
+ operand_name = operand ["name" ].strip ("'" ) if "name" in operand else None
298
+ operand_name = operand_name [0 ].lower () + operand_name [1 :] if (operand_name != None ) else ""
299
+ match operand ["kind" ]:
300
+ case "IdResult" | "IdResultType" : continue
301
+ case "IdRef" :
302
+ match operand ["name" ]:
303
+ case "'Pointer'" :
304
+ if shape == Shape .PTR_TEMPLATE :
305
+ args .append ("P " + operand_name )
306
+ elif shape == Shape .BDA :
307
+ if (not "typename T" in final_templates ) and (result_ty == "T" or op_ty == "T" ):
308
+ final_templates = ["typename T" ] + final_templates
309
+ args .append ("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name )
310
+ else :
311
+ if (not "typename T" in final_templates ) and (result_ty == "T" or op_ty == "T" ):
327
312
final_templates = ["typename T" ] + final_templates
328
- args .append (op_ty + " " + operand_name )
329
- case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
330
- args .append ("uint32_t " + operand_name )
331
- case "'Predicate'" : args .append ("bool " + operand_name )
332
- case "'ClusterSize'" :
333
- if "quantifier" in operand and operand ["quantifier" ] == "?" : continue # TODO: overload
334
- else : return ignore (op_name ) # TODO
335
- case _: return ignore (op_name ) # TODO
336
- case "IdScope" : args .append ("uint32_t " + operand_name .lower () + "Scope" )
337
- case "IdMemorySemantics" : args .append (" uint32_t " + operand_name )
338
- case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
339
- case "MemoryAccess" :
340
- assert len (overload_caps ) <= 1
341
- if shape != Shape .BDA :
342
- writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
343
- writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
344
- writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , rt , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
345
- case _: return ignore (op_name ) # TODO
346
-
347
- writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , rt , args )
313
+ args .append ("[[vk::ext_reference]] " + op_ty + " " + operand_name )
314
+ case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'" :
315
+ if (not "typename T" in final_templates ) and (result_ty == "T" or op_ty == "T" ):
316
+ final_templates = ["typename T" ] + final_templates
317
+ args .append (op_ty + " " + operand_name )
318
+ case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'" :
319
+ args .append ("uint32_t " + operand_name )
320
+ case "'Predicate'" : args .append ("bool " + operand_name )
321
+ case "'ClusterSize'" :
322
+ if "quantifier" in operand and operand ["quantifier" ] == "?" : continue # TODO: overload
323
+ else : return ignore (op_name ) # TODO
324
+ case _: return ignore (op_name ) # TODO
325
+ case "IdScope" : args .append ("uint32_t " + operand_name .lower () + "Scope" )
326
+ case "IdMemorySemantics" : args .append (" uint32_t " + operand_name )
327
+ case "GroupOperation" : args .append ("[[vk::ext_literal]] uint32_t " + operand_name )
328
+ case "MemoryAccess" :
329
+ assert len (caps ) <= 1
330
+ if shape != Shape .BDA :
331
+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , result_ty , args + ["[[vk::ext_literal]] uint32_t memoryAccess" ])
332
+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , result_ty , args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam" ])
333
+ writeInst (writer , final_templates + ["uint32_t alignment" ], cap , exts , op_name , final_fn_name , conds , result_ty , args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002" , "[[vk::ext_literal]] uint32_t __alignment = alignment" ])
334
+ case _: return ignore (op_name ) # TODO
335
+
336
+ writeInst (writer , final_templates , cap , exts , op_name , final_fn_name , conds , result_ty , args )
348
337
349
338
350
339
def writeInst (writer : io .TextIOWrapper , templates , cap , exts , op_name , fn_name , conds , result_type , args ):
0 commit comments