Skip to content

Potentially optimize dot4{I,U}8Packed on Metal #7653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Bottom level categories:

Naga now infers the correct binding layout when a resource appears only in an assignment to `_`. By @andyleiserson in [#7540](https://github.com/gfx-rs/wgpu/pull/7540).

- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V and HSLS if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494) and [#7574](https://github.com/gfx-rs/wgpu/pull/7574).
- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V, HSLS, and Metal if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494), [#7574](https://github.com/gfx-rs/wgpu/pull/7574), and [#7653](https://github.com/gfx-rs/wgpu/pull/7653).
- Add polyfilled `pack4x{I,U}8Clamped` built-ins to all backends and WGSL frontend. By @ErichDonGubler in [#7546](https://github.com/gfx-rs/wgpu/pull/7546).
- Allow textureLoad's sample index arg to be unsigned. By @jimblandy in [#7625](https://github.com/gfx-rs/wgpu/pull/7625).
- Properly convert arguments to atomic operations. By @jimblandy in [#7573](https://github.com/gfx-rs/wgpu/pull/7573).
Expand Down
187 changes: 150 additions & 37 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ const fn scalar_is_int(scalar: crate::Scalar) -> bool {
/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";

/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
const REINTERPRET_PREFIX: &str = "reinterpreted_";

/// Wrapper for identifier names for clamped level-of-detail values
///
/// Values of this type implement [`core::fmt::Display`], formatting as
Expand Down Expand Up @@ -156,6 +159,30 @@ impl Display for ArraySizeMember {
}
}

/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
///
/// Implements [`core::fmt::Display`], formatting as a name derived from
/// `target_type` and the variable name of `orig`.
#[derive(Clone, Copy)]
struct Reinterpreted<'a> {
target_type: &'a str,
orig: Handle<crate::Expression>,
}

impl<'a> Reinterpreted<'a> {
const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
Self { target_type, orig }
}
}

impl Display for Reinterpreted<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.write_str(REINTERPRET_PREFIX)?;
f.write_str(self.target_type)?;
self.orig.write_prefixed(f, "_e")
}
}

struct TypeContext<'a> {
handle: Handle<crate::Type>,
gctx: proc::GlobalCtx<'a>,
Expand Down Expand Up @@ -1470,14 +1497,14 @@ impl<W: Write> Writer<W> {

/// Emit code for the arithmetic expression of the dot product.
///
/// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector,
/// and an index. writes out the expression for the component at that index.
fn put_dot_product(
/// The argument `extractor` is a function that accepts a `Writer`, a vector, and
/// an index. It writes out the expression for the vector component at that index.
fn put_dot_product<T: Copy>(
&mut self,
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
arg: T,
arg1: T,
size: usize,
extractor: impl Fn(&mut Self, Handle<crate::Expression>, usize) -> BackendResult,
extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
) -> BackendResult {
// Write parentheses around the dot product expression to prevent operators
// with different precedences from applying earlier.
Expand Down Expand Up @@ -2206,27 +2233,53 @@ impl<W: Write> Writer<W> {
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
let conversion = match fun {
Mf::Dot4I8Packed => "int",
Mf::Dot4U8Packed => "",
_ => unreachable!(),
};
if context.lang_version >= (2, 1) {
// Write potentially optimizable code using `packed_(u?)char4`.
// The two function arguments were already reinterpreted as packed (signed
// or unsigned) chars in `Self::put_block`.
let packed_type = match fun {
Mf::Dot4I8Packed => "packed_char4",
Mf::Dot4U8Packed => "packed_uchar4",
_ => unreachable!(),
};

return self.put_dot_product(
arg,
arg1.unwrap(),
4,
|writer, arg, index| {
write!(writer.out, "({}(", conversion)?;
writer.put_expression(arg, context, true)?;
if index == 3 {
write!(writer.out, ") >> 24)")?;
} else {
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
}
Ok(())
},
);
return self.put_dot_product(
Reinterpreted::new(packed_type, arg),
Reinterpreted::new(packed_type, arg1.unwrap()),
4,
|writer, arg, index| {
// MSL implicitly promotes these (signed or unsigned) chars to
// `int` or `uint` in the multiplication, so no overflow can occur.
write!(writer.out, "{arg}[{index}]")?;
Ok(())
},
);
} else {
// Fall back to a polyfill since MSL < 2.1 doesn't seem to support
// bitcasting from uint to `packed_char4` or `packed_uchar4`.
// See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
let conversion = match fun {
Mf::Dot4I8Packed => "int",
Mf::Dot4U8Packed => "",
_ => unreachable!(),
};

return self.put_dot_product(
arg,
arg1.unwrap(),
4,
|writer, arg, index| {
write!(writer.out, "({}(", conversion)?;
writer.put_expression(arg, context, true)?;
if index == 3 {
write!(writer.out, ") >> 24)")?;
} else {
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
}
Ok(())
},
);
}
}
Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
Mf::Cross => "cross",
Expand Down Expand Up @@ -3346,6 +3399,38 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
///
/// Caches the results in temporary variables (whose names are derived from
/// the original variable names). This caching avoids the need to redo the
/// casting for each vector component when emitting the dot product.
fn put_casting_to_packed_chars(
&mut self,
fun: crate::MathFunction,
arg0: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
indent: back::Level,
context: &StatementContext<'_>,
) -> Result<(), Error> {
let packed_type = match fun {
crate::MathFunction::Dot4I8Packed => "packed_char4",
crate::MathFunction::Dot4U8Packed => "packed_uchar4",
_ => unreachable!(),
};

for arg in [arg0, arg1] {
write!(
self.out,
"{indent}{packed_type} {0} = as_type<{packed_type}>(",
Reinterpreted::new(packed_type, arg)
)?;
self.put_expression(arg, &context.expression, true)?;
writeln!(self.out, ");")?;
}

Ok(())
}

fn put_block(
&mut self,
level: back::Level,
Expand All @@ -3362,17 +3447,45 @@ impl<W: Write> Writer<W> {
match *statement {
crate::Statement::Emit(ref range) => {
for handle in range.clone() {
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
// may need to cache a clamped version of their level-of-detail argument.
if let crate::Expression::ImageLoad {
image,
level: mip_level,
..
} = context.expression.function.expressions[handle]
{
self.put_cache_restricted_level(
handle, image, mip_level, level, context,
)?;
use crate::MathFunction as Mf;

match context.expression.function.expressions[handle] {
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
// may need to cache a clamped version of their level-of-detail argument.
crate::Expression::ImageLoad {
image,
level: mip_level,
..
} => {
self.put_cache_restricted_level(
handle, image, mip_level, level, context,
)?;
}

// If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
// 2.1+ then we introduce two intermediate variables that recast the two
// arguments as packed (signed or unsigned) chars. The actual dot product
// is implemented in `Self::put_expression`, and it uses both of these
// intermediate variables multiple times. There's no danger that the
// original arguments get modified between the definition of these
// intermediate variables and the implementation of the actual dot
// product since we require the inputs of `Dot4{I, U}Packed` to be baked.
crate::Expression::Math {
fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
arg,
arg1,
..
} if context.expression.lang_version >= (2, 1) => {
self.put_casting_to_packed_chars(
fun,
arg,
arg1.unwrap(),
level,
context,
)?;
}

_ => (),
}

let ptr_class = context.expression.resolve_type(handle).pointer_space();
Expand Down
9 changes: 6 additions & 3 deletions naga/tests/in/wgsl/functions-optimized-by-version.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by
# using a version of SPIR-V / shader model that supports these without any extensions.
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V, HLSL, and Metal
# by using a language version / shader model that supports these (without any extensions).

targets = "SPIRV | HLSL"
targets = "SPIRV | HLSL | METAL"

[spv]
# We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6).
Expand All @@ -10,3 +10,6 @@ version = [1, 6]

[hlsl]
shader_model = "V6_4"

[msl]
lang_version = [2, 1]
7 changes: 5 additions & 2 deletions naga/tests/in/wgsl/functions-unoptimized.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed`
# on SPIRV and HLSL.
# on SPIRV, HLSL, and Metal.

targets = "SPIRV | HLSL"
targets = "SPIRV | HLSL | METAL"

[spv]
# Provide some unrelated capability because an empty list of capabilities would
Expand All @@ -11,3 +11,6 @@ capabilities = ["Matrix"]

[hlsl]
shader_model = "V6_3"

[msl]
lang_version = [2, 0]
33 changes: 33 additions & 0 deletions naga/tests/out/msl/wgsl-functions-optimized-by-version.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// language: metal2.1
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


uint test_packed_integer_dot_product(
) {
packed_char4 reinterpreted_packed_char4_e0 = as_type<packed_char4>(1u);
packed_char4 reinterpreted_packed_char4_e1 = as_type<packed_char4>(2u);
int c_5_ = ( + reinterpreted_packed_char4_e0[0] * reinterpreted_packed_char4_e1[0] + reinterpreted_packed_char4_e0[1] * reinterpreted_packed_char4_e1[1] + reinterpreted_packed_char4_e0[2] * reinterpreted_packed_char4_e1[2] + reinterpreted_packed_char4_e0[3] * reinterpreted_packed_char4_e1[3]);
packed_uchar4 reinterpreted_packed_uchar4_e3 = as_type<packed_uchar4>(3u);
packed_uchar4 reinterpreted_packed_uchar4_e4 = as_type<packed_uchar4>(4u);
uint c_6_ = ( + reinterpreted_packed_uchar4_e3[0] * reinterpreted_packed_uchar4_e4[0] + reinterpreted_packed_uchar4_e3[1] * reinterpreted_packed_uchar4_e4[1] + reinterpreted_packed_uchar4_e3[2] * reinterpreted_packed_uchar4_e4[2] + reinterpreted_packed_uchar4_e3[3] * reinterpreted_packed_uchar4_e4[3]);
uint _e7 = 5u + c_6_;
uint _e9 = 6u + c_6_;
packed_char4 reinterpreted_packed_char4_e7 = as_type<packed_char4>(_e7);
packed_char4 reinterpreted_packed_char4_e9 = as_type<packed_char4>(_e9);
int c_7_ = ( + reinterpreted_packed_char4_e7[0] * reinterpreted_packed_char4_e9[0] + reinterpreted_packed_char4_e7[1] * reinterpreted_packed_char4_e9[1] + reinterpreted_packed_char4_e7[2] * reinterpreted_packed_char4_e9[2] + reinterpreted_packed_char4_e7[3] * reinterpreted_packed_char4_e9[3]);
uint _e12 = 7u + c_6_;
uint _e14 = 8u + c_6_;
packed_uchar4 reinterpreted_packed_uchar4_e12 = as_type<packed_uchar4>(_e12);
packed_uchar4 reinterpreted_packed_uchar4_e14 = as_type<packed_uchar4>(_e14);
uint c_8_ = ( + reinterpreted_packed_uchar4_e12[0] * reinterpreted_packed_uchar4_e14[0] + reinterpreted_packed_uchar4_e12[1] * reinterpreted_packed_uchar4_e14[1] + reinterpreted_packed_uchar4_e12[2] * reinterpreted_packed_uchar4_e14[2] + reinterpreted_packed_uchar4_e12[3] * reinterpreted_packed_uchar4_e14[3]);
return c_8_;
}

kernel void main_(
) {
uint _e0 = test_packed_integer_dot_product();
return;
}
25 changes: 25 additions & 0 deletions naga/tests/out/msl/wgsl-functions-unoptimized.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


uint test_packed_integer_dot_product(
) {
int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24));
uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24));
uint _e7 = 5u + c_6_;
uint _e9 = 6u + c_6_;
int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24));
uint _e12 = 7u + c_6_;
uint _e14 = 8u + c_6_;
uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24));
return c_8_;
}

kernel void main_(
) {
uint _e0 = test_packed_integer_dot_product();
return;
}