Skip to content

Commit f44bd46

Browse files
committed
Optimize dot4{I, U}8Packed for all spv versions
Emit optimized code for `dot4{I, U}8Packed` regardless of SPIR-V version as long as the required capabilities are available. On SPIR-V < 1.6, require the extension "SPV_KHR_integer_dot_product" for this. On SPIR-V >= 1.6, don't require the extension because the corresponding capabilities are part of SPIR-V >= 1.6 proper.
1 parent be9debd commit f44bd46

11 files changed

+122
-108
lines changed

naga/src/back/spv/block.rs

+13-9
Original file line numberDiff line numberDiff line change
@@ -1143,17 +1143,21 @@ impl BlockContext<'_> {
11431143
),
11441144
},
11451145
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
1146-
if self.writer.lang_version() >= (1, 6)
1147-
&& self
1148-
.writer
1149-
.require_all(&[
1150-
spirv::Capability::DotProduct,
1151-
spirv::Capability::DotProductInput4x8BitPacked,
1152-
])
1153-
.is_ok()
1146+
if self
1147+
.writer
1148+
.require_all(&[
1149+
spirv::Capability::DotProduct,
1150+
spirv::Capability::DotProductInput4x8BitPacked,
1151+
])
1152+
.is_ok()
11541153
{
11551154
// Write optimized code using `PackedVectorFormat4x8Bit`.
1156-
self.writer.use_extension("SPV_KHR_integer_dot_product");
1155+
if self.writer.lang_version() < (1, 6) {
1156+
// SPIR-V 1.6 supports the required capabilities natively, so the extension
1157+
// is only required for earlier versions. See right column of
1158+
// <https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSDot>.
1159+
self.writer.use_extension("SPV_KHR_integer_dot_product");
1160+
}
11571161

11581162
let op = match fun {
11591163
Mf::Dot4I8Packed => spirv::Op::SDot,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` by enabling the
2+
# required capabilities on a SPIR-V version where these capabilities are only
3+
# available via the extension "SPV_KHR_integer_dot_product".
4+
5+
targets = "SPIRV"
6+
7+
[spv]
8+
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
9+
version = [1, 0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by
2+
# using a version of SPIR-V / shader model that supports these without any extensions.
3+
4+
targets = "SPIRV | HLSL"
5+
6+
[spv]
7+
# We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6).
8+
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
9+
version = [1, 6]
10+
11+
[hlsl]
12+
shader_model = "V6_4"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
fn test_packed_integer_dot_product() -> u32 {
2+
let a_5 = 1u;
3+
let b_5 = 2u;
4+
let c_5: i32 = dot4I8Packed(a_5, b_5);
5+
6+
let a_6 = 3u;
7+
let b_6 = 4u;
8+
let c_6: u32 = dot4U8Packed(a_6, b_6);
9+
10+
// test baking of arguments
11+
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
12+
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
13+
return c_8;
14+
}
15+
16+
@compute @workgroup_size(1)
17+
fn main() {
18+
let c = test_packed_integer_dot_product();
19+
}

naga/tests/in/wgsl/functions-optimized.toml

-11
This file was deleted.

naga/tests/out/spv/wgsl-functions-optimized.spvasm renamed to naga/tests/out/spv/wgsl-functions-optimized-by-capability.spvasm

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; SPIR-V
2-
; Version: 1.6
2+
; Version: 1.0
33
; Generator: rspirv
44
; Bound: 30
55
OpCapability Shader
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
; SPIR-V
2+
; Version: 1.6
3+
; Generator: rspirv
4+
; Bound: 30
5+
OpCapability Shader
6+
OpCapability DotProductKHR
7+
OpCapability DotProductInput4x8BitPackedKHR
8+
%1 = OpExtInstImport "GLSL.std.450"
9+
OpMemoryModel Logical GLSL450
10+
OpEntryPoint GLCompute %26 "main"
11+
OpExecutionMode %26 LocalSize 1 1 1
12+
%2 = OpTypeVoid
13+
%3 = OpTypeInt 32 0
14+
%6 = OpTypeFunction %3
15+
%7 = OpConstant %3 1
16+
%8 = OpConstant %3 2
17+
%9 = OpConstant %3 3
18+
%10 = OpConstant %3 4
19+
%11 = OpConstant %3 5
20+
%12 = OpConstant %3 6
21+
%13 = OpConstant %3 7
22+
%14 = OpConstant %3 8
23+
%16 = OpTypeInt 32 1
24+
%27 = OpTypeFunction %2
25+
%5 = OpFunction %3 None %6
26+
%4 = OpLabel
27+
OpBranch %15
28+
%15 = OpLabel
29+
%17 = OpSDotKHR %16 %7 %8 PackedVectorFormat4x8BitKHR
30+
%18 = OpUDotKHR %3 %9 %10 PackedVectorFormat4x8BitKHR
31+
%19 = OpIAdd %3 %11 %18
32+
%20 = OpIAdd %3 %12 %18
33+
%21 = OpSDotKHR %16 %19 %20 PackedVectorFormat4x8BitKHR
34+
%22 = OpIAdd %3 %13 %18
35+
%23 = OpIAdd %3 %14 %18
36+
%24 = OpUDotKHR %3 %22 %23 PackedVectorFormat4x8BitKHR
37+
OpReturnValue %24
38+
OpFunctionEnd
39+
%26 = OpFunction %2 None %27
40+
%25 = OpLabel
41+
OpBranch %28
42+
%28 = OpLabel
43+
%29 = OpFunctionCall %3 %5
44+
OpReturn
45+
OpFunctionEnd
+23-87
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
; SPIR-V
22
; Version: 1.1
33
; Generator: rspirv
4-
; Bound: 162
4+
; Bound: 95
55
OpCapability Shader
6+
OpCapability DotProductKHR
7+
OpCapability DotProductInput4x8BitPackedKHR
8+
OpExtension "SPV_KHR_integer_dot_product"
69
%1 = OpExtInstImport "GLSL.std.450"
710
OpMemoryModel Logical GLSL450
8-
OpEntryPoint GLCompute %156 "main"
9-
OpExecutionMode %156 LocalSize 1 1 1
11+
OpEntryPoint GLCompute %89 "main"
12+
OpExecutionMode %89 LocalSize 1 1 1
1013
%2 = OpTypeVoid
1114
%4 = OpTypeFloat 32
1215
%3 = OpTypeVector %4 2
@@ -39,10 +42,7 @@ OpExecutionMode %156 LocalSize 1 1 1
3942
%76 = OpConstant %6 6
4043
%77 = OpConstant %6 7
4144
%78 = OpConstant %6 8
42-
%83 = OpConstant %6 0
43-
%84 = OpConstant %6 16
44-
%85 = OpConstant %6 24
45-
%157 = OpTypeFunction %2
45+
%90 = OpTypeFunction %2
4646
%8 = OpFunction %3 None %9
4747
%7 = OpLabel
4848
OpBranch %14
@@ -96,86 +96,22 @@ OpFunctionEnd
9696
%69 = OpLabel
9797
OpBranch %79
9898
%79 = OpLabel
99-
%81 = OpBitcast %5 %22
100-
%82 = OpBitcast %5 %72
101-
%86 = OpBitFieldSExtract %5 %81 %83 %78
102-
%87 = OpBitFieldSExtract %5 %82 %83 %78
103-
%88 = OpIMul %5 %86 %87
104-
%89 = OpIAdd %5 %32 %88
105-
%90 = OpBitFieldSExtract %5 %81 %78 %78
106-
%91 = OpBitFieldSExtract %5 %82 %78 %78
107-
%92 = OpIMul %5 %90 %91
108-
%93 = OpIAdd %5 %89 %92
109-
%94 = OpBitFieldSExtract %5 %81 %84 %78
110-
%95 = OpBitFieldSExtract %5 %82 %84 %78
111-
%96 = OpIMul %5 %94 %95
112-
%97 = OpIAdd %5 %93 %96
113-
%98 = OpBitFieldSExtract %5 %81 %85 %78
114-
%99 = OpBitFieldSExtract %5 %82 %85 %78
115-
%100 = OpIMul %5 %98 %99
116-
%80 = OpIAdd %5 %97 %100
117-
%102 = OpBitFieldUExtract %6 %73 %83 %78
118-
%103 = OpBitFieldUExtract %6 %74 %83 %78
119-
%104 = OpIMul %6 %102 %103
120-
%105 = OpIAdd %6 %41 %104
121-
%106 = OpBitFieldUExtract %6 %73 %78 %78
122-
%107 = OpBitFieldUExtract %6 %74 %78 %78
123-
%108 = OpIMul %6 %106 %107
124-
%109 = OpIAdd %6 %105 %108
125-
%110 = OpBitFieldUExtract %6 %73 %84 %78
126-
%111 = OpBitFieldUExtract %6 %74 %84 %78
127-
%112 = OpIMul %6 %110 %111
128-
%113 = OpIAdd %6 %109 %112
129-
%114 = OpBitFieldUExtract %6 %73 %85 %78
130-
%115 = OpBitFieldUExtract %6 %74 %85 %78
131-
%116 = OpIMul %6 %114 %115
132-
%101 = OpIAdd %6 %113 %116
133-
%117 = OpIAdd %6 %75 %101
134-
%118 = OpIAdd %6 %76 %101
135-
%120 = OpBitcast %5 %117
136-
%121 = OpBitcast %5 %118
137-
%122 = OpBitFieldSExtract %5 %120 %83 %78
138-
%123 = OpBitFieldSExtract %5 %121 %83 %78
139-
%124 = OpIMul %5 %122 %123
140-
%125 = OpIAdd %5 %32 %124
141-
%126 = OpBitFieldSExtract %5 %120 %78 %78
142-
%127 = OpBitFieldSExtract %5 %121 %78 %78
143-
%128 = OpIMul %5 %126 %127
144-
%129 = OpIAdd %5 %125 %128
145-
%130 = OpBitFieldSExtract %5 %120 %84 %78
146-
%131 = OpBitFieldSExtract %5 %121 %84 %78
147-
%132 = OpIMul %5 %130 %131
148-
%133 = OpIAdd %5 %129 %132
149-
%134 = OpBitFieldSExtract %5 %120 %85 %78
150-
%135 = OpBitFieldSExtract %5 %121 %85 %78
151-
%136 = OpIMul %5 %134 %135
152-
%119 = OpIAdd %5 %133 %136
153-
%137 = OpIAdd %6 %77 %101
154-
%138 = OpIAdd %6 %78 %101
155-
%140 = OpBitFieldUExtract %6 %137 %83 %78
156-
%141 = OpBitFieldUExtract %6 %138 %83 %78
157-
%142 = OpIMul %6 %140 %141
158-
%143 = OpIAdd %6 %41 %142
159-
%144 = OpBitFieldUExtract %6 %137 %78 %78
160-
%145 = OpBitFieldUExtract %6 %138 %78 %78
161-
%146 = OpIMul %6 %144 %145
162-
%147 = OpIAdd %6 %143 %146
163-
%148 = OpBitFieldUExtract %6 %137 %84 %78
164-
%149 = OpBitFieldUExtract %6 %138 %84 %78
165-
%150 = OpIMul %6 %148 %149
166-
%151 = OpIAdd %6 %147 %150
167-
%152 = OpBitFieldUExtract %6 %137 %85 %78
168-
%153 = OpBitFieldUExtract %6 %138 %85 %78
169-
%154 = OpIMul %6 %152 %153
170-
%139 = OpIAdd %6 %151 %154
171-
OpReturnValue %139
99+
%80 = OpSDotKHR %5 %22 %72 PackedVectorFormat4x8BitKHR
100+
%81 = OpUDotKHR %6 %73 %74 PackedVectorFormat4x8BitKHR
101+
%82 = OpIAdd %6 %75 %81
102+
%83 = OpIAdd %6 %76 %81
103+
%84 = OpSDotKHR %5 %82 %83 PackedVectorFormat4x8BitKHR
104+
%85 = OpIAdd %6 %77 %81
105+
%86 = OpIAdd %6 %78 %81
106+
%87 = OpUDotKHR %6 %85 %86 PackedVectorFormat4x8BitKHR
107+
OpReturnValue %87
172108
OpFunctionEnd
173-
%156 = OpFunction %2 None %157
174-
%155 = OpLabel
175-
OpBranch %158
176-
%158 = OpLabel
177-
%159 = OpFunctionCall %3 %8
178-
%160 = OpFunctionCall %5 %17
179-
%161 = OpFunctionCall %6 %70
109+
%89 = OpFunction %2 None %90
110+
%88 = OpLabel
111+
OpBranch %91
112+
%91 = OpLabel
113+
%92 = OpFunctionCall %3 %8
114+
%93 = OpFunctionCall %5 %17
115+
%94 = OpFunctionCall %6 %70
180116
OpReturn
181117
OpFunctionEnd

0 commit comments

Comments
 (0)