Skip to content

Vectorize [un]pack4x{I, U}8[Clamp] on spv and msl #7664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: trunk
Choose a base branch
from

Conversation

robamler
Copy link
Contributor

@robamler robamler commented May 3, 2025

Connections
Related to, but independent of:

Description
Emits simpler code (using vectorized ops) for unpack4xI8, unpack4xU8, pack4xI8, pack4xU8, pack4xI8Clamp, and pack4xU8Clamp on SPIR-V and Metal (with MSL 2.1+).

  • on SPIR-V (if capability "Int8" is available): exploit that SClamp, UClamp, and OpUConvert accept vector arguments (in which case results are computed per component), and that OpBitcast can cast between vectors and scalars, with a well-defined bit order that matches that required by the WGSL spec (see below for details).
  • on Metal (MSL 2.1+ only): translate (see shader playground)
    • unpack4x{I, U}8(x)[u]int4(as_type<packed_[u]char4>(x)),
    • pack4x{I, U}8(x)as_type<uint>(packed_[u]char4(x)), and
    • pack4x{I, U}8Clamp(x)as_type<uint>(packed_uchar4(metal::clamp(x, 0, 255))).

Regarding byte order (this was not relevant for the dot product in #7653 but is relevant now):

  • WGSL states:

    Component e[i] of the input is mapped to bits 8 x i through 8 x i + 7 of the result.

  • This corresponds to little-endian byte order, which is what Apple uses for Metal.
  • SPIR-V states the same byte order in more complicated words:

    Within this mapping, any single component of S [remark: a 32-bit int in our case] (mapping to multiple components of L [remark: a 4-vector of 8-bit ints in our case]) maps its lower-ordered bits to the lower-numbered components of L

Testing

  • SPIR-V part was already covered by snapshot tests.
  • Added a snapshot test for MSL 2.1+
  • It would be nice to add tests for correctness (rather than just validity) of the MSL and SPIR-V output here (i.e., given some input vector, test that does it get packed into the correct uint). Is there an easy way to do this?
    • [Update: this is already implemented in wgpu_gpu::shader::data_builtins::pack4x_i8, but currently uses the polyfill]

Squash or Rebase?
All commits should pass CI.
Needs squashing.

Checklist

  • Run cargo fmt.
  • Run taplo format.
  • Run cargo clippy --tests. If applicable, add:
    • --target wasm32-unknown-unknown
  • Run cargo xtask test to run tests.
  • If this contains user-facing changes, add a CHANGELOG.md entry.

robamler added a commit to robamler/wgpu that referenced this pull request May 3, 2025
@robamler robamler force-pushed the optimize-pack4x8 branch 2 times, most recently from c8e640f to ac120e9 Compare May 3, 2025 21:16
@robamler

This comment was marked as resolved.

@robamler robamler requested a review from a team as a code owner May 4, 2025 18:57
@robamler robamler changed the title [naga] Vectorize [un]pack4x{I, U}8[Clamp] on spv and msl Vectorize [un]pack4x{I, U}8[Clamp] on spv and msl May 4, 2025
@robamler robamler force-pushed the optimize-pack4x8 branch from ced786f to fcb58b3 Compare May 4, 2025 20:40
robamler added 5 commits May 4, 2025 23:10
Emits vectorized SPIR-V code for the WGSL functions `unpack4xI8`,
`unpack4xU8`, `pack4xI8`, `pack4xU8`, `pack4xI8Clamp`, `pack4xU8Clamp`.

Exploits the following facts about SPIR-V ops:
- `SClamp`, `UClamp`, and `OpUConvert` accept vector arguments, in which
  case results are computed per component; and
- `OpBitcast` can cast between vectors and scalars, with a well-defined
  bit order that matches that required by the WGSL spec, see below.

WGSL spec for `pack4xI8` [1]:

> Component e[i] of the input is mapped to bits 8 x i through 8 x i + 7
> of the result.

SPIR-V spec for `OpBitcast` [2]:

> Within this mapping, any single component of `S` [remark: the type
> with fewer but wider components] (mapping to multiple components of
> `L` [remark: the type with more but narrower components]) maps its
> lower-ordered bits to the lower-numbered components of `L`.

[1] https://www.w3.org/TR/WGSL/#pack4xI8-builtin
[2] https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast
Implements more direct conversions between 32-bit integers and 4x8-bit
integer vectors using bit casting to/from `packed_[u]char4` when on
MSL 2.1+ (older versions of MSL don't seem to support these bit casts).

- `unpack4x{I, U}8(x)` becomes `[u]int4(as_type<packed_[u]char4>(x))`;
- `pack4x{I, U}8(x)` becomes `as_type<uint>(packed_[u]char4(x))`; and
- `pack4x{I, U}8Clamp(x)` becomes
  `as_type<uint>(packed_uchar4(metal::clamp(x, 0, 255)))`.

These bit casts match the WGSL spec for these functions because Metal
runs on little-endian machines.
Separates the Vulkan feature sets
`VkPhysicalDeviceShaderFloat16Int8Features` and
`VkPhysicalDevice16BitStorageFeatures`, which previously were used
"together, or not at all".

This commit should not change any behavior yet, but I'd like to run full
CI tests on it for now. If the CI tests pass, I'll use this separation
to enable the `shader_int8` feature separately from the rest of the
features to enable optimizations of `[un]pack4x{I,U}8[Clamp]` on SPIR-V.
@robamler robamler force-pushed the optimize-pack4x8 branch from fcb58b3 to e7ebc1b Compare May 4, 2025 21:22
robamler added 2 commits May 4, 2025 23:48
This allows declaring the SPIR-V capability "Int8", which allows us to
generate faster code for `[un]pack4x{I, U}8[Clamp]`.
@robamler robamler force-pushed the optimize-pack4x8 branch from a136017 to 659df18 Compare May 4, 2025 21:48
@robamler
Copy link
Contributor Author

robamler commented May 4, 2025

I added the required changes to wgpu-hal so that these optimizations now do actually get triggered on SPIR-V.

This will need some squashing before being merged, but it might be easier to review the PR with the current history (please let me know if you'd like me to squash first).

Notes to self:

  • 62c32b3 must be squashed into 1f695b1 because the latter would fail CI.
  • c197de1 should become a fixup because it's trivial.

@robamler
Copy link
Contributor Author

robamler commented May 4, 2025

FWIW, the optimization of pack4xU8 speeds up our motivating use case (same as in #7595 (comment)) by between 1.1% and 1.7%, with a standard error of at most 0.06% across all benchmarks. So the improvement here is much smaller than what we found for dot4I8Packed, but it's definitely statistically significant. Curiously, the impact of the optimizations of unpack4xU8 is within the error bars, even though this function is also used in our shader, and I verified manually that the vectorized code for it gets emitted. These are very ad-hoc benchmarks, and I ran them only on SPIR-V for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant