Skip to content

Commit a6825c8

Browse files
committed
Potentially optimize dot4{I,U}8Packed on Metal
This might allow the Metal compiler to emit faster code (but that's not confirmed). See <gpuweb/gpuweb#2677 (comment)> for the optimization. The limitation to Metal 2.1+ is discussed here: <#7574 (comment)>.
1 parent 50eb207 commit a6825c8

File tree

5 files changed

+204
-42
lines changed

5 files changed

+204
-42
lines changed

naga/src/back/msl/writer.rs

+135-37
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ const fn scalar_is_int(scalar: crate::Scalar) -> bool {
121121
/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
122122
const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
123123

124+
/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
125+
const REINTERPRET_PREFIX: &str = "reinterpreted_";
126+
124127
/// Wrapper for identifier names for clamped level-of-detail values
125128
///
126129
/// Values of this type implement [`core::fmt::Display`], formatting as
@@ -156,6 +159,30 @@ impl Display for ArraySizeMember {
156159
}
157160
}
158161

162+
/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
163+
///
164+
/// Implements [`core::fmt::Display`], formatting as a name derived from
165+
/// `target_type` and the variable name of `orig`.
166+
#[derive(Clone, Copy)]
167+
struct Reinterpreted<'a> {
168+
target_type: &'a str,
169+
orig: Handle<crate::Expression>,
170+
}
171+
172+
impl<'a> Reinterpreted<'a> {
173+
const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
174+
Self { target_type, orig }
175+
}
176+
}
177+
178+
impl Display for Reinterpreted<'_> {
179+
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
180+
f.write_str(REINTERPRET_PREFIX)?;
181+
f.write_str(self.target_type)?;
182+
self.orig.write_prefixed(f, "_e")
183+
}
184+
}
185+
159186
struct TypeContext<'a> {
160187
handle: Handle<crate::Type>,
161188
gctx: proc::GlobalCtx<'a>,
@@ -1470,14 +1497,14 @@ impl<W: Write> Writer<W> {
14701497

14711498
/// Emit code for the arithmetic expression of the dot product.
14721499
///
1473-
/// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector,
1474-
/// and an index. writes out the expression for the component at that index.
1475-
fn put_dot_product(
1500+
/// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1501+
/// an index. It writes out the expression for the vector component at that index.
1502+
fn put_dot_product<T: Copy>(
14761503
&mut self,
1477-
arg: Handle<crate::Expression>,
1478-
arg1: Handle<crate::Expression>,
1504+
arg: T,
1505+
arg1: T,
14791506
size: usize,
1480-
extractor: impl Fn(&mut Self, Handle<crate::Expression>, usize) -> BackendResult,
1507+
extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
14811508
) -> BackendResult {
14821509
// Write parentheses around the dot product expression to prevent operators
14831510
// with different precedences from applying earlier.
@@ -2206,27 +2233,53 @@ impl<W: Write> Writer<W> {
22062233
),
22072234
},
22082235
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2209-
let conversion = match fun {
2210-
Mf::Dot4I8Packed => "int",
2211-
Mf::Dot4U8Packed => "",
2212-
_ => unreachable!(),
2213-
};
2236+
if context.lang_version >= (2, 1) {
2237+
// Write potentially optimizable code using `packed_(u?)char4`.
2238+
// The two function arguments were already reinterpreted as packed (signed
2239+
// or unsigned) chars in `Self::put_block`.
2240+
let packed_type = match fun {
2241+
Mf::Dot4I8Packed => "packed_char4",
2242+
Mf::Dot4U8Packed => "packed_uchar4",
2243+
_ => unreachable!(),
2244+
};
22142245

2215-
return self.put_dot_product(
2216-
arg,
2217-
arg1.unwrap(),
2218-
4,
2219-
|writer, arg, index| {
2220-
write!(writer.out, "({}(", conversion)?;
2221-
writer.put_expression(arg, context, true)?;
2222-
if index == 3 {
2223-
write!(writer.out, ") >> 24)")?;
2224-
} else {
2225-
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2226-
}
2227-
Ok(())
2228-
},
2229-
);
2246+
return self.put_dot_product(
2247+
Reinterpreted::new(packed_type, arg),
2248+
Reinterpreted::new(packed_type, arg1.unwrap()),
2249+
4,
2250+
|writer, arg, index| {
2251+
// MSL implicitly promotes these (signed or unsigned) chars to
2252+
// `int` or `uint` in the multiplication, so no overflow can occur.
2253+
write!(writer.out, "{arg}[{index}]")?;
2254+
Ok(())
2255+
},
2256+
);
2257+
} else {
2258+
// Fall back to a polyfill since MSL < 2.1 doesn't seem to support
2259+
// bitcasting from uint to `packed_char4` or `packed_uchar4`.
2260+
// See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
2261+
let conversion = match fun {
2262+
Mf::Dot4I8Packed => "int",
2263+
Mf::Dot4U8Packed => "",
2264+
_ => unreachable!(),
2265+
};
2266+
2267+
return self.put_dot_product(
2268+
arg,
2269+
arg1.unwrap(),
2270+
4,
2271+
|writer, arg, index| {
2272+
write!(writer.out, "({}(", conversion)?;
2273+
writer.put_expression(arg, context, true)?;
2274+
if index == 3 {
2275+
write!(writer.out, ") >> 24)")?;
2276+
} else {
2277+
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2278+
}
2279+
Ok(())
2280+
},
2281+
);
2282+
}
22302283
}
22312284
Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
22322285
Mf::Cross => "cross",
@@ -3362,17 +3415,62 @@ impl<W: Write> Writer<W> {
33623415
match *statement {
33633416
crate::Statement::Emit(ref range) => {
33643417
for handle in range.clone() {
3365-
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
3366-
// may need to cache a clamped version of their level-of-detail argument.
3367-
if let crate::Expression::ImageLoad {
3368-
image,
3369-
level: mip_level,
3370-
..
3371-
} = context.expression.function.expressions[handle]
3372-
{
3373-
self.put_cache_restricted_level(
3374-
handle, image, mip_level, level, context,
3375-
)?;
3418+
use crate::MathFunction as Mf;
3419+
3420+
match context.expression.function.expressions[handle] {
3421+
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
3422+
// may need to cache a clamped version of their level-of-detail argument.
3423+
crate::Expression::ImageLoad {
3424+
image,
3425+
level: mip_level,
3426+
..
3427+
} => {
3428+
self.put_cache_restricted_level(
3429+
handle, image, mip_level, level, context,
3430+
)?;
3431+
}
3432+
3433+
// If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
3434+
// 2.1+ then we introduce two intermediate variables that recast the two
3435+
// arguments as packed (signed or unsigned) chars. The actual dot product
3436+
// is implemented in `Self::put_expression`, and it uses both of these
3437+
// intermediate variables multiple times. There's no danger that the
3438+
// origianal arguments get modified between the definition of these
3439+
// intermediate variables and the implementation of the actual dot
3440+
// product since we require the inputs of `Dot4{I, U}Packed` to be baked.
3441+
crate::Expression::Math {
3442+
fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3443+
arg,
3444+
arg1,
3445+
..
3446+
} => {
3447+
if context.expression.lang_version >= (2, 1) {
3448+
let arg1 = arg1.unwrap();
3449+
let packed_type = match fun {
3450+
Mf::Dot4I8Packed => "packed_char4",
3451+
Mf::Dot4U8Packed => "packed_uchar4",
3452+
_ => unreachable!(),
3453+
};
3454+
3455+
write!(
3456+
self.out,
3457+
"{level}{packed_type} {0} = as_type<{packed_type}>(",
3458+
Reinterpreted::new(packed_type, arg)
3459+
)?;
3460+
self.put_expression(arg, &context.expression, true)?;
3461+
writeln!(self.out, ");")?;
3462+
3463+
write!(
3464+
self.out,
3465+
"{level}{packed_type} {0} = as_type<{packed_type}>(",
3466+
Reinterpreted::new(packed_type, arg1)
3467+
)?;
3468+
self.put_expression(arg1, &context.expression, true)?;
3469+
writeln!(self.out, ");")?;
3470+
}
3471+
}
3472+
3473+
_ => (),
33763474
}
33773475

33783476
let ptr_class = context.expression.resolve_type(handle).pointer_space();
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by
2-
# using a version of SPIR-V / shader model that supports these without any extensions.
1+
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V, HLSL, and Metal
2+
# by using a language version / shader model that supports these (without any extensions).
33

4-
targets = "SPIRV | HLSL"
4+
targets = "SPIRV | HLSL | METAL"
55

66
[spv]
77
# We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6).
@@ -10,3 +10,6 @@ version = [1, 6]
1010

1111
[hlsl]
1212
shader_model = "V6_4"
13+
14+
[msl]
15+
lang_version = [2, 1]
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed`
2-
# on SPIRV and HLSL.
2+
# on SPIRV, HLSL, and Metal.
33

4-
targets = "SPIRV | HLSL"
4+
targets = "SPIRV | HLSL | METAL"
55

66
[spv]
77
# Provide some unrelated capability because an empty list of capabilities would
@@ -11,3 +11,6 @@ capabilities = ["Matrix"]
1111

1212
[hlsl]
1313
shader_model = "V6_3"
14+
15+
[msl]
16+
lang_version = [2, 0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// language: metal2.1
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
7+
8+
uint test_packed_integer_dot_product(
9+
) {
10+
packed_char4 reinterpreted_packed_char4_e0 = as_type<packed_char4>(1u);
11+
packed_char4 reinterpreted_packed_char4_e1 = as_type<packed_char4>(2u);
12+
int c_5_ = ( + reinterpreted_packed_char4_e0[0] * reinterpreted_packed_char4_e1[0] + reinterpreted_packed_char4_e0[1] * reinterpreted_packed_char4_e1[1] + reinterpreted_packed_char4_e0[2] * reinterpreted_packed_char4_e1[2] + reinterpreted_packed_char4_e0[3] * reinterpreted_packed_char4_e1[3]);
13+
packed_uchar4 reinterpreted_packed_uchar4_e3 = as_type<packed_uchar4>(3u);
14+
packed_uchar4 reinterpreted_packed_uchar4_e4 = as_type<packed_uchar4>(4u);
15+
uint c_6_ = ( + reinterpreted_packed_uchar4_e3[0] * reinterpreted_packed_uchar4_e4[0] + reinterpreted_packed_uchar4_e3[1] * reinterpreted_packed_uchar4_e4[1] + reinterpreted_packed_uchar4_e3[2] * reinterpreted_packed_uchar4_e4[2] + reinterpreted_packed_uchar4_e3[3] * reinterpreted_packed_uchar4_e4[3]);
16+
uint _e7 = 5u + c_6_;
17+
uint _e9 = 6u + c_6_;
18+
packed_char4 reinterpreted_packed_char4_e7 = as_type<packed_char4>(_e7);
19+
packed_char4 reinterpreted_packed_char4_e9 = as_type<packed_char4>(_e9);
20+
int c_7_ = ( + reinterpreted_packed_char4_e7[0] * reinterpreted_packed_char4_e9[0] + reinterpreted_packed_char4_e7[1] * reinterpreted_packed_char4_e9[1] + reinterpreted_packed_char4_e7[2] * reinterpreted_packed_char4_e9[2] + reinterpreted_packed_char4_e7[3] * reinterpreted_packed_char4_e9[3]);
21+
uint _e12 = 7u + c_6_;
22+
uint _e14 = 8u + c_6_;
23+
packed_uchar4 reinterpreted_packed_uchar4_e12 = as_type<packed_uchar4>(_e12);
24+
packed_uchar4 reinterpreted_packed_uchar4_e14 = as_type<packed_uchar4>(_e14);
25+
uint c_8_ = ( + reinterpreted_packed_uchar4_e12[0] * reinterpreted_packed_uchar4_e14[0] + reinterpreted_packed_uchar4_e12[1] * reinterpreted_packed_uchar4_e14[1] + reinterpreted_packed_uchar4_e12[2] * reinterpreted_packed_uchar4_e14[2] + reinterpreted_packed_uchar4_e12[3] * reinterpreted_packed_uchar4_e14[3]);
26+
return c_8_;
27+
}
28+
29+
kernel void main_(
30+
) {
31+
uint _e0 = test_packed_integer_dot_product();
32+
return;
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// language: metal2.0
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
7+
8+
uint test_packed_integer_dot_product(
9+
) {
10+
int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24));
11+
uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24));
12+
uint _e7 = 5u + c_6_;
13+
uint _e9 = 6u + c_6_;
14+
int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24));
15+
uint _e12 = 7u + c_6_;
16+
uint _e14 = 8u + c_6_;
17+
uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24));
18+
return c_8_;
19+
}
20+
21+
kernel void main_(
22+
) {
23+
uint _e0 = test_packed_integer_dot_product();
24+
return;
25+
}

0 commit comments

Comments
 (0)