Skip to content

Commit c387d96

Browse files
committed
hlsl_generator: handwritten BDA instructions
Signed-off-by: Ali Cheraghi <[email protected]>
1 parent a2e0b6a commit c387d96

File tree

2 files changed

+63
-87
lines changed

2 files changed

+63
-87
lines changed

tools/hlsl_generator/gen.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
{
3030
3131
//! General Decls
32-
template<class T>
33-
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_spirv_type<T>::value;
34-
3532
template<uint32_t StorageClass, typename T>
3633
struct pointer
3734
{
@@ -47,6 +44,9 @@
4744
template<uint32_t StorageClass, typename T>
4845
using pointer_t = typename pointer<StorageClass, T>::type;
4946
47+
template<uint32_t StorageClass, typename T>
48+
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_same_v<T, typename pointer<StorageClass, T>::type >;
49+
5050
// The holy operation that makes addrof possible
5151
template<uint32_t StorageClass, typename T>
5252
[[vk::ext_instruction(spv::OpCopyObject)]]
@@ -58,11 +58,31 @@
5858
[[vk::ext_instruction(34 /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
5959
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
6060
61+
//! Memory instructions
62+
template<typename T, uint32_t alignment>
63+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
64+
[[vk::ext_instruction(spv::OpLoad)]]
65+
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
66+
67+
template<typename T, typename P>
68+
[[vk::ext_instruction(spv::OpLoad)]]
69+
enable_if_t<is_spirv_type_v<P>, T> load(P pointer);
70+
71+
template<typename T, uint32_t alignment>
72+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
73+
[[vk::ext_instruction(spv::OpStore)]]
74+
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T obj, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
75+
76+
template<typename T, typename P>
77+
[[vk::ext_instruction(spv::OpStore)]]
78+
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T obj);
79+
80+
//! Bitcast Instructions
6181
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
6282
template<typename T, typename U>
6383
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
6484
[[vk::ext_instruction(spv::OpBitcast)]]
65-
enable_if_t<is_pointer_v<T>, T> bitcast(U);
85+
enable_if_t<is_pointer_v<spv::StorageClassPhysicalStorageBuffer, T>, T> bitcast(U);
6686
6787
template<typename T>
6888
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
@@ -181,9 +201,6 @@ def gen(grammer_path, output_path):
181201
case "Atomic":
182202
processInst(writer, instruction)
183203
processInst(writer, instruction, Shape.PTR_TEMPLATE)
184-
case "Memory":
185-
processInst(writer, instruction, Shape.PTR_TEMPLATE)
186-
processInst(writer, instruction, Shape.BDA)
187204
case "Barrier" | "Bit":
188205
processInst(writer, instruction)
189206
case "Reserved":
@@ -208,7 +225,6 @@ def gen(grammer_path, output_path):
208225
class Shape(Enum):
209226
DEFAULT = 0,
210227
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
211-
BDA = 2, # PhysicalStorageBuffer Result Type
212228

213229
def processInst(writer: io.TextIOWrapper,
214230
instruction,
@@ -231,8 +247,6 @@ def processInst(writer: io.TextIOWrapper,
231247
if shape == Shape.PTR_TEMPLATE:
232248
templates.append("typename P")
233249
conds.append("is_spirv_type_v<P>")
234-
elif shape == Shape.BDA:
235-
caps.append("PhysicalStorageBufferAddresses")
236250

237251
# split upper case words
238252
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)]
@@ -249,7 +263,7 @@ def processInst(writer: io.TextIOWrapper,
249263
conds.append("is_signed_v<T>")
250264
break
251265
case "F":
252-
conds.append("is_floating_point<T>")
266+
conds.append("(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>)")
253267
break
254268
else:
255269
if instruction["class"] == "Bit":
@@ -303,10 +317,6 @@ def processInst(writer: io.TextIOWrapper,
303317
case "'Pointer'":
304318
if shape == Shape.PTR_TEMPLATE:
305319
args.append("P " + operand_name)
306-
elif shape == Shape.BDA:
307-
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"):
308-
final_templates = ["typename T"] + final_templates
309-
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
310320
else:
311321
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"):
312322
final_templates = ["typename T"] + final_templates
@@ -327,10 +337,8 @@ def processInst(writer: io.TextIOWrapper,
327337
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
328338
case "MemoryAccess":
329339
assert len(caps) <= 1
330-
if shape != Shape.BDA:
331-
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
332-
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
333-
writeInst(writer, final_templates + ["uint32_t alignment"], cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
340+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
341+
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
334342
case _: return ignore(op_name) # TODO
335343

336344
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args)

tools/hlsl_generator/out.hlsl

+36-68
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ namespace spirv
2020
{
2121

2222
//! General Decls
23-
template<class T>
24-
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_spirv_type<T>::value;
25-
2623
template<uint32_t StorageClass, typename T>
2724
struct pointer
2825
{
@@ -38,6 +35,9 @@ struct pointer<spv::StorageClassPhysicalStorageBuffer, T>
3835
template<uint32_t StorageClass, typename T>
3936
using pointer_t = typename pointer<StorageClass, T>::type;
4037

38+
template<uint32_t StorageClass, typename T>
39+
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_same_v<T, typename pointer<StorageClass, T>::type >;
40+
4141
// The holy operation that makes addrof possible
4242
template<uint32_t StorageClass, typename T>
4343
[[vk::ext_instruction(spv::OpCopyObject)]]
@@ -49,11 +49,31 @@ template<typename SquareMatrix>
4949
[[vk::ext_instruction(34 /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
5050
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
5151

52+
//! Memory instructions
53+
template<typename T, uint32_t alignment>
54+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
55+
[[vk::ext_instruction(spv::OpLoad)]]
56+
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
57+
58+
template<typename T, typename P>
59+
[[vk::ext_instruction(spv::OpLoad)]]
60+
enable_if_t<is_spirv_type_v<P>, T> load(P pointer);
61+
62+
template<typename T, uint32_t alignment>
63+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
64+
[[vk::ext_instruction(spv::OpStore)]]
65+
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T obj, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
66+
67+
template<typename T, typename P>
68+
[[vk::ext_instruction(spv::OpStore)]]
69+
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T obj);
70+
71+
//! Bitcast Instructions
5272
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
5373
template<typename T, typename U>
5474
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
5575
[[vk::ext_instruction(spv::OpBitcast)]]
56-
enable_if_t<is_pointer_v<T>, T> bitcast(U);
76+
enable_if_t<is_pointer_v<spv::StorageClassPhysicalStorageBuffer, T>, T> bitcast(U);
5777

5878
template<typename T>
5979
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
@@ -548,58 +568,6 @@ namespace group_operation
548568
}
549569

550570
//! Instructions
551-
template<typename T, typename P>
552-
[[vk::ext_instruction(spv::OpLoad)]]
553-
enable_if_t<is_spirv_type_v<P>, T> load(P pointer, [[vk::ext_literal]] uint32_t memoryAccess);
554-
555-
template<typename T, typename P>
556-
[[vk::ext_instruction(spv::OpLoad)]]
557-
enable_if_t<is_spirv_type_v<P>, T> load(P pointer, [[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam);
558-
559-
template<typename T, typename P, uint32_t alignment>
560-
[[vk::ext_instruction(spv::OpLoad)]]
561-
enable_if_t<is_spirv_type_v<P>, T> load(P pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
562-
563-
template<typename T, typename P>
564-
[[vk::ext_instruction(spv::OpLoad)]]
565-
enable_if_t<is_spirv_type_v<P>, T> load(P pointer);
566-
567-
template<typename T, uint32_t alignment>
568-
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
569-
[[vk::ext_instruction(spv::OpLoad)]]
570-
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
571-
572-
template<typename T>
573-
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
574-
[[vk::ext_instruction(spv::OpLoad)]]
575-
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer);
576-
577-
template<typename T, typename P>
578-
[[vk::ext_instruction(spv::OpStore)]]
579-
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object, [[vk::ext_literal]] uint32_t memoryAccess);
580-
581-
template<typename T, typename P>
582-
[[vk::ext_instruction(spv::OpStore)]]
583-
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object, [[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam);
584-
585-
template<typename T, typename P, uint32_t alignment>
586-
[[vk::ext_instruction(spv::OpStore)]]
587-
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
588-
589-
template<typename T, typename P>
590-
[[vk::ext_instruction(spv::OpStore)]]
591-
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object);
592-
593-
template<typename T, uint32_t alignment>
594-
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
595-
[[vk::ext_instruction(spv::OpStore)]]
596-
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T object, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
597-
598-
template<typename T>
599-
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
600-
[[vk::ext_instruction(spv::OpStore)]]
601-
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T object);
602-
603571
template<typename T>
604572
[[vk::ext_capability(spv::CapabilityBitInstructions)]]
605573
[[vk::ext_instruction(spv::OpBitFieldInsert)]]
@@ -838,17 +806,17 @@ enable_if_t<(is_signed_v<T> || is_unsigned_v<T>), T> groupNonUniformIAdd_GroupNo
838806
template<typename T>
839807
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
840808
[[vk::ext_instruction(spv::OpGroupNonUniformFAdd)]]
841-
enable_if_t<is_floating_point<T>, T> groupNonUniformFAdd_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
809+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFAdd_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
842810

843811
template<typename T>
844812
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
845813
[[vk::ext_instruction(spv::OpGroupNonUniformFAdd)]]
846-
enable_if_t<is_floating_point<T>, T> groupNonUniformFAdd_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
814+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFAdd_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
847815

848816
template<typename T>
849817
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
850818
[[vk::ext_instruction(spv::OpGroupNonUniformFAdd)]]
851-
enable_if_t<is_floating_point<T>, T> groupNonUniformFAdd_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
819+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFAdd_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
852820

853821
template<typename T>
854822
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
@@ -868,17 +836,17 @@ enable_if_t<(is_signed_v<T> || is_unsigned_v<T>), T> groupNonUniformIMul_GroupNo
868836
template<typename T>
869837
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
870838
[[vk::ext_instruction(spv::OpGroupNonUniformFMul)]]
871-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMul_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
839+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMul_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
872840

873841
template<typename T>
874842
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
875843
[[vk::ext_instruction(spv::OpGroupNonUniformFMul)]]
876-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMul_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
844+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMul_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
877845

878846
template<typename T>
879847
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
880848
[[vk::ext_instruction(spv::OpGroupNonUniformFMul)]]
881-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMul_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
849+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMul_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
882850

883851
template<typename T>
884852
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
@@ -913,17 +881,17 @@ enable_if_t<is_unsigned_v<T>, T> groupNonUniformUMin_GroupNonUniformPartitionedN
913881
template<typename T>
914882
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
915883
[[vk::ext_instruction(spv::OpGroupNonUniformFMin)]]
916-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMin_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
884+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMin_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
917885

918886
template<typename T>
919887
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
920888
[[vk::ext_instruction(spv::OpGroupNonUniformFMin)]]
921-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMin_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
889+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMin_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
922890

923891
template<typename T>
924892
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
925893
[[vk::ext_instruction(spv::OpGroupNonUniformFMin)]]
926-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMin_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
894+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMin_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
927895

928896
template<typename T>
929897
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
@@ -958,17 +926,17 @@ enable_if_t<is_unsigned_v<T>, T> groupNonUniformUMax_GroupNonUniformPartitionedN
958926
template<typename T>
959927
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
960928
[[vk::ext_instruction(spv::OpGroupNonUniformFMax)]]
961-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMax_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
929+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMax_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
962930

963931
template<typename T>
964932
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
965933
[[vk::ext_instruction(spv::OpGroupNonUniformFMax)]]
966-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMax_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
934+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMax_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
967935

968936
template<typename T>
969937
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
970938
[[vk::ext_instruction(spv::OpGroupNonUniformFMax)]]
971-
enable_if_t<is_floating_point<T>, T> groupNonUniformFMax_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
939+
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMax_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
972940

973941
template<typename T>
974942
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]

0 commit comments

Comments
 (0)