From 8813b38ad816157db53306ba7f34ddce81e0fd3b Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sun, 27 Apr 2025 21:43:23 +0200 Subject: [PATCH 1/2] Potentially optimize `dot4{I,U}8Packed` on Metal This might allow the Metal compiler to emit faster code (but that's not confirmed). See for the optimization. The limitation to Metal 2.1+ is discussed here: . --- CHANGELOG.md | 2 +- naga/src/back/msl/writer.rs | 172 ++++++++++++++---- .../wgsl/functions-optimized-by-version.toml | 9 +- naga/tests/in/wgsl/functions-unoptimized.toml | 7 +- .../wgsl-functions-optimized-by-version.msl | 33 ++++ .../out/msl/wgsl-functions-unoptimized.msl | 25 +++ 6 files changed, 205 insertions(+), 43 deletions(-) create mode 100644 naga/tests/out/msl/wgsl-functions-optimized-by-version.msl create mode 100644 naga/tests/out/msl/wgsl-functions-unoptimized.msl diff --git a/CHANGELOG.md b/CHANGELOG.md index b92c3f02df..e42a705017 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,7 +56,7 @@ Bottom level categories: Naga now infers the correct binding layout when a resource appears only in an assignment to `_`. By @andyleiserson in [#7540](https://github.com/gfx-rs/wgpu/pull/7540). -- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V and HSLS if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494) and [#7574](https://github.com/gfx-rs/wgpu/pull/7574). +- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V, HSLS, and Metal if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494), [#7574](https://github.com/gfx-rs/wgpu/pull/7574), and [#7653](https://github.com/gfx-rs/wgpu/pull/7653). - Add polyfilled `pack4x{I,U}8Clamped` built-ins to all backends and WGSL frontend. By @ErichDonGubler in [#7546](https://github.com/gfx-rs/wgpu/pull/7546). - Allow textureLoad's sample index arg to be unsigned. By @jimblandy in [#7625](https://github.com/gfx-rs/wgpu/pull/7625). - Properly convert arguments to atomic operations. By @jimblandy in [#7573](https://github.com/gfx-rs/wgpu/pull/7573). diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f05e5c233a..a9e69e237f 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -121,6 +121,9 @@ const fn scalar_is_int(scalar: crate::Scalar) -> bool { /// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions. const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e"; +/// Prefix for reinterpreted expressions using `as_type(...)`. +const REINTERPRET_PREFIX: &str = "reinterpreted_"; + /// Wrapper for identifier names for clamped level-of-detail values /// /// Values of this type implement [`core::fmt::Display`], formatting as @@ -156,6 +159,30 @@ impl Display for ArraySizeMember { } } +/// Wrapper for reinterpreted variables using `as_type(orig)`. +/// +/// Implements [`core::fmt::Display`], formatting as a name derived from +/// `target_type` and the variable name of `orig`. +#[derive(Clone, Copy)] +struct Reinterpreted<'a> { + target_type: &'a str, + orig: Handle, +} + +impl<'a> Reinterpreted<'a> { + const fn new(target_type: &'a str, orig: Handle) -> Self { + Self { target_type, orig } + } +} + +impl Display for Reinterpreted<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + f.write_str(REINTERPRET_PREFIX)?; + f.write_str(self.target_type)?; + self.orig.write_prefixed(f, "_e") + } +} + struct TypeContext<'a> { handle: Handle, gctx: proc::GlobalCtx<'a>, @@ -1470,14 +1497,14 @@ impl Writer { /// Emit code for the arithmetic expression of the dot product. /// - /// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector, - /// and an index. writes out the expression for the component at that index. - fn put_dot_product( + /// The argument `extractor` is a function that accepts a `Writer`, a vector, and + /// an index. It writes out the expression for the vector component at that index. + fn put_dot_product( &mut self, - arg: Handle, - arg1: Handle, + arg: T, + arg1: T, size: usize, - extractor: impl Fn(&mut Self, Handle, usize) -> BackendResult, + extractor: impl Fn(&mut Self, T, usize) -> BackendResult, ) -> BackendResult { // Write parentheses around the dot product expression to prevent operators // with different precedences from applying earlier. @@ -2206,27 +2233,53 @@ impl Writer { ), }, fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => { - let conversion = match fun { - Mf::Dot4I8Packed => "int", - Mf::Dot4U8Packed => "", - _ => unreachable!(), - }; + if context.lang_version >= (2, 1) { + // Write potentially optimizable code using `packed_(u?)char4`. + // The two function arguments were already reinterpreted as packed (signed + // or unsigned) chars in `Self::put_block`. + let packed_type = match fun { + Mf::Dot4I8Packed => "packed_char4", + Mf::Dot4U8Packed => "packed_uchar4", + _ => unreachable!(), + }; - return self.put_dot_product( - arg, - arg1.unwrap(), - 4, - |writer, arg, index| { - write!(writer.out, "({}(", conversion)?; - writer.put_expression(arg, context, true)?; - if index == 3 { - write!(writer.out, ") >> 24)")?; - } else { - write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?; - } - Ok(()) - }, - ); + return self.put_dot_product( + Reinterpreted::new(packed_type, arg), + Reinterpreted::new(packed_type, arg1.unwrap()), + 4, + |writer, arg, index| { + // MSL implicitly promotes these (signed or unsigned) chars to + // `int` or `uint` in the multiplication, so no overflow can occur. + write!(writer.out, "{arg}[{index}]")?; + Ok(()) + }, + ); + } else { + // Fall back to a polyfill since MSL < 2.1 doesn't seem to support + // bitcasting from uint to `packed_char4` or `packed_uchar4`. + // See . + let conversion = match fun { + Mf::Dot4I8Packed => "int", + Mf::Dot4U8Packed => "", + _ => unreachable!(), + }; + + return self.put_dot_product( + arg, + arg1.unwrap(), + 4, + |writer, arg, index| { + write!(writer.out, "({}(", conversion)?; + writer.put_expression(arg, context, true)?; + if index == 3 { + write!(writer.out, ") >> 24)")?; + } else { + write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?; + } + Ok(()) + }, + ); + } } Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))), Mf::Cross => "cross", @@ -3362,17 +3415,62 @@ impl Writer { match *statement { crate::Statement::Emit(ref range) => { for handle in range.clone() { - // `ImageLoad` expressions covered by the `Restrict` bounds check policy - // may need to cache a clamped version of their level-of-detail argument. - if let crate::Expression::ImageLoad { - image, - level: mip_level, - .. - } = context.expression.function.expressions[handle] - { - self.put_cache_restricted_level( - handle, image, mip_level, level, context, - )?; + use crate::MathFunction as Mf; + + match context.expression.function.expressions[handle] { + // `ImageLoad` expressions covered by the `Restrict` bounds check policy + // may need to cache a clamped version of their level-of-detail argument. + crate::Expression::ImageLoad { + image, + level: mip_level, + .. + } => { + self.put_cache_restricted_level( + handle, image, mip_level, level, context, + )?; + } + + // If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal + // 2.1+ then we introduce two intermediate variables that recast the two + // arguments as packed (signed or unsigned) chars. The actual dot product + // is implemented in `Self::put_expression`, and it uses both of these + // intermediate variables multiple times. There's no danger that the + // original arguments get modified between the definition of these + // intermediate variables and the implementation of the actual dot + // product since we require the inputs of `Dot4{I, U}Packed` to be baked. + crate::Expression::Math { + fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed), + arg, + arg1, + .. + } => { + if context.expression.lang_version >= (2, 1) { + let arg1 = arg1.unwrap(); + let packed_type = match fun { + Mf::Dot4I8Packed => "packed_char4", + Mf::Dot4U8Packed => "packed_uchar4", + _ => unreachable!(), + }; + + write!( + self.out, + "{level}{packed_type} {0} = as_type<{packed_type}>(", + Reinterpreted::new(packed_type, arg) + )?; + self.put_expression(arg, &context.expression, true)?; + writeln!(self.out, ");")?; + + write!( + self.out, + "{level}{packed_type} {0} = as_type<{packed_type}>(", + Reinterpreted::new(packed_type, arg1) + )?; + self.put_expression(arg1, &context.expression, true)?; + writeln!(self.out, ");")?; + } + } + + _ => (), } let ptr_class = context.expression.resolve_type(handle).pointer_space(); diff --git a/naga/tests/in/wgsl/functions-optimized-by-version.toml b/naga/tests/in/wgsl/functions-optimized-by-version.toml index e3d5bc1028..2df2ffdf87 100644 --- a/naga/tests/in/wgsl/functions-optimized-by-version.toml +++ b/naga/tests/in/wgsl/functions-optimized-by-version.toml @@ -1,7 +1,7 @@ -# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by -# using a version of SPIR-V / shader model that supports these without any extensions. +# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V, HLSL, and Metal +# by using a language version / shader model that supports these (without any extensions). -targets = "SPIRV | HLSL" +targets = "SPIRV | HLSL | METAL" [spv] # We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6). @@ -10,3 +10,6 @@ version = [1, 6] [hlsl] shader_model = "V6_4" + +[msl] +lang_version = [2, 1] diff --git a/naga/tests/in/wgsl/functions-unoptimized.toml b/naga/tests/in/wgsl/functions-unoptimized.toml index 7361004be3..4a37aac5ff 100644 --- a/naga/tests/in/wgsl/functions-unoptimized.toml +++ b/naga/tests/in/wgsl/functions-unoptimized.toml @@ -1,7 +1,7 @@ # Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed` -# on SPIRV and HLSL. +# on SPIRV, HLSL, and Metal. -targets = "SPIRV | HLSL" +targets = "SPIRV | HLSL | METAL" [spv] # Provide some unrelated capability because an empty list of capabilities would @@ -11,3 +11,6 @@ capabilities = ["Matrix"] [hlsl] shader_model = "V6_3" + +[msl] +lang_version = [2, 0] diff --git a/naga/tests/out/msl/wgsl-functions-optimized-by-version.msl b/naga/tests/out/msl/wgsl-functions-optimized-by-version.msl new file mode 100644 index 0000000000..341cbf84de --- /dev/null +++ b/naga/tests/out/msl/wgsl-functions-optimized-by-version.msl @@ -0,0 +1,33 @@ +// language: metal2.1 +#include +#include + +using metal::uint; + + +uint test_packed_integer_dot_product( +) { + packed_char4 reinterpreted_packed_char4_e0 = as_type(1u); + packed_char4 reinterpreted_packed_char4_e1 = as_type(2u); + int c_5_ = ( + reinterpreted_packed_char4_e0[0] * reinterpreted_packed_char4_e1[0] + reinterpreted_packed_char4_e0[1] * reinterpreted_packed_char4_e1[1] + reinterpreted_packed_char4_e0[2] * reinterpreted_packed_char4_e1[2] + reinterpreted_packed_char4_e0[3] * reinterpreted_packed_char4_e1[3]); + packed_uchar4 reinterpreted_packed_uchar4_e3 = as_type(3u); + packed_uchar4 reinterpreted_packed_uchar4_e4 = as_type(4u); + uint c_6_ = ( + reinterpreted_packed_uchar4_e3[0] * reinterpreted_packed_uchar4_e4[0] + reinterpreted_packed_uchar4_e3[1] * reinterpreted_packed_uchar4_e4[1] + reinterpreted_packed_uchar4_e3[2] * reinterpreted_packed_uchar4_e4[2] + reinterpreted_packed_uchar4_e3[3] * reinterpreted_packed_uchar4_e4[3]); + uint _e7 = 5u + c_6_; + uint _e9 = 6u + c_6_; + packed_char4 reinterpreted_packed_char4_e7 = as_type(_e7); + packed_char4 reinterpreted_packed_char4_e9 = as_type(_e9); + int c_7_ = ( + reinterpreted_packed_char4_e7[0] * reinterpreted_packed_char4_e9[0] + reinterpreted_packed_char4_e7[1] * reinterpreted_packed_char4_e9[1] + reinterpreted_packed_char4_e7[2] * reinterpreted_packed_char4_e9[2] + reinterpreted_packed_char4_e7[3] * reinterpreted_packed_char4_e9[3]); + uint _e12 = 7u + c_6_; + uint _e14 = 8u + c_6_; + packed_uchar4 reinterpreted_packed_uchar4_e12 = as_type(_e12); + packed_uchar4 reinterpreted_packed_uchar4_e14 = as_type(_e14); + uint c_8_ = ( + reinterpreted_packed_uchar4_e12[0] * reinterpreted_packed_uchar4_e14[0] + reinterpreted_packed_uchar4_e12[1] * reinterpreted_packed_uchar4_e14[1] + reinterpreted_packed_uchar4_e12[2] * reinterpreted_packed_uchar4_e14[2] + reinterpreted_packed_uchar4_e12[3] * reinterpreted_packed_uchar4_e14[3]); + return c_8_; +} + +kernel void main_( +) { + uint _e0 = test_packed_integer_dot_product(); + return; +} diff --git a/naga/tests/out/msl/wgsl-functions-unoptimized.msl b/naga/tests/out/msl/wgsl-functions-unoptimized.msl new file mode 100644 index 0000000000..4ec98b0ec7 --- /dev/null +++ b/naga/tests/out/msl/wgsl-functions-unoptimized.msl @@ -0,0 +1,25 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +uint test_packed_integer_dot_product( +) { + int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24)); + uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24)); + uint _e7 = 5u + c_6_; + uint _e9 = 6u + c_6_; + int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24)); + uint _e12 = 7u + c_6_; + uint _e14 = 8u + c_6_; + uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24)); + return c_8_; +} + +kernel void main_( +) { + uint _e0 = test_packed_integer_dot_product(); + return; +} From 2a41b9562935588248d5fede96aff78cee887af5 Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Fri, 2 May 2025 00:08:59 +0200 Subject: [PATCH 2/2] [naga] Factor out new part of `put_block` on msl CI on test failed because the latest changes to `put_block` made its stack too big. Factoring out the new code into a separate method fixes this issue. --- naga/src/back/msl/writer.rs | 65 +++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index a9e69e237f..846bd3df5a 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3399,6 +3399,38 @@ impl Writer { Ok(()) } + /// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`. + /// + /// Caches the results in temporary variables (whose names are derived from + /// the original variable names). This caching avoids the need to redo the + /// casting for each vector component when emitting the dot product. + fn put_casting_to_packed_chars( + &mut self, + fun: crate::MathFunction, + arg0: Handle, + arg1: Handle, + indent: back::Level, + context: &StatementContext<'_>, + ) -> Result<(), Error> { + let packed_type = match fun { + crate::MathFunction::Dot4I8Packed => "packed_char4", + crate::MathFunction::Dot4U8Packed => "packed_uchar4", + _ => unreachable!(), + }; + + for arg in [arg0, arg1] { + write!( + self.out, + "{indent}{packed_type} {0} = as_type<{packed_type}>(", + Reinterpreted::new(packed_type, arg) + )?; + self.put_expression(arg, &context.expression, true)?; + writeln!(self.out, ");")?; + } + + Ok(()) + } + fn put_block( &mut self, level: back::Level, @@ -3443,31 +3475,14 @@ impl Writer { arg, arg1, .. - } => { - if context.expression.lang_version >= (2, 1) { - let arg1 = arg1.unwrap(); - let packed_type = match fun { - Mf::Dot4I8Packed => "packed_char4", - Mf::Dot4U8Packed => "packed_uchar4", - _ => unreachable!(), - }; - - write!( - self.out, - "{level}{packed_type} {0} = as_type<{packed_type}>(", - Reinterpreted::new(packed_type, arg) - )?; - self.put_expression(arg, &context.expression, true)?; - writeln!(self.out, ");")?; - - write!( - self.out, - "{level}{packed_type} {0} = as_type<{packed_type}>(", - Reinterpreted::new(packed_type, arg1) - )?; - self.put_expression(arg1, &context.expression, true)?; - writeln!(self.out, ");")?; - } + } if context.expression.lang_version >= (2, 1) => { + self.put_casting_to_packed_chars( + fun, + arg, + arg1.unwrap(), + level, + context, + )?; } _ => (),