Skip to content

Commit 20034bb

Browse files
committed
Add HLSL generator
1 parent 719dc53 commit 20034bb

File tree

2 files changed

+1651
-0
lines changed

2 files changed

+1651
-0
lines changed

tools/hlsl_generator/gen.py

+311
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
import json
2+
import io
3+
import os
4+
import re
5+
from enum import Enum
6+
from argparse import ArgumentParser
7+
from typing import NamedTuple
8+
from typing import Optional
9+
10+
head = """// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
11+
// This file is part of the "Nabla Engine".
12+
// For conditions of distribution and use, see copyright notice in nabla.h
13+
#ifndef _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_
14+
#define _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_
15+
16+
#ifdef __HLSL_VERSION
17+
#include "spirv/unified1/spirv.hpp"
18+
#include "spirv/unified1/GLSL.std.450.h"
19+
#endif
20+
21+
#include "nbl/builtin/hlsl/type_traits.hlsl"
22+
23+
namespace nbl
24+
{
25+
namespace hlsl
26+
{
27+
#ifdef __HLSL_VERSION
28+
namespace spirv
29+
{
30+
31+
//! General Decls
32+
template<uint32_t StorageClass, typename T>
33+
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
34+
35+
// The holy operation that makes addrof possible
36+
template<uint32_t StorageClass, typename T>
37+
[[vk::ext_instruction(spv::OpCopyObject)]]
38+
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
39+
40+
//! Std 450 Extended set operations
41+
template<typename SquareMatrix>
42+
[[vk::ext_instruction(GLSLstd450MatrixInverse)]]
43+
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
44+
45+
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
46+
template<typename T, typename U>
47+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
48+
[[vk::ext_instruction(spv::OpBitcast)]]
49+
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U);
50+
51+
template<typename T>
52+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
53+
[[vk::ext_instruction(spv::OpBitcast)]]
54+
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>);
55+
56+
template<typename T>
57+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
58+
[[vk::ext_instruction(spv::OpBitcast)]]
59+
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t);
60+
61+
template<class T, class U>
62+
[[vk::ext_instruction(spv::OpBitcast)]]
63+
T bitcast(U);
64+
"""
65+
66+
foot = """}
67+
68+
#endif
69+
}
70+
}
71+
72+
#endif
73+
"""
74+
75+
def gen(grammer_path, output_path):
76+
grammer_raw = open(grammer_path, "r").read()
77+
grammer = json.loads(grammer_raw)
78+
del grammer_raw
79+
80+
output = open(output_path, "w", buffering=1024**2)
81+
82+
builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"]
83+
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"]
84+
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"]
85+
86+
with output as writer:
87+
writer.write(head)
88+
89+
writer.write("\n//! Builtins\nnamespace builtin\n{")
90+
for b in builtins:
91+
builtin_type = None
92+
is_output = False
93+
builtin_name = b["enumerant"]
94+
match builtin_name:
95+
case "HelperInvocation": builtin_type = "bool"
96+
case "VertexIndex": builtin_type = "uint32_t"
97+
case "InstanceIndex": builtin_type = "uint32_t"
98+
case "NumWorkgroups": builtin_type = "uint32_t3"
99+
case "WorkgroupId": builtin_type = "uint32_t3"
100+
case "LocalInvocationId": builtin_type = "uint32_t3"
101+
case "GlobalInvocationId": builtin_type = "uint32_t3"
102+
case "LocalInvocationIndex": builtin_type = "uint32_t"
103+
case "SubgroupEqMask": builtin_type = "uint32_t4"
104+
case "SubgroupGeMask": builtin_type = "uint32_t4"
105+
case "SubgroupGtMask": builtin_type = "uint32_t4"
106+
case "SubgroupLeMask": builtin_type = "uint32_t4"
107+
case "SubgroupLtMask": builtin_type = "uint32_t4"
108+
case "SubgroupSize": builtin_type = "uint32_t"
109+
case "NumSubgroups": builtin_type = "uint32_t"
110+
case "SubgroupId": builtin_type = "uint32_t"
111+
case "SubgroupLocalInvocationId": builtin_type = "uint32_t"
112+
case "Position":
113+
builtin_type = "float32_t4"
114+
is_output = True
115+
case _: continue
116+
if is_output:
117+
writer.write("[[vk::ext_builtin_output(spv::BuiltIn" + builtin_name + ")]]\n")
118+
writer.write("static " + builtin_type + " " + builtin_name + ";\n")
119+
else:
120+
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + builtin_name + ")]]\n")
121+
writer.write("static const " + builtin_type + " " + builtin_name + ";\n")
122+
writer.write("}\n")
123+
124+
writer.write("\n//! Execution Modes\nnamespace execution_mode\n{")
125+
for em in execution_modes:
126+
name = em["enumerant"]
127+
name_l = name[0].lower() + name[1:]
128+
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n")
129+
writer.write("}\n")
130+
131+
writer.write("\n//! Group Operations\nnamespace group_operation\n{\n")
132+
for go in group_operations:
133+
name = go["enumerant"]
134+
value = go["value"]
135+
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n")
136+
writer.write("}\n")
137+
138+
writer.write("\n//! Instructions\n")
139+
for instruction in grammer["instructions"]:
140+
match instruction["class"]:
141+
case "Atomic":
142+
processInst(writer, instruction, InstOptions())
143+
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
144+
case "Memory":
145+
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
146+
processInst(writer, instruction, InstOptions(shape=Shape.PSB_RT))
147+
case "Barrier" | "Bit":
148+
processInst(writer, instruction, InstOptions())
149+
case "Reserved":
150+
match instruction["opname"]:
151+
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT":
152+
processInst(writer, instruction, InstOptions())
153+
case "Non-Uniform":
154+
match instruction["opname"]:
155+
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual":
156+
processInst(writer, instruction, InstOptions(result_ty="bool"))
157+
case "OpGroupNonUniformBallot":
158+
processInst(writer, instruction, InstOptions(result_ty="uint32_t4",op_ty="bool"))
159+
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract":
160+
processInst(writer, instruction, InstOptions(result_ty="bool",op_ty="uint32_t4"))
161+
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB":
162+
processInst(writer, instruction, InstOptions(result_ty="uint32_t",op_ty="uint32_t4"))
163+
case _: processInst(writer, instruction, InstOptions())
164+
case _: continue # TODO
165+
166+
writer.write(foot)
167+
168+
class Shape(Enum):
169+
DEFAULT = 0,
170+
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
171+
PSB_RT = 2, # PhysicalStorageBuffer Result Type
172+
173+
class InstOptions(NamedTuple):
174+
shape: Shape = Shape.DEFAULT
175+
result_ty: Optional[str] = None
176+
op_ty: Optional[str] = None
177+
178+
def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
179+
templates = []
180+
caps = []
181+
conds = []
182+
op_name = instruction["opname"]
183+
fn_name = op_name[2].lower() + op_name[3:]
184+
result_types = []
185+
186+
if "capabilities" in instruction and len(instruction["capabilities"]) > 0:
187+
for cap in instruction["capabilities"]:
188+
if cap == "Shader" or cap == "Kernel": continue
189+
caps.append(cap)
190+
191+
if options.shape == Shape.PTR_TEMPLATE:
192+
templates.append("typename P")
193+
conds.append("is_spirv_type_v<P>")
194+
195+
# split upper case words
196+
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)]
197+
198+
for m in matches:
199+
match m[0]:
200+
case "I":
201+
conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
202+
break
203+
case "U":
204+
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
205+
result_types = ["uint32_t", "uint64_t"]
206+
break
207+
case "S":
208+
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
209+
result_types = ["int32_t", "int64_t"]
210+
break
211+
case "F":
212+
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
213+
result_types = ["float"]
214+
break
215+
216+
if "operands" in instruction:
217+
operands = instruction["operands"]
218+
if operands[0]["kind"] == "IdResultType":
219+
operands = operands[2:]
220+
if len(result_types) == 0:
221+
if options.result_ty == None:
222+
result_types = ["T"]
223+
else:
224+
result_types = [options.result_ty]
225+
else:
226+
assert len(result_types) == 0
227+
result_types = ["void"]
228+
229+
for rt in result_types:
230+
op_ty = "T"
231+
if options.op_ty != None:
232+
op_ty = options.op_ty
233+
elif rt != "void":
234+
op_ty = rt
235+
236+
if (not "typename T" in templates) and (rt == "T"):
237+
templates = ["typename T"] + templates
238+
239+
args = []
240+
for operand in operands:
241+
operand_name = operand["name"].strip("'") if "name" in operand else None
242+
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
243+
match operand["kind"]:
244+
case "IdRef":
245+
match operand["name"]:
246+
case "'Pointer'":
247+
if options.shape == Shape.PTR_TEMPLATE:
248+
args.append("P " + operand_name)
249+
elif options.shape == Shape.PSB_RT:
250+
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
251+
templates = ["typename T"] + templates
252+
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
253+
else:
254+
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
255+
templates = ["typename T"] + templates
256+
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
257+
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
258+
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
259+
templates = ["typename T"] + templates
260+
args.append(op_ty + " " + operand_name)
261+
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
262+
args.append("uint32_t " + operand_name)
263+
case "'Predicate'": args.append("bool " + operand_name)
264+
case "'ClusterSize'":
265+
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
266+
else: return # TODO
267+
case _: return # TODO
268+
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
269+
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
270+
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
271+
case "MemoryAccess":
272+
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
273+
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
274+
writeInst(writer, templates + ["uint32_t alignment"], caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
275+
case _: return # TODO
276+
277+
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args)
278+
279+
280+
def writeInst(writer: io.TextIOWrapper, templates, caps, op_name, fn_name, conds, result_type, args):
281+
if len(caps) > 0:
282+
for cap in caps:
283+
final_fn_name = fn_name
284+
if (len(caps) > 1): final_fn_name = fn_name + "_" + cap
285+
writeInstInner(writer, templates, cap, op_name, final_fn_name, conds, result_type, args)
286+
else:
287+
writeInstInner(writer, templates, None, op_name, fn_name, conds, result_type, args)
288+
289+
def writeInstInner(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args):
290+
if len(templates) > 0:
291+
writer.write("template<" + ", ".join(templates) + ">\n")
292+
if (cap != None):
293+
writer.write("[[vk::ext_capability(spv::Capability" + cap + ")]]\n")
294+
writer.write("[[vk::ext_instruction(spv::" + op_name + ")]]\n")
295+
if len(conds) > 0:
296+
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_type + ">")
297+
else:
298+
writer.write(result_type)
299+
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")
300+
301+
302+
if __name__ == "__main__":
303+
script_dir_path = os.path.abspath(os.path.dirname(__file__))
304+
305+
parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions")
306+
parser.add_argument("output", type=str, help="HLSL output file")
307+
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json"))
308+
args = parser.parse_args()
309+
310+
gen(args.grammer, args.output)
311+

0 commit comments

Comments
 (0)