Skip to content

Commit f78153b

Browse files
committed
[naga] Vectorize [un]pack4x{I, U}8[Clamp] on msl
Implements more direct conversions between 32-bit integers and 4x8-bit integer vectors using bit casting to/from `packed_[u]char4` when on MSL 2.1+ (older versions of MSL don't seem to support these bit casts). - `unpack4x{I, U}8(x)` becomes `[u]int4(as_type<packed_[u]char4>(x))`; - `pack4x{I, U}8(x)` becomes `as_type<uint>(packed_[u]char4(x))`; and - `pack4x{I, U}8Clamp(x)` becomes `as_type<uint>(packed_uchar4(metal::clamp(x, 0, 255)))`. These bit casts match the WGSL spec for these functions because Metal runs on little-endian machines.
1 parent 1de2c4f commit f78153b

File tree

4 files changed

+308
-47
lines changed

4 files changed

+308
-47
lines changed

naga/src/back/msl/writer.rs

+98-47
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,63 @@ impl<W: Write> Writer<W> {
14971497
Ok(())
14981498
}
14991499

1500+
/// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
1501+
fn put_pack4x8(
1502+
&mut self,
1503+
arg: Handle<crate::Expression>,
1504+
context: &ExpressionContext<'_>,
1505+
was_signed: bool,
1506+
clamp_bounds: Option<(&str, &str)>,
1507+
) -> Result<(), Error> {
1508+
if context.lang_version >= (2, 1) {
1509+
let packed_type = if was_signed {
1510+
"packed_char4"
1511+
} else {
1512+
"packed_uchar4"
1513+
};
1514+
// Metal uses little endian byte order, which matches what WGSL expects here.
1515+
write!(self.out, "as_type<uint>({packed_type}(")?;
1516+
if let Some((min, max)) = clamp_bounds {
1517+
// Clamping a vector to scalar bounds works and operates component-wise.
1518+
write!(self.out, "{NAMESPACE}::clamp(")?;
1519+
self.put_expression(arg, context, true)?;
1520+
write!(self.out, ", {min}, {max})")?;
1521+
} else {
1522+
self.put_expression(arg, context, true)?;
1523+
}
1524+
write!(self.out, "))")?;
1525+
} else {
1526+
// MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
1527+
if was_signed {
1528+
write!(self.out, "uint(")?;
1529+
}
1530+
let write_arg = |this: &mut Self| -> BackendResult {
1531+
if let Some((min, max)) = clamp_bounds {
1532+
write!(this.out, "{NAMESPACE}::clamp(")?;
1533+
this.put_expression(arg, context, true)?;
1534+
write!(this.out, ", {min}, {max})")?;
1535+
} else {
1536+
this.put_expression(arg, context, true)?;
1537+
}
1538+
Ok(())
1539+
};
1540+
write!(self.out, "(")?;
1541+
write_arg(self)?;
1542+
write!(self.out, "[0] & 0xFF) | ((")?;
1543+
write_arg(self)?;
1544+
write!(self.out, "[1] & 0xFF) << 8) | ((")?;
1545+
write_arg(self)?;
1546+
write!(self.out, "[2] & 0xFF) << 16) | ((")?;
1547+
write_arg(self)?;
1548+
write!(self.out, "[3] & 0xFF) << 24)")?;
1549+
if was_signed {
1550+
write!(self.out, ")")?;
1551+
}
1552+
}
1553+
1554+
Ok(())
1555+
}
1556+
15001557
/// Emit code for the isign expression.
15011558
///
15021559
fn put_isign(
@@ -2437,53 +2494,41 @@ impl<W: Write> Writer<W> {
24372494
write!(self.out, "{fun_name}")?;
24382495
self.put_call_parameters(iter::once(arg), context)?;
24392496
}
2440-
fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
2441-
let was_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
2442-
let clamp_bounds = match fun {
2443-
Mf::Pack4xI8Clamp => Some(("-128", "127")),
2444-
Mf::Pack4xU8Clamp => Some(("0", "255")),
2445-
_ => None,
2446-
};
2447-
if was_signed {
2448-
write!(self.out, "uint(")?;
2449-
}
2450-
let write_arg = |this: &mut Self| -> BackendResult {
2451-
if let Some((min, max)) = clamp_bounds {
2452-
write!(this.out, "{NAMESPACE}::clamp(")?;
2453-
this.put_expression(arg, context, true)?;
2454-
write!(this.out, ", {min}, {max})")?;
2455-
} else {
2456-
this.put_expression(arg, context, true)?;
2457-
}
2458-
Ok(())
2459-
};
2460-
write!(self.out, "(")?;
2461-
write_arg(self)?;
2462-
write!(self.out, "[0] & 0xFF) | ((")?;
2463-
write_arg(self)?;
2464-
write!(self.out, "[1] & 0xFF) << 8) | ((")?;
2465-
write_arg(self)?;
2466-
write!(self.out, "[2] & 0xFF) << 16) | ((")?;
2467-
write_arg(self)?;
2468-
write!(self.out, "[3] & 0xFF) << 24)")?;
2469-
if was_signed {
2470-
write!(self.out, ")")?;
2471-
}
2497+
Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
2498+
Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
2499+
Mf::Pack4xI8Clamp => {
2500+
self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
2501+
}
2502+
Mf::Pack4xU8Clamp => {
2503+
self.put_pack4x8(arg, context, false, Some(("0", "255")))?
24722504
}
24732505
fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
2474-
write!(self.out, "(")?;
2475-
if matches!(fun, Mf::Unpack4xU8) {
2476-
write!(self.out, "u")?;
2506+
let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
2507+
"u"
2508+
} else {
2509+
""
2510+
};
2511+
2512+
if context.lang_version >= (2, 1) {
2513+
// Metal uses little endian byte order, which matches what WGSL expects here.
2514+
write!(
2515+
self.out,
2516+
"{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
2517+
)?;
2518+
self.put_expression(arg, context, true)?;
2519+
write!(self.out, "))")?;
2520+
} else {
2521+
// MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
2522+
write!(self.out, "({sign_prefix}int4(")?;
2523+
self.put_expression(arg, context, true)?;
2524+
write!(self.out, ", ")?;
2525+
self.put_expression(arg, context, true)?;
2526+
write!(self.out, " >> 8, ")?;
2527+
self.put_expression(arg, context, true)?;
2528+
write!(self.out, " >> 16, ")?;
2529+
self.put_expression(arg, context, true)?;
2530+
write!(self.out, " >> 24) << 24 >> 24)")?;
24772531
}
2478-
write!(self.out, "int4(")?;
2479-
self.put_expression(arg, context, true)?;
2480-
write!(self.out, ", ")?;
2481-
self.put_expression(arg, context, true)?;
2482-
write!(self.out, " >> 8, ")?;
2483-
self.put_expression(arg, context, true)?;
2484-
write!(self.out, " >> 16, ")?;
2485-
self.put_expression(arg, context, true)?;
2486-
write!(self.out, " >> 24) << 24 >> 24)")?;
24872532
}
24882533
Mf::QuantizeToF16 => {
24892534
match *context.resolve_type(arg) {
@@ -3226,14 +3271,20 @@ impl<W: Write> Writer<W> {
32263271
self.need_bake_expressions.insert(arg);
32273272
self.need_bake_expressions.insert(arg1.unwrap());
32283273
}
3229-
crate::MathFunction::FirstLeadingBit
3230-
| crate::MathFunction::Pack4xI8
3274+
crate::MathFunction::FirstLeadingBit => {
3275+
self.need_bake_expressions.insert(arg);
3276+
}
3277+
crate::MathFunction::Pack4xI8
32313278
| crate::MathFunction::Pack4xU8
32323279
| crate::MathFunction::Pack4xI8Clamp
32333280
| crate::MathFunction::Pack4xU8Clamp
32343281
| crate::MathFunction::Unpack4xI8
32353282
| crate::MathFunction::Unpack4xU8 => {
3236-
self.need_bake_expressions.insert(arg);
3283+
// On MSL < 2.1, we emit a polyfill for these functions that uses the
3284+
// argument multiple times. This is no longer necessary on MSL >= 2.1.
3285+
if context.lang_version < (2, 1) {
3286+
self.need_bake_expressions.insert(arg);
3287+
}
32373288
}
32383289
crate::MathFunction::ExtractBits => {
32393290
// Only argument 1 is re-used.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
targets = "METAL"
2+
3+
[msl]
4+
lang_version = [2, 1]
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Keep in sync with `bits_downlevel` and `bits_downlevel_webgl`
2+
3+
@compute @workgroup_size(1)
4+
fn main() {
5+
var i = 0;
6+
var i2 = vec2<i32>(0);
7+
var i3 = vec3<i32>(0);
8+
var i4 = vec4<i32>(0);
9+
var u = 0u;
10+
var u2 = vec2<u32>(0u);
11+
var u3 = vec3<u32>(0u);
12+
var u4 = vec4<u32>(0u);
13+
var f2 = vec2<f32>(0.0);
14+
var f4 = vec4<f32>(0.0);
15+
u = pack4x8snorm(f4);
16+
u = pack4x8unorm(f4);
17+
u = pack2x16snorm(f2);
18+
u = pack2x16unorm(f2);
19+
u = pack2x16float(f2);
20+
u = pack4xI8(i4);
21+
u = pack4xU8(u4);
22+
u = pack4xI8Clamp(i4);
23+
u = pack4xU8Clamp(u4);
24+
f4 = unpack4x8snorm(u);
25+
f4 = unpack4x8unorm(u);
26+
f2 = unpack2x16snorm(u);
27+
f2 = unpack2x16unorm(u);
28+
f2 = unpack2x16float(u);
29+
i4 = unpack4xI8(u);
30+
u4 = unpack4xU8(u);
31+
i = insertBits(i, i, 5u, 10u);
32+
i2 = insertBits(i2, i2, 5u, 10u);
33+
i3 = insertBits(i3, i3, 5u, 10u);
34+
i4 = insertBits(i4, i4, 5u, 10u);
35+
u = insertBits(u, u, 5u, 10u);
36+
u2 = insertBits(u2, u2, 5u, 10u);
37+
u3 = insertBits(u3, u3, 5u, 10u);
38+
u4 = insertBits(u4, u4, 5u, 10u);
39+
i = extractBits(i, 5u, 10u);
40+
i2 = extractBits(i2, 5u, 10u);
41+
i3 = extractBits(i3, 5u, 10u);
42+
i4 = extractBits(i4, 5u, 10u);
43+
u = extractBits(u, 5u, 10u);
44+
u2 = extractBits(u2, 5u, 10u);
45+
u3 = extractBits(u3, 5u, 10u);
46+
u4 = extractBits(u4, 5u, 10u);
47+
i = firstTrailingBit(i);
48+
u2 = firstTrailingBit(u2);
49+
i3 = firstLeadingBit(i3);
50+
u3 = firstLeadingBit(u3);
51+
i = firstLeadingBit(i);
52+
u = firstLeadingBit(u);
53+
i = countOneBits(i);
54+
i2 = countOneBits(i2);
55+
i3 = countOneBits(i3);
56+
i4 = countOneBits(i4);
57+
u = countOneBits(u);
58+
u2 = countOneBits(u2);
59+
u3 = countOneBits(u3);
60+
u4 = countOneBits(u4);
61+
i = reverseBits(i);
62+
i2 = reverseBits(i2);
63+
i3 = reverseBits(i3);
64+
i4 = reverseBits(i4);
65+
u = reverseBits(u);
66+
u2 = reverseBits(u2);
67+
u3 = reverseBits(u3);
68+
u4 = reverseBits(u4);
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// language: metal2.1
2+
#include <metal_stdlib>
3+
#include <simd/simd.h>
4+
5+
using metal::uint;
6+
7+
8+
kernel void main_(
9+
) {
10+
int i = 0;
11+
metal::int2 i2_ = metal::int2(0);
12+
metal::int3 i3_ = metal::int3(0);
13+
metal::int4 i4_ = metal::int4(0);
14+
uint u = 0u;
15+
metal::uint2 u2_ = metal::uint2(0u);
16+
metal::uint3 u3_ = metal::uint3(0u);
17+
metal::uint4 u4_ = metal::uint4(0u);
18+
metal::float2 f2_ = metal::float2(0.0);
19+
metal::float4 f4_ = metal::float4(0.0);
20+
metal::float4 _e28 = f4_;
21+
u = metal::pack_float_to_snorm4x8(_e28);
22+
metal::float4 _e30 = f4_;
23+
u = metal::pack_float_to_unorm4x8(_e30);
24+
metal::float2 _e32 = f2_;
25+
u = metal::pack_float_to_snorm2x16(_e32);
26+
metal::float2 _e34 = f2_;
27+
u = metal::pack_float_to_unorm2x16(_e34);
28+
metal::float2 _e36 = f2_;
29+
u = as_type<uint>(half2(_e36));
30+
metal::int4 _e38 = i4_;
31+
u = as_type<uint>(packed_char4(_e38));
32+
metal::uint4 _e40 = u4_;
33+
u = as_type<uint>(packed_uchar4(_e40));
34+
metal::int4 _e42 = i4_;
35+
u = as_type<uint>(packed_char4(metal::clamp(_e42, -128, 127)));
36+
metal::uint4 _e44 = u4_;
37+
u = as_type<uint>(packed_uchar4(metal::clamp(_e44, 0, 255)));
38+
uint _e46 = u;
39+
f4_ = metal::unpack_snorm4x8_to_float(_e46);
40+
uint _e48 = u;
41+
f4_ = metal::unpack_unorm4x8_to_float(_e48);
42+
uint _e50 = u;
43+
f2_ = metal::unpack_snorm2x16_to_float(_e50);
44+
uint _e52 = u;
45+
f2_ = metal::unpack_unorm2x16_to_float(_e52);
46+
uint _e54 = u;
47+
f2_ = float2(as_type<half2>(_e54));
48+
uint _e56 = u;
49+
i4_ = int4(as_type<packed_char4>(_e56));
50+
uint _e58 = u;
51+
u4_ = uint4(as_type<packed_uchar4>(_e58));
52+
int _e60 = i;
53+
int _e61 = i;
54+
i = metal::insert_bits(_e60, _e61, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
55+
metal::int2 _e65 = i2_;
56+
metal::int2 _e66 = i2_;
57+
i2_ = metal::insert_bits(_e65, _e66, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
58+
metal::int3 _e70 = i3_;
59+
metal::int3 _e71 = i3_;
60+
i3_ = metal::insert_bits(_e70, _e71, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
61+
metal::int4 _e75 = i4_;
62+
metal::int4 _e76 = i4_;
63+
i4_ = metal::insert_bits(_e75, _e76, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
64+
uint _e80 = u;
65+
uint _e81 = u;
66+
u = metal::insert_bits(_e80, _e81, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
67+
metal::uint2 _e85 = u2_;
68+
metal::uint2 _e86 = u2_;
69+
u2_ = metal::insert_bits(_e85, _e86, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
70+
metal::uint3 _e90 = u3_;
71+
metal::uint3 _e91 = u3_;
72+
u3_ = metal::insert_bits(_e90, _e91, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
73+
metal::uint4 _e95 = u4_;
74+
metal::uint4 _e96 = u4_;
75+
u4_ = metal::insert_bits(_e95, _e96, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
76+
int _e100 = i;
77+
i = metal::extract_bits(_e100, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
78+
metal::int2 _e104 = i2_;
79+
i2_ = metal::extract_bits(_e104, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
80+
metal::int3 _e108 = i3_;
81+
i3_ = metal::extract_bits(_e108, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
82+
metal::int4 _e112 = i4_;
83+
i4_ = metal::extract_bits(_e112, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
84+
uint _e116 = u;
85+
u = metal::extract_bits(_e116, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
86+
metal::uint2 _e120 = u2_;
87+
u2_ = metal::extract_bits(_e120, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
88+
metal::uint3 _e124 = u3_;
89+
u3_ = metal::extract_bits(_e124, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
90+
metal::uint4 _e128 = u4_;
91+
u4_ = metal::extract_bits(_e128, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
92+
int _e132 = i;
93+
i = (((metal::ctz(_e132) + 1) % 33) - 1);
94+
metal::uint2 _e134 = u2_;
95+
u2_ = (((metal::ctz(_e134) + 1) % 33) - 1);
96+
metal::int3 _e136 = i3_;
97+
i3_ = metal::select(31 - metal::clz(metal::select(_e136, ~_e136, _e136 < 0)), int3(-1), _e136 == 0 || _e136 == -1);
98+
metal::uint3 _e138 = u3_;
99+
u3_ = metal::select(31 - metal::clz(_e138), uint3(-1), _e138 == 0 || _e138 == -1);
100+
int _e140 = i;
101+
i = metal::select(31 - metal::clz(metal::select(_e140, ~_e140, _e140 < 0)), int(-1), _e140 == 0 || _e140 == -1);
102+
uint _e142 = u;
103+
u = metal::select(31 - metal::clz(_e142), uint(-1), _e142 == 0 || _e142 == -1);
104+
int _e144 = i;
105+
i = metal::popcount(_e144);
106+
metal::int2 _e146 = i2_;
107+
i2_ = metal::popcount(_e146);
108+
metal::int3 _e148 = i3_;
109+
i3_ = metal::popcount(_e148);
110+
metal::int4 _e150 = i4_;
111+
i4_ = metal::popcount(_e150);
112+
uint _e152 = u;
113+
u = metal::popcount(_e152);
114+
metal::uint2 _e154 = u2_;
115+
u2_ = metal::popcount(_e154);
116+
metal::uint3 _e156 = u3_;
117+
u3_ = metal::popcount(_e156);
118+
metal::uint4 _e158 = u4_;
119+
u4_ = metal::popcount(_e158);
120+
int _e160 = i;
121+
i = metal::reverse_bits(_e160);
122+
metal::int2 _e162 = i2_;
123+
i2_ = metal::reverse_bits(_e162);
124+
metal::int3 _e164 = i3_;
125+
i3_ = metal::reverse_bits(_e164);
126+
metal::int4 _e166 = i4_;
127+
i4_ = metal::reverse_bits(_e166);
128+
uint _e168 = u;
129+
u = metal::reverse_bits(_e168);
130+
metal::uint2 _e170 = u2_;
131+
u2_ = metal::reverse_bits(_e170);
132+
metal::uint3 _e172 = u3_;
133+
u3_ = metal::reverse_bits(_e172);
134+
metal::uint4 _e174 = u4_;
135+
u4_ = metal::reverse_bits(_e174);
136+
return;
137+
}

0 commit comments

Comments
 (0)