Skip to content

Commit a2e0b6a

Browse files
committed
hlsl_generator: don't emit unneccesary overloads
Signed-off-by: Ali Cheraghi <[email protected]>
1 parent e0919e8 commit a2e0b6a

File tree

2 files changed

+149
-638
lines changed

2 files changed

+149
-638
lines changed

tools/hlsl_generator/gen.py

+81-92
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def processInst(writer: io.TextIOWrapper,
220220
conds = []
221221
op_name = instruction["opname"]
222222
fn_name = op_name[2].lower() + op_name[3:]
223-
result_types = []
224223
exts = instruction["extensions"] if "extensions" in instruction else []
225224

226225
if "capabilities" in instruction and len(instruction["capabilities"]) > 0:
@@ -244,107 +243,97 @@ def processInst(writer: io.TextIOWrapper,
244243
conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
245244
break
246245
case "U":
247-
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
248-
result_types = ["uint16_t", "uint32_t", "uint64_t"]
246+
conds.append("is_unsigned_v<T>")
249247
break
250248
case "S":
251-
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
252-
result_types = ["int16_t", "int32_t", "int64_t"]
249+
conds.append("is_signed_v<T>")
253250
break
254251
case "F":
255-
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
256-
result_types = ["float16_t", "float32_t", "float64_t"]
252+
conds.append("is_floating_point<T>")
257253
break
258-
259-
match instruction["class"]:
260-
case "Bit":
261-
if len(result_types) == 0: conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
254+
else:
255+
if instruction["class"] == "Bit":
256+
conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
262257

263258
if "operands" in instruction and instruction["operands"][0]["kind"] == "IdResultType":
264-
if len(result_types) == 0:
265-
if result_ty == None:
266-
result_types = ["T"]
267-
else:
268-
result_types = [result_ty]
259+
if result_ty == None:
260+
result_ty = "T"
269261
else:
270-
assert len(result_types) == 0
271-
result_types = ["void"]
272-
273-
for rt in result_types:
274-
overload_caps = caps.copy()
275-
match rt:
276-
case "uint16_t" | "int16_t": overload_caps.append("Int16")
277-
case "uint64_t" | "int64_t": overload_caps.append("Int64")
278-
case "float16_t": overload_caps.append("Float16")
279-
case "float64_t": overload_caps.append("Float64")
280-
281-
for cap in overload_caps or [None]:
282-
final_fn_name = fn_name + "_" + cap if (len(overload_caps) > 1) else fn_name
283-
final_templates = templates.copy()
262+
result_ty = "void"
263+
264+
match result_ty:
265+
case "uint16_t" | "int16_t": caps.append("Int16")
266+
case "uint64_t" | "int64_t": caps.append("Int64")
267+
case "float16_t": caps.append("Float16")
268+
case "float64_t": caps.append("Float64")
269+
270+
for cap in caps or [None]:
271+
final_fn_name = fn_name + "_" + cap if (len(caps) > 1) else fn_name
272+
final_templates = templates.copy()
273+
274+
if (not "typename T" in final_templates) and (result_ty == "T"):
275+
final_templates = ["typename T"] + final_templates
276+
277+
if len(caps) > 0:
278+
if (("Float16" in cap and result_ty != "float16_t") or
279+
("Float32" in cap and result_ty != "float32_t") or
280+
("Float64" in cap and result_ty != "float64_t") or
281+
("Int16" in cap and result_ty != "int16_t" and result_ty != "uint16_t") or
282+
("Int64" in cap and result_ty != "int64_t" and result_ty != "uint64_t")): continue
284283

285-
if (not "typename T" in final_templates) and (rt == "T"):
286-
final_templates = ["typename T"] + final_templates
287-
288-
if len(overload_caps) > 0:
289-
if (("Float16" in cap and rt != "float16_t") or
290-
("Float32" in cap and rt != "float32_t") or
291-
("Float64" in cap and rt != "float64_t") or
292-
("Int16" in cap and rt != "int16_t" and rt != "uint16_t") or
293-
("Int64" in cap and rt != "int64_t" and rt != "uint64_t")): continue
294-
295-
if "Vector" in cap:
296-
rt = "vector<" + rt + ", N> "
297-
final_templates.append("uint32_t N")
298-
299-
op_ty = "T"
300-
if prefered_op_ty != None:
301-
op_ty = prefered_op_ty
302-
elif rt != "void":
303-
op_ty = rt
304-
305-
args = []
306-
if "operands" in instruction:
307-
for operand in instruction["operands"]:
308-
operand_name = operand["name"].strip("'") if "name" in operand else None
309-
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
310-
match operand["kind"]:
311-
case "IdResult" | "IdResultType": continue
312-
case "IdRef":
313-
match operand["name"]:
314-
case "'Pointer'":
315-
if shape == Shape.PTR_TEMPLATE:
316-
args.append("P " + operand_name)
317-
elif shape == Shape.BDA:
318-
if (not "typename T" in final_templates) and (rt == "T" or op_ty == "T"):
319-
final_templates = ["typename T"] + final_templates
320-
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
321-
else:
322-
if (not "typename T" in final_templates) and (rt == "T" or op_ty == "T"):
323-
final_templates = ["typename T"] + final_templates
324-
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
325-
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
326-
if (not "typename T" in final_templates) and (rt == "T" or op_ty == "T"):
284+
if "Vector" in cap:
285+
result_ty = "vector<" + result_ty + ", N> "
286+
final_templates.append("uint32_t N")
287+
288+
op_ty = "T"
289+
if prefered_op_ty != None:
290+
op_ty = prefered_op_ty
291+
elif result_ty != "void":
292+
op_ty = result_ty
293+
294+
args = []
295+
if "operands" in instruction:
296+
for operand in instruction["operands"]:
297+
operand_name = operand["name"].strip("'") if "name" in operand else None
298+
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
299+
match operand["kind"]:
300+
case "IdResult" | "IdResultType": continue
301+
case "IdRef":
302+
match operand["name"]:
303+
case "'Pointer'":
304+
if shape == Shape.PTR_TEMPLATE:
305+
args.append("P " + operand_name)
306+
elif shape == Shape.BDA:
307+
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"):
308+
final_templates = ["typename T"] + final_templates
309+
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
310+
else:
311+
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"):
327312
final_templates = ["typename T"] + final_templates
328-
args.append(op_ty + " " + operand_name)
329-
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
330-
args.append("uint32_t " + operand_name)
331-
case "'Predicate'": args.append("bool " + operand_name)
332-
case "'ClusterSize'":
333-
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
334-
else: return ignore(op_name) # TODO
335-
case _: return ignore(op_name) # TODO
336-
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
337-
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
338-
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
339-
case "MemoryAccess":
340-
assert len(overload_caps) <= 1
341-
if shape != Shape.BDA:
342-
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
343-
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
344-
writeInst(writer, final_templates + ["uint32_t alignment"], cap, exts, op_name, final_fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
345-
case _: return ignore(op_name) # TODO
346-
347-
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, rt, args)
313+
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
314+
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
315+
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"):
316+
final_templates = ["typename T"] + final_templates
317+
args.append(op_ty + " " + operand_name)
318+
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
319+
args.append("uint32_t " + operand_name)
320+
case "'Predicate'": args.append("bool " + operand_name)
321+
case "'ClusterSize'":
322+
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
323+
else: return ignore(op_name) # TODO
324+
case _: return ignore(op_name) # TODO
325+
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
326+
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
327+
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
328+
case "MemoryAccess":
329+
assert len(caps) <= 1
330+
if shape != Shape.BDA:
331+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
332+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
333+
writeInst(writer, final_templates + ["uint32_t alignment"], cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
334+
case _: return ignore(op_name) # TODO
335+
336+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args)
348337

349338

350339
def writeInst(writer: io.TextIOWrapper, templates, cap, exts, op_name, fn_name, conds, result_type, args):

0 commit comments

Comments
 (0)