From 1f695b12101fd0baff8595accc2fc68e2b89bd38 Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sat, 3 May 2025 13:27:19 +0200 Subject: [PATCH 1/7] [naga] Vectorize `[un]pack4x{I, U}8[Clamp]` on spv Emits vectorized SPIR-V code for the WGSL functions `unpack4xI8`, `unpack4xU8`, `pack4xI8`, `pack4xU8`, `pack4xI8Clamp`, `pack4xU8Clamp`. Exploits the following facts about SPIR-V ops: - `SClamp`, `UClamp`, and `OpUConvert` accept vector arguments, in which case results are computed per component; and - `OpBitcast` can cast between vectors and scalars, with a well-defined bit order that matches that required by the WGSL spec, see below. WGSL spec for `pack4xI8` [1]: > Component e[i] of the input is mapped to bits 8 x i through 8 x i + 7 > of the result. SPIR-V spec for `OpBitcast` [2]: > Within this mapping, any single component of `S` [remark: the type > with fewer but wider components] (mapping to multiple components of > `L` [remark: the type with more but narrower components]) maps its > lower-ordered bits to the lower-numbered components of `L`. [1] https://www.w3.org/TR/WGSL/#pack4xI8-builtin [2] https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast --- naga/src/back/spv/block.rs | 226 ++++----- .../spv/wgsl-6772-unpack-expr-accesses.spvasm | 32 +- naga/tests/out/spv/wgsl-bits.spvasm | 463 ++++++++---------- 3 files changed, 317 insertions(+), 404 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 92edbcb05c..8c6ae97a22 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1557,102 +1557,81 @@ impl BlockContext<'_> { Mf::Pack4xU8 | Mf::Pack4xU8Clamp => (crate::ScalarKind::Uint, false), _ => unreachable!(), }; + let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp); - let uint_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); - let int_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { + let wide_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { kind: int_type, width: 4, - })); - - let mut last_instruction = Instruction::new(spirv::Op::Nop); - - let zero = self.writer.get_constant_scalar(crate::Literal::U32(0)); - let mut preresult = zero; - block - .body - .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed))); - - let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); - const VEC_LENGTH: u8 = 4; - for i in 0..u32::from(VEC_LENGTH) { - let offset = - self.writer.get_constant_scalar(crate::Literal::U32(i * 8)); - let mut extracted = self.gen_id(); - block.body.push(Instruction::binary( - spirv::Op::CompositeExtract, - int_type_id, - extracted, - arg0_id, - i, - )); - if is_signed { - let casted = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - uint_type_id, - casted, - extracted, - )); - extracted = casted; - } - if should_clamp { - let (min, max, clamp_op) = if is_signed { - ( - crate::Literal::I32(-128), - crate::Literal::I32(127), - spirv::GLOp::SClamp, - ) - } else { - ( - crate::Literal::U32(0), - crate::Literal::U32(255), - spirv::GLOp::UClamp, - ) - }; - let [min, max] = - [min, max].map(|lit| self.writer.get_constant_scalar(lit)); - - let clamp_id = self.gen_id(); - block.body.push(Instruction::ext_inst( - self.writer.gl450_ext_inst_id, - clamp_op, - result_type_id, - clamp_id, - &[extracted, min, max], - )); + }, + }); + let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: crate::ScalarKind::Uint, + width: 1, + }, + }); - extracted = clamp_id; - } - let is_last = i == u32::from(VEC_LENGTH - 1); - if is_last { - last_instruction = Instruction::quaternary( - spirv::Op::BitFieldInsert, - result_type_id, - id, - preresult, - extracted, - offset, - eight, + let mut wide_vector = arg0_id; + if should_clamp { + let (min, max, clamp_op) = if is_signed { + ( + crate::Literal::I32(-128), + crate::Literal::I32(127), + spirv::GLOp::SClamp, ) } else { - let new_preresult = self.gen_id(); - block.body.push(Instruction::quaternary( - spirv::Op::BitFieldInsert, - result_type_id, - new_preresult, - preresult, - extracted, - offset, - eight, + ( + crate::Literal::U32(0), + crate::Literal::U32(255), + spirv::GLOp::UClamp, + ) + }; + let [min, max] = [min, max].map(|lit| { + let scalar = self.writer.get_constant_scalar(lit); + // TODO: can we cache these constant vectors somehow? + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + wide_vector_type_id, + id, + &[scalar; 4], )); - preresult = new_preresult; - } + id + }); + + let clamp_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + clamp_op, + wide_vector_type_id, + clamp_id, + &[wide_vector, min, max], + )); + + wide_vector = clamp_id; } - MathOp::Custom(last_instruction) + let packed_vector = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically. + packed_vector_type_id, + packed_vector, + wide_vector, + )); + + // The SPIR-V spec [1] defines the bit order for bit casting between a vector + // and a scalar precisely as required by the WGSL spec [2]. + // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast + // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin + MathOp::Custom(Instruction::unary( + spirv::Op::Bitcast, + result_type_id, + id, + packed_vector, + )) } Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8), Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8), @@ -1660,59 +1639,38 @@ impl BlockContext<'_> { Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16), Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { - let (int_type, extract_op, is_signed) = match fun { - Mf::Unpack4xI8 => { - (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract, true) - } - Mf::Unpack4xU8 => { - (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract, false) - } + let (int_type, convert_op) = match fun { + Mf::Unpack4xI8 => (crate::ScalarKind::Sint, spirv::Op::SConvert), + Mf::Unpack4xU8 => (crate::ScalarKind::Uint, spirv::Op::UConvert), _ => unreachable!(), }; - let sint_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32)); - - let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); - let int_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { + let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { kind: int_type, - width: 4, - })); - block - .body - .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed)); - let arg_id = if is_signed { - let new_arg_id = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - sint_type_id, - new_arg_id, - arg0_id, - )); - new_arg_id - } else { - arg0_id - }; - - const VEC_LENGTH: u8 = 4; - let parts: [_; VEC_LENGTH as usize] = - core::array::from_fn(|_| self.gen_id()); - for (i, part_id) in parts.into_iter().enumerate() { - let index = self - .writer - .get_constant_scalar(crate::Literal::U32(i as u32 * 8)); - block.body.push(Instruction::ternary( - extract_op, - int_type_id, - part_id, - arg_id, - index, - eight, - )); - } + width: 1, + }, + }); + + // The SPIR-V spec [1] defines the bit order for bit casting between a vector + // and a scalar precisely as required by the WGSL spec [2]. + // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast + // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin + let packed_vector = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + packed_vector_type_id, + packed_vector, + arg0_id, + )); - MathOp::Custom(Instruction::composite_construct(result_type_id, id, &parts)) + MathOp::Custom(Instruction::unary( + convert_op, + result_type_id, + id, + packed_vector, + )) } }; diff --git a/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm b/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm index 973557789e..eb3edc3b36 100644 --- a/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm +++ b/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm @@ -1,8 +1,9 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 30 +; Bound: 23 OpCapability Shader +OpCapability Int8 %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %4 "main" @@ -14,27 +15,20 @@ OpExecutionMode %4 LocalSize 1 1 1 %8 = OpTypeInt 32 0 %9 = OpConstant %8 12 %11 = OpTypeVector %6 4 -%13 = OpConstant %8 8 -%19 = OpConstant %8 0 -%20 = OpConstant %8 16 -%21 = OpConstant %8 24 -%23 = OpTypeVector %8 4 +%14 = OpTypeInt 8 1 +%13 = OpTypeVector %14 4 +%17 = OpTypeVector %8 4 +%20 = OpTypeInt 8 0 +%19 = OpTypeVector %20 4 %4 = OpFunction %2 None %5 %3 = OpLabel OpBranch %10 %10 = OpLabel -%14 = OpBitcast %6 %9 -%15 = OpBitFieldSExtract %6 %14 %19 %13 -%16 = OpBitFieldSExtract %6 %14 %13 %13 -%17 = OpBitFieldSExtract %6 %14 %20 %13 -%18 = OpBitFieldSExtract %6 %14 %21 %13 -%12 = OpCompositeConstruct %11 %15 %16 %17 %18 -%22 = OpCompositeExtract %6 %12 2 -%25 = OpBitFieldUExtract %8 %9 %19 %13 -%26 = OpBitFieldUExtract %8 %9 %13 %13 -%27 = OpBitFieldUExtract %8 %9 %20 %13 -%28 = OpBitFieldUExtract %8 %9 %21 %13 -%24 = OpCompositeConstruct %23 %25 %26 %27 %28 -%29 = OpCompositeExtract %8 %24 1 +%15 = OpBitcast %13 %9 +%12 = OpSConvert %11 %15 +%16 = OpCompositeExtract %6 %12 2 +%21 = OpBitcast %19 %9 +%18 = OpUConvert %17 %21 +%22 = OpCompositeExtract %8 %18 1 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-bits.spvasm b/naga/tests/out/spv/wgsl-bits.spvasm index 76e221aea1..a55f4a9478 100644 --- a/naga/tests/out/spv/wgsl-bits.spvasm +++ b/naga/tests/out/spv/wgsl-bits.spvasm @@ -1,8 +1,9 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 275 +; Bound: 235 OpCapability Shader +OpCapability Int8 %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %15 "main" @@ -43,13 +44,14 @@ OpExecutionMode %15 LocalSize 1 1 1 %45 = OpTypePointer Function %10 %47 = OpTypePointer Function %11 %49 = OpTypePointer Function %13 -%63 = OpConstant %7 8 -%70 = OpConstant %7 16 -%74 = OpConstant %7 24 -%90 = OpConstant %3 -128 -%91 = OpConstant %3 127 -%108 = OpConstant %7 255 -%145 = OpConstant %7 32 +%64 = OpTypeInt 8 0 +%63 = OpTypeVector %64 4 +%71 = OpConstant %3 -128 +%73 = OpConstant %3 127 +%80 = OpConstant %7 255 +%97 = OpTypeInt 8 1 +%96 = OpTypeVector %97 4 +%105 = OpConstant %7 32 %15 = OpFunction %2 None %16 %14 = OpLabel %48 = OpVariable %49 Function %27 @@ -80,260 +82,219 @@ OpStore %38 %58 %60 = OpExtInst %7 %1 PackHalf2x16 %59 OpStore %38 %60 %61 = OpLoad %6 %36 -%64 = OpCompositeExtract %3 %61 0 -%65 = OpBitcast %7 %64 -%66 = OpBitFieldInsert %7 %21 %65 %21 %63 -%67 = OpCompositeExtract %3 %61 1 -%68 = OpBitcast %7 %67 -%69 = OpBitFieldInsert %7 %66 %68 %63 %63 -%71 = OpCompositeExtract %3 %61 2 -%72 = OpBitcast %7 %71 -%73 = OpBitFieldInsert %7 %69 %72 %70 %63 -%75 = OpCompositeExtract %3 %61 3 -%76 = OpBitcast %7 %75 -%62 = OpBitFieldInsert %7 %73 %76 %74 %63 +%65 = OpUConvert %63 %61 +%62 = OpBitcast %7 %65 OpStore %38 %62 +%66 = OpLoad %10 %44 +%68 = OpUConvert %63 %66 +%67 = OpBitcast %7 %68 +OpStore %38 %67 +%69 = OpLoad %6 %36 +%72 = OpCompositeConstruct %6 %71 %71 %71 %71 +%74 = OpCompositeConstruct %6 %73 %73 %73 %73 +%75 = OpExtInst %6 %1 SClamp %69 %72 %74 +%76 = OpUConvert %63 %75 +%70 = OpBitcast %7 %76 +OpStore %38 %70 %77 = OpLoad %10 %44 -%79 = OpCompositeExtract %7 %77 0 -%80 = OpBitFieldInsert %7 %21 %79 %21 %63 -%81 = OpCompositeExtract %7 %77 1 -%82 = OpBitFieldInsert %7 %80 %81 %63 %63 -%83 = OpCompositeExtract %7 %77 2 -%84 = OpBitFieldInsert %7 %82 %83 %70 %63 -%85 = OpCompositeExtract %7 %77 3 -%78 = OpBitFieldInsert %7 %84 %85 %74 %63 +%79 = OpCompositeConstruct %10 %21 %21 %21 %21 +%81 = OpCompositeConstruct %10 %80 %80 %80 %80 +%82 = OpExtInst %10 %1 UClamp %77 %79 %81 +%83 = OpUConvert %63 %82 +%78 = OpBitcast %7 %83 OpStore %38 %78 -%86 = OpLoad %6 %36 -%88 = OpCompositeExtract %3 %86 0 -%89 = OpBitcast %7 %88 -%92 = OpExtInst %7 %1 SClamp %89 %90 %91 -%93 = OpBitFieldInsert %7 %21 %92 %21 %63 -%94 = OpCompositeExtract %3 %86 1 -%95 = OpBitcast %7 %94 -%96 = OpExtInst %7 %1 SClamp %95 %90 %91 -%97 = OpBitFieldInsert %7 %93 %96 %63 %63 -%98 = OpCompositeExtract %3 %86 2 -%99 = OpBitcast %7 %98 -%100 = OpExtInst %7 %1 SClamp %99 %90 %91 -%101 = OpBitFieldInsert %7 %97 %100 %70 %63 -%102 = OpCompositeExtract %3 %86 3 -%103 = OpBitcast %7 %102 -%104 = OpExtInst %7 %1 SClamp %103 %90 %91 -%87 = OpBitFieldInsert %7 %101 %104 %74 %63 -OpStore %38 %87 -%105 = OpLoad %10 %44 -%107 = OpCompositeExtract %7 %105 0 -%109 = OpExtInst %7 %1 UClamp %107 %21 %108 -%110 = OpBitFieldInsert %7 %21 %109 %21 %63 -%111 = OpCompositeExtract %7 %105 1 -%112 = OpExtInst %7 %1 UClamp %111 %21 %108 -%113 = OpBitFieldInsert %7 %110 %112 %63 %63 -%114 = OpCompositeExtract %7 %105 2 -%115 = OpExtInst %7 %1 UClamp %114 %21 %108 -%116 = OpBitFieldInsert %7 %113 %115 %70 %63 -%117 = OpCompositeExtract %7 %105 3 -%118 = OpExtInst %7 %1 UClamp %117 %21 %108 -%106 = OpBitFieldInsert %7 %116 %118 %74 %63 -OpStore %38 %106 -%119 = OpLoad %7 %38 -%120 = OpExtInst %13 %1 UnpackSnorm4x8 %119 -OpStore %48 %120 -%121 = OpLoad %7 %38 -%122 = OpExtInst %13 %1 UnpackUnorm4x8 %121 -OpStore %48 %122 -%123 = OpLoad %7 %38 -%124 = OpExtInst %11 %1 UnpackSnorm2x16 %123 -OpStore %46 %124 -%125 = OpLoad %7 %38 -%126 = OpExtInst %11 %1 UnpackUnorm2x16 %125 -OpStore %46 %126 +%84 = OpLoad %7 %38 +%85 = OpExtInst %13 %1 UnpackSnorm4x8 %84 +OpStore %48 %85 +%86 = OpLoad %7 %38 +%87 = OpExtInst %13 %1 UnpackUnorm4x8 %86 +OpStore %48 %87 +%88 = OpLoad %7 %38 +%89 = OpExtInst %11 %1 UnpackSnorm2x16 %88 +OpStore %46 %89 +%90 = OpLoad %7 %38 +%91 = OpExtInst %11 %1 UnpackUnorm2x16 %90 +OpStore %46 %91 +%92 = OpLoad %7 %38 +%93 = OpExtInst %11 %1 UnpackHalf2x16 %92 +OpStore %46 %93 +%94 = OpLoad %7 %38 +%98 = OpBitcast %96 %94 +%95 = OpSConvert %6 %98 +OpStore %36 %95 +%99 = OpLoad %7 %38 +%101 = OpBitcast %63 %99 +%100 = OpUConvert %10 %101 +OpStore %44 %100 +%102 = OpLoad %3 %30 +%103 = OpLoad %3 %30 +%106 = OpExtInst %7 %1 UMin %28 %105 +%107 = OpISub %7 %105 %106 +%108 = OpExtInst %7 %1 UMin %29 %107 +%104 = OpBitFieldInsert %3 %102 %103 %106 %108 +OpStore %30 %104 +%109 = OpLoad %4 %32 +%110 = OpLoad %4 %32 +%112 = OpExtInst %7 %1 UMin %28 %105 +%113 = OpISub %7 %105 %112 +%114 = OpExtInst %7 %1 UMin %29 %113 +%111 = OpBitFieldInsert %4 %109 %110 %112 %114 +OpStore %32 %111 +%115 = OpLoad %5 %34 +%116 = OpLoad %5 %34 +%118 = OpExtInst %7 %1 UMin %28 %105 +%119 = OpISub %7 %105 %118 +%120 = OpExtInst %7 %1 UMin %29 %119 +%117 = OpBitFieldInsert %5 %115 %116 %118 %120 +OpStore %34 %117 +%121 = OpLoad %6 %36 +%122 = OpLoad %6 %36 +%124 = OpExtInst %7 %1 UMin %28 %105 +%125 = OpISub %7 %105 %124 +%126 = OpExtInst %7 %1 UMin %29 %125 +%123 = OpBitFieldInsert %6 %121 %122 %124 %126 +OpStore %36 %123 %127 = OpLoad %7 %38 -%128 = OpExtInst %11 %1 UnpackHalf2x16 %127 -OpStore %46 %128 -%129 = OpLoad %7 %38 -%131 = OpBitcast %3 %129 -%132 = OpBitFieldSExtract %3 %131 %21 %63 -%133 = OpBitFieldSExtract %3 %131 %63 %63 -%134 = OpBitFieldSExtract %3 %131 %70 %63 -%135 = OpBitFieldSExtract %3 %131 %74 %63 -%130 = OpCompositeConstruct %6 %132 %133 %134 %135 -OpStore %36 %130 -%136 = OpLoad %7 %38 -%138 = OpBitFieldUExtract %7 %136 %21 %63 -%139 = OpBitFieldUExtract %7 %136 %63 %63 -%140 = OpBitFieldUExtract %7 %136 %70 %63 -%141 = OpBitFieldUExtract %7 %136 %74 %63 -%137 = OpCompositeConstruct %10 %138 %139 %140 %141 -OpStore %44 %137 -%142 = OpLoad %3 %30 -%143 = OpLoad %3 %30 -%146 = OpExtInst %7 %1 UMin %28 %145 -%147 = OpISub %7 %145 %146 -%148 = OpExtInst %7 %1 UMin %29 %147 -%144 = OpBitFieldInsert %3 %142 %143 %146 %148 -OpStore %30 %144 -%149 = OpLoad %4 %32 -%150 = OpLoad %4 %32 -%152 = OpExtInst %7 %1 UMin %28 %145 -%153 = OpISub %7 %145 %152 -%154 = OpExtInst %7 %1 UMin %29 %153 -%151 = OpBitFieldInsert %4 %149 %150 %152 %154 -OpStore %32 %151 -%155 = OpLoad %5 %34 -%156 = OpLoad %5 %34 -%158 = OpExtInst %7 %1 UMin %28 %145 -%159 = OpISub %7 %145 %158 +%128 = OpLoad %7 %38 +%130 = OpExtInst %7 %1 UMin %28 %105 +%131 = OpISub %7 %105 %130 +%132 = OpExtInst %7 %1 UMin %29 %131 +%129 = OpBitFieldInsert %7 %127 %128 %130 %132 +OpStore %38 %129 +%133 = OpLoad %8 %40 +%134 = OpLoad %8 %40 +%136 = OpExtInst %7 %1 UMin %28 %105 +%137 = OpISub %7 %105 %136 +%138 = OpExtInst %7 %1 UMin %29 %137 +%135 = OpBitFieldInsert %8 %133 %134 %136 %138 +OpStore %40 %135 +%139 = OpLoad %9 %42 +%140 = OpLoad %9 %42 +%142 = OpExtInst %7 %1 UMin %28 %105 +%143 = OpISub %7 %105 %142 +%144 = OpExtInst %7 %1 UMin %29 %143 +%141 = OpBitFieldInsert %9 %139 %140 %142 %144 +OpStore %42 %141 +%145 = OpLoad %10 %44 +%146 = OpLoad %10 %44 +%148 = OpExtInst %7 %1 UMin %28 %105 +%149 = OpISub %7 %105 %148 +%150 = OpExtInst %7 %1 UMin %29 %149 +%147 = OpBitFieldInsert %10 %145 %146 %148 %150 +OpStore %44 %147 +%151 = OpLoad %3 %30 +%153 = OpExtInst %7 %1 UMin %28 %105 +%154 = OpISub %7 %105 %153 +%155 = OpExtInst %7 %1 UMin %29 %154 +%152 = OpBitFieldSExtract %3 %151 %153 %155 +OpStore %30 %152 +%156 = OpLoad %4 %32 +%158 = OpExtInst %7 %1 UMin %28 %105 +%159 = OpISub %7 %105 %158 %160 = OpExtInst %7 %1 UMin %29 %159 -%157 = OpBitFieldInsert %5 %155 %156 %158 %160 -OpStore %34 %157 -%161 = OpLoad %6 %36 -%162 = OpLoad %6 %36 -%164 = OpExtInst %7 %1 UMin %28 %145 -%165 = OpISub %7 %145 %164 -%166 = OpExtInst %7 %1 UMin %29 %165 -%163 = OpBitFieldInsert %6 %161 %162 %164 %166 -OpStore %36 %163 -%167 = OpLoad %7 %38 -%168 = OpLoad %7 %38 -%170 = OpExtInst %7 %1 UMin %28 %145 -%171 = OpISub %7 %145 %170 -%172 = OpExtInst %7 %1 UMin %29 %171 -%169 = OpBitFieldInsert %7 %167 %168 %170 %172 -OpStore %38 %169 -%173 = OpLoad %8 %40 -%174 = OpLoad %8 %40 -%176 = OpExtInst %7 %1 UMin %28 %145 -%177 = OpISub %7 %145 %176 -%178 = OpExtInst %7 %1 UMin %29 %177 -%175 = OpBitFieldInsert %8 %173 %174 %176 %178 -OpStore %40 %175 -%179 = OpLoad %9 %42 -%180 = OpLoad %9 %42 -%182 = OpExtInst %7 %1 UMin %28 %145 -%183 = OpISub %7 %145 %182 -%184 = OpExtInst %7 %1 UMin %29 %183 -%181 = OpBitFieldInsert %9 %179 %180 %182 %184 -OpStore %42 %181 -%185 = OpLoad %10 %44 +%157 = OpBitFieldSExtract %4 %156 %158 %160 +OpStore %32 %157 +%161 = OpLoad %5 %34 +%163 = OpExtInst %7 %1 UMin %28 %105 +%164 = OpISub %7 %105 %163 +%165 = OpExtInst %7 %1 UMin %29 %164 +%162 = OpBitFieldSExtract %5 %161 %163 %165 +OpStore %34 %162 +%166 = OpLoad %6 %36 +%168 = OpExtInst %7 %1 UMin %28 %105 +%169 = OpISub %7 %105 %168 +%170 = OpExtInst %7 %1 UMin %29 %169 +%167 = OpBitFieldSExtract %6 %166 %168 %170 +OpStore %36 %167 +%171 = OpLoad %7 %38 +%173 = OpExtInst %7 %1 UMin %28 %105 +%174 = OpISub %7 %105 %173 +%175 = OpExtInst %7 %1 UMin %29 %174 +%172 = OpBitFieldUExtract %7 %171 %173 %175 +OpStore %38 %172 +%176 = OpLoad %8 %40 +%178 = OpExtInst %7 %1 UMin %28 %105 +%179 = OpISub %7 %105 %178 +%180 = OpExtInst %7 %1 UMin %29 %179 +%177 = OpBitFieldUExtract %8 %176 %178 %180 +OpStore %40 %177 +%181 = OpLoad %9 %42 +%183 = OpExtInst %7 %1 UMin %28 %105 +%184 = OpISub %7 %105 %183 +%185 = OpExtInst %7 %1 UMin %29 %184 +%182 = OpBitFieldUExtract %9 %181 %183 %185 +OpStore %42 %182 %186 = OpLoad %10 %44 -%188 = OpExtInst %7 %1 UMin %28 %145 -%189 = OpISub %7 %145 %188 +%188 = OpExtInst %7 %1 UMin %28 %105 +%189 = OpISub %7 %105 %188 %190 = OpExtInst %7 %1 UMin %29 %189 -%187 = OpBitFieldInsert %10 %185 %186 %188 %190 +%187 = OpBitFieldUExtract %10 %186 %188 %190 OpStore %44 %187 %191 = OpLoad %3 %30 -%193 = OpExtInst %7 %1 UMin %28 %145 -%194 = OpISub %7 %145 %193 -%195 = OpExtInst %7 %1 UMin %29 %194 -%192 = OpBitFieldSExtract %3 %191 %193 %195 +%192 = OpExtInst %3 %1 FindILsb %191 OpStore %30 %192 -%196 = OpLoad %4 %32 -%198 = OpExtInst %7 %1 UMin %28 %145 -%199 = OpISub %7 %145 %198 -%200 = OpExtInst %7 %1 UMin %29 %199 -%197 = OpBitFieldSExtract %4 %196 %198 %200 -OpStore %32 %197 -%201 = OpLoad %5 %34 -%203 = OpExtInst %7 %1 UMin %28 %145 -%204 = OpISub %7 %145 %203 -%205 = OpExtInst %7 %1 UMin %29 %204 -%202 = OpBitFieldSExtract %5 %201 %203 %205 -OpStore %34 %202 -%206 = OpLoad %6 %36 -%208 = OpExtInst %7 %1 UMin %28 %145 -%209 = OpISub %7 %145 %208 -%210 = OpExtInst %7 %1 UMin %29 %209 -%207 = OpBitFieldSExtract %6 %206 %208 %210 -OpStore %36 %207 +%193 = OpLoad %8 %40 +%194 = OpExtInst %8 %1 FindILsb %193 +OpStore %40 %194 +%195 = OpLoad %5 %34 +%196 = OpExtInst %5 %1 FindSMsb %195 +OpStore %34 %196 +%197 = OpLoad %9 %42 +%198 = OpExtInst %9 %1 FindUMsb %197 +OpStore %42 %198 +%199 = OpLoad %3 %30 +%200 = OpExtInst %3 %1 FindSMsb %199 +OpStore %30 %200 +%201 = OpLoad %7 %38 +%202 = OpExtInst %7 %1 FindUMsb %201 +OpStore %38 %202 +%203 = OpLoad %3 %30 +%204 = OpBitCount %3 %203 +OpStore %30 %204 +%205 = OpLoad %4 %32 +%206 = OpBitCount %4 %205 +OpStore %32 %206 +%207 = OpLoad %5 %34 +%208 = OpBitCount %5 %207 +OpStore %34 %208 +%209 = OpLoad %6 %36 +%210 = OpBitCount %6 %209 +OpStore %36 %210 %211 = OpLoad %7 %38 -%213 = OpExtInst %7 %1 UMin %28 %145 -%214 = OpISub %7 %145 %213 -%215 = OpExtInst %7 %1 UMin %29 %214 -%212 = OpBitFieldUExtract %7 %211 %213 %215 +%212 = OpBitCount %7 %211 OpStore %38 %212 -%216 = OpLoad %8 %40 -%218 = OpExtInst %7 %1 UMin %28 %145 -%219 = OpISub %7 %145 %218 -%220 = OpExtInst %7 %1 UMin %29 %219 -%217 = OpBitFieldUExtract %8 %216 %218 %220 -OpStore %40 %217 -%221 = OpLoad %9 %42 -%223 = OpExtInst %7 %1 UMin %28 %145 -%224 = OpISub %7 %145 %223 -%225 = OpExtInst %7 %1 UMin %29 %224 -%222 = OpBitFieldUExtract %9 %221 %223 %225 -OpStore %42 %222 -%226 = OpLoad %10 %44 -%228 = OpExtInst %7 %1 UMin %28 %145 -%229 = OpISub %7 %145 %228 -%230 = OpExtInst %7 %1 UMin %29 %229 -%227 = OpBitFieldUExtract %10 %226 %228 %230 -OpStore %44 %227 -%231 = OpLoad %3 %30 -%232 = OpExtInst %3 %1 FindILsb %231 -OpStore %30 %232 -%233 = OpLoad %8 %40 -%234 = OpExtInst %8 %1 FindILsb %233 -OpStore %40 %234 -%235 = OpLoad %5 %34 -%236 = OpExtInst %5 %1 FindSMsb %235 -OpStore %34 %236 -%237 = OpLoad %9 %42 -%238 = OpExtInst %9 %1 FindUMsb %237 -OpStore %42 %238 -%239 = OpLoad %3 %30 -%240 = OpExtInst %3 %1 FindSMsb %239 -OpStore %30 %240 -%241 = OpLoad %7 %38 -%242 = OpExtInst %7 %1 FindUMsb %241 -OpStore %38 %242 -%243 = OpLoad %3 %30 -%244 = OpBitCount %3 %243 -OpStore %30 %244 -%245 = OpLoad %4 %32 -%246 = OpBitCount %4 %245 -OpStore %32 %246 -%247 = OpLoad %5 %34 -%248 = OpBitCount %5 %247 -OpStore %34 %248 -%249 = OpLoad %6 %36 -%250 = OpBitCount %6 %249 -OpStore %36 %250 -%251 = OpLoad %7 %38 -%252 = OpBitCount %7 %251 -OpStore %38 %252 -%253 = OpLoad %8 %40 -%254 = OpBitCount %8 %253 -OpStore %40 %254 -%255 = OpLoad %9 %42 -%256 = OpBitCount %9 %255 -OpStore %42 %256 -%257 = OpLoad %10 %44 -%258 = OpBitCount %10 %257 -OpStore %44 %258 -%259 = OpLoad %3 %30 -%260 = OpBitReverse %3 %259 -OpStore %30 %260 -%261 = OpLoad %4 %32 -%262 = OpBitReverse %4 %261 -OpStore %32 %262 -%263 = OpLoad %5 %34 -%264 = OpBitReverse %5 %263 -OpStore %34 %264 -%265 = OpLoad %6 %36 -%266 = OpBitReverse %6 %265 -OpStore %36 %266 -%267 = OpLoad %7 %38 -%268 = OpBitReverse %7 %267 -OpStore %38 %268 -%269 = OpLoad %8 %40 -%270 = OpBitReverse %8 %269 -OpStore %40 %270 -%271 = OpLoad %9 %42 -%272 = OpBitReverse %9 %271 -OpStore %42 %272 -%273 = OpLoad %10 %44 -%274 = OpBitReverse %10 %273 -OpStore %44 %274 +%213 = OpLoad %8 %40 +%214 = OpBitCount %8 %213 +OpStore %40 %214 +%215 = OpLoad %9 %42 +%216 = OpBitCount %9 %215 +OpStore %42 %216 +%217 = OpLoad %10 %44 +%218 = OpBitCount %10 %217 +OpStore %44 %218 +%219 = OpLoad %3 %30 +%220 = OpBitReverse %3 %219 +OpStore %30 %220 +%221 = OpLoad %4 %32 +%222 = OpBitReverse %4 %221 +OpStore %32 %222 +%223 = OpLoad %5 %34 +%224 = OpBitReverse %5 %223 +OpStore %34 %224 +%225 = OpLoad %6 %36 +%226 = OpBitReverse %6 %225 +OpStore %36 %226 +%227 = OpLoad %7 %38 +%228 = OpBitReverse %7 %227 +OpStore %38 %228 +%229 = OpLoad %8 %40 +%230 = OpBitReverse %8 %229 +OpStore %40 %230 +%231 = OpLoad %9 %42 +%232 = OpBitReverse %9 %231 +OpStore %42 %232 +%233 = OpLoad %10 %44 +%234 = OpBitReverse %10 %233 +OpStore %44 %234 OpReturn OpFunctionEnd \ No newline at end of file From ee63be44e3668fda2bc68859284743eb799c170e Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sat, 3 May 2025 20:56:26 +0200 Subject: [PATCH 2/7] [naga] Vectorize `[un]pack4x{I, U}8[Clamp]` on msl Implements more direct conversions between 32-bit integers and 4x8-bit integer vectors using bit casting to/from `packed_[u]char4` when on MSL 2.1+ (older versions of MSL don't seem to support these bit casts). - `unpack4x{I, U}8(x)` becomes `[u]int4(as_type(x))`; - `pack4x{I, U}8(x)` becomes `as_type(packed_[u]char4(x))`; and - `pack4x{I, U}8Clamp(x)` becomes `as_type(packed_uchar4(metal::clamp(x, 0, 255)))`. These bit casts match the WGSL spec for these functions because Metal runs on little-endian machines. --- naga/src/back/msl/writer.rs | 145 ++++++++++++------ naga/tests/in/wgsl/bits-optimized-msl.toml | 4 + naga/tests/in/wgsl/bits-optimized-msl.wgsl | 69 +++++++++ .../tests/out/msl/wgsl-bits-optimized-msl.msl | 137 +++++++++++++++++ 4 files changed, 308 insertions(+), 47 deletions(-) create mode 100644 naga/tests/in/wgsl/bits-optimized-msl.toml create mode 100644 naga/tests/in/wgsl/bits-optimized-msl.wgsl create mode 100644 naga/tests/out/msl/wgsl-bits-optimized-msl.msl diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f05e5c233a..cba177b5dd 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1497,6 +1497,63 @@ impl Writer { Ok(()) } + /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`. + fn put_pack4x8( + &mut self, + arg: Handle, + context: &ExpressionContext<'_>, + was_signed: bool, + clamp_bounds: Option<(&str, &str)>, + ) -> Result<(), Error> { + if context.lang_version >= (2, 1) { + let packed_type = if was_signed { + "packed_char4" + } else { + "packed_uchar4" + }; + // Metal uses little endian byte order, which matches what WGSL expects here. + write!(self.out, "as_type({packed_type}(")?; + if let Some((min, max)) = clamp_bounds { + // Clamping a vector to scalar bounds works and operates component-wise. + write!(self.out, "{NAMESPACE}::clamp(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", {min}, {max})")?; + } else { + self.put_expression(arg, context, true)?; + } + write!(self.out, "))")?; + } else { + // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. + if was_signed { + write!(self.out, "uint(")?; + } + let write_arg = |this: &mut Self| -> BackendResult { + if let Some((min, max)) = clamp_bounds { + write!(this.out, "{NAMESPACE}::clamp(")?; + this.put_expression(arg, context, true)?; + write!(this.out, ", {min}, {max})")?; + } else { + this.put_expression(arg, context, true)?; + } + Ok(()) + }; + write!(self.out, "(")?; + write_arg(self)?; + write!(self.out, "[0] & 0xFF) | ((")?; + write_arg(self)?; + write!(self.out, "[1] & 0xFF) << 8) | ((")?; + write_arg(self)?; + write!(self.out, "[2] & 0xFF) << 16) | ((")?; + write_arg(self)?; + write!(self.out, "[3] & 0xFF) << 24)")?; + if was_signed { + write!(self.out, ")")?; + } + } + + Ok(()) + } + /// Emit code for the isign expression. /// fn put_isign( @@ -2437,53 +2494,41 @@ impl Writer { write!(self.out, "{fun_name}")?; self.put_call_parameters(iter::once(arg), context)?; } - fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => { - let was_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp); - let clamp_bounds = match fun { - Mf::Pack4xI8Clamp => Some(("-128", "127")), - Mf::Pack4xU8Clamp => Some(("0", "255")), - _ => None, - }; - if was_signed { - write!(self.out, "uint(")?; - } - let write_arg = |this: &mut Self| -> BackendResult { - if let Some((min, max)) = clamp_bounds { - write!(this.out, "{NAMESPACE}::clamp(")?; - this.put_expression(arg, context, true)?; - write!(this.out, ", {min}, {max})")?; - } else { - this.put_expression(arg, context, true)?; - } - Ok(()) - }; - write!(self.out, "(")?; - write_arg(self)?; - write!(self.out, "[0] & 0xFF) | ((")?; - write_arg(self)?; - write!(self.out, "[1] & 0xFF) << 8) | ((")?; - write_arg(self)?; - write!(self.out, "[2] & 0xFF) << 16) | ((")?; - write_arg(self)?; - write!(self.out, "[3] & 0xFF) << 24)")?; - if was_signed { - write!(self.out, ")")?; - } + Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?, + Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?, + Mf::Pack4xI8Clamp => { + self.put_pack4x8(arg, context, true, Some(("-128", "127")))? + } + Mf::Pack4xU8Clamp => { + self.put_pack4x8(arg, context, false, Some(("0", "255")))? } fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { - write!(self.out, "(")?; - if matches!(fun, Mf::Unpack4xU8) { - write!(self.out, "u")?; + let sign_prefix = if matches!(fun, Mf::Unpack4xU8) { + "u" + } else { + "" + }; + + if context.lang_version >= (2, 1) { + // Metal uses little endian byte order, which matches what WGSL expects here. + write!( + self.out, + "{sign_prefix}int4(as_type(" + )?; + self.put_expression(arg, context, true)?; + write!(self.out, "))")?; + } else { + // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. + write!(self.out, "({sign_prefix}int4(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 8, ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 16, ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 24) << 24 >> 24)")?; } - write!(self.out, "int4(")?; - self.put_expression(arg, context, true)?; - write!(self.out, ", ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 8, ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 16, ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 24) << 24 >> 24)")?; } Mf::QuantizeToF16 => { match *context.resolve_type(arg) { @@ -3226,14 +3271,20 @@ impl Writer { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); } - crate::MathFunction::FirstLeadingBit - | crate::MathFunction::Pack4xI8 + crate::MathFunction::FirstLeadingBit => { + self.need_bake_expressions.insert(arg); + } + crate::MathFunction::Pack4xI8 | crate::MathFunction::Pack4xU8 | crate::MathFunction::Pack4xI8Clamp | crate::MathFunction::Pack4xU8Clamp | crate::MathFunction::Unpack4xI8 | crate::MathFunction::Unpack4xU8 => { - self.need_bake_expressions.insert(arg); + // On MSL < 2.1, we emit a polyfill for these functions that uses the + // argument multiple times. This is no longer necessary on MSL >= 2.1. + if context.lang_version < (2, 1) { + self.need_bake_expressions.insert(arg); + } } crate::MathFunction::ExtractBits => { // Only argument 1 is re-used. diff --git a/naga/tests/in/wgsl/bits-optimized-msl.toml b/naga/tests/in/wgsl/bits-optimized-msl.toml new file mode 100644 index 0000000000..9409d2ac77 --- /dev/null +++ b/naga/tests/in/wgsl/bits-optimized-msl.toml @@ -0,0 +1,4 @@ +targets = "METAL" + +[msl] +lang_version = [2, 1] diff --git a/naga/tests/in/wgsl/bits-optimized-msl.wgsl b/naga/tests/in/wgsl/bits-optimized-msl.wgsl new file mode 100644 index 0000000000..a77266ad34 --- /dev/null +++ b/naga/tests/in/wgsl/bits-optimized-msl.wgsl @@ -0,0 +1,69 @@ +// Keep in sync with `bits_downlevel` and `bits_downlevel_webgl` + +@compute @workgroup_size(1) +fn main() { + var i = 0; + var i2 = vec2(0); + var i3 = vec3(0); + var i4 = vec4(0); + var u = 0u; + var u2 = vec2(0u); + var u3 = vec3(0u); + var u4 = vec4(0u); + var f2 = vec2(0.0); + var f4 = vec4(0.0); + u = pack4x8snorm(f4); + u = pack4x8unorm(f4); + u = pack2x16snorm(f2); + u = pack2x16unorm(f2); + u = pack2x16float(f2); + u = pack4xI8(i4); + u = pack4xU8(u4); + u = pack4xI8Clamp(i4); + u = pack4xU8Clamp(u4); + f4 = unpack4x8snorm(u); + f4 = unpack4x8unorm(u); + f2 = unpack2x16snorm(u); + f2 = unpack2x16unorm(u); + f2 = unpack2x16float(u); + i4 = unpack4xI8(u); + u4 = unpack4xU8(u); + i = insertBits(i, i, 5u, 10u); + i2 = insertBits(i2, i2, 5u, 10u); + i3 = insertBits(i3, i3, 5u, 10u); + i4 = insertBits(i4, i4, 5u, 10u); + u = insertBits(u, u, 5u, 10u); + u2 = insertBits(u2, u2, 5u, 10u); + u3 = insertBits(u3, u3, 5u, 10u); + u4 = insertBits(u4, u4, 5u, 10u); + i = extractBits(i, 5u, 10u); + i2 = extractBits(i2, 5u, 10u); + i3 = extractBits(i3, 5u, 10u); + i4 = extractBits(i4, 5u, 10u); + u = extractBits(u, 5u, 10u); + u2 = extractBits(u2, 5u, 10u); + u3 = extractBits(u3, 5u, 10u); + u4 = extractBits(u4, 5u, 10u); + i = firstTrailingBit(i); + u2 = firstTrailingBit(u2); + i3 = firstLeadingBit(i3); + u3 = firstLeadingBit(u3); + i = firstLeadingBit(i); + u = firstLeadingBit(u); + i = countOneBits(i); + i2 = countOneBits(i2); + i3 = countOneBits(i3); + i4 = countOneBits(i4); + u = countOneBits(u); + u2 = countOneBits(u2); + u3 = countOneBits(u3); + u4 = countOneBits(u4); + i = reverseBits(i); + i2 = reverseBits(i2); + i3 = reverseBits(i3); + i4 = reverseBits(i4); + u = reverseBits(u); + u2 = reverseBits(u2); + u3 = reverseBits(u3); + u4 = reverseBits(u4); +} diff --git a/naga/tests/out/msl/wgsl-bits-optimized-msl.msl b/naga/tests/out/msl/wgsl-bits-optimized-msl.msl new file mode 100644 index 0000000000..e33ed65f46 --- /dev/null +++ b/naga/tests/out/msl/wgsl-bits-optimized-msl.msl @@ -0,0 +1,137 @@ +// language: metal2.1 +#include +#include + +using metal::uint; + + +kernel void main_( +) { + int i = 0; + metal::int2 i2_ = metal::int2(0); + metal::int3 i3_ = metal::int3(0); + metal::int4 i4_ = metal::int4(0); + uint u = 0u; + metal::uint2 u2_ = metal::uint2(0u); + metal::uint3 u3_ = metal::uint3(0u); + metal::uint4 u4_ = metal::uint4(0u); + metal::float2 f2_ = metal::float2(0.0); + metal::float4 f4_ = metal::float4(0.0); + metal::float4 _e28 = f4_; + u = metal::pack_float_to_snorm4x8(_e28); + metal::float4 _e30 = f4_; + u = metal::pack_float_to_unorm4x8(_e30); + metal::float2 _e32 = f2_; + u = metal::pack_float_to_snorm2x16(_e32); + metal::float2 _e34 = f2_; + u = metal::pack_float_to_unorm2x16(_e34); + metal::float2 _e36 = f2_; + u = as_type(half2(_e36)); + metal::int4 _e38 = i4_; + u = as_type(packed_char4(_e38)); + metal::uint4 _e40 = u4_; + u = as_type(packed_uchar4(_e40)); + metal::int4 _e42 = i4_; + u = as_type(packed_char4(metal::clamp(_e42, -128, 127))); + metal::uint4 _e44 = u4_; + u = as_type(packed_uchar4(metal::clamp(_e44, 0, 255))); + uint _e46 = u; + f4_ = metal::unpack_snorm4x8_to_float(_e46); + uint _e48 = u; + f4_ = metal::unpack_unorm4x8_to_float(_e48); + uint _e50 = u; + f2_ = metal::unpack_snorm2x16_to_float(_e50); + uint _e52 = u; + f2_ = metal::unpack_unorm2x16_to_float(_e52); + uint _e54 = u; + f2_ = float2(as_type(_e54)); + uint _e56 = u; + i4_ = int4(as_type(_e56)); + uint _e58 = u; + u4_ = uint4(as_type(_e58)); + int _e60 = i; + int _e61 = i; + i = metal::insert_bits(_e60, _e61, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int2 _e65 = i2_; + metal::int2 _e66 = i2_; + i2_ = metal::insert_bits(_e65, _e66, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int3 _e70 = i3_; + metal::int3 _e71 = i3_; + i3_ = metal::insert_bits(_e70, _e71, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int4 _e75 = i4_; + metal::int4 _e76 = i4_; + i4_ = metal::insert_bits(_e75, _e76, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + uint _e80 = u; + uint _e81 = u; + u = metal::insert_bits(_e80, _e81, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint2 _e85 = u2_; + metal::uint2 _e86 = u2_; + u2_ = metal::insert_bits(_e85, _e86, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint3 _e90 = u3_; + metal::uint3 _e91 = u3_; + u3_ = metal::insert_bits(_e90, _e91, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint4 _e95 = u4_; + metal::uint4 _e96 = u4_; + u4_ = metal::insert_bits(_e95, _e96, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + int _e100 = i; + i = metal::extract_bits(_e100, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int2 _e104 = i2_; + i2_ = metal::extract_bits(_e104, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int3 _e108 = i3_; + i3_ = metal::extract_bits(_e108, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int4 _e112 = i4_; + i4_ = metal::extract_bits(_e112, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + uint _e116 = u; + u = metal::extract_bits(_e116, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint2 _e120 = u2_; + u2_ = metal::extract_bits(_e120, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint3 _e124 = u3_; + u3_ = metal::extract_bits(_e124, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint4 _e128 = u4_; + u4_ = metal::extract_bits(_e128, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + int _e132 = i; + i = (((metal::ctz(_e132) + 1) % 33) - 1); + metal::uint2 _e134 = u2_; + u2_ = (((metal::ctz(_e134) + 1) % 33) - 1); + metal::int3 _e136 = i3_; + i3_ = metal::select(31 - metal::clz(metal::select(_e136, ~_e136, _e136 < 0)), int3(-1), _e136 == 0 || _e136 == -1); + metal::uint3 _e138 = u3_; + u3_ = metal::select(31 - metal::clz(_e138), uint3(-1), _e138 == 0 || _e138 == -1); + int _e140 = i; + i = metal::select(31 - metal::clz(metal::select(_e140, ~_e140, _e140 < 0)), int(-1), _e140 == 0 || _e140 == -1); + uint _e142 = u; + u = metal::select(31 - metal::clz(_e142), uint(-1), _e142 == 0 || _e142 == -1); + int _e144 = i; + i = metal::popcount(_e144); + metal::int2 _e146 = i2_; + i2_ = metal::popcount(_e146); + metal::int3 _e148 = i3_; + i3_ = metal::popcount(_e148); + metal::int4 _e150 = i4_; + i4_ = metal::popcount(_e150); + uint _e152 = u; + u = metal::popcount(_e152); + metal::uint2 _e154 = u2_; + u2_ = metal::popcount(_e154); + metal::uint3 _e156 = u3_; + u3_ = metal::popcount(_e156); + metal::uint4 _e158 = u4_; + u4_ = metal::popcount(_e158); + int _e160 = i; + i = metal::reverse_bits(_e160); + metal::int2 _e162 = i2_; + i2_ = metal::reverse_bits(_e162); + metal::int3 _e164 = i3_; + i3_ = metal::reverse_bits(_e164); + metal::int4 _e166 = i4_; + i4_ = metal::reverse_bits(_e166); + uint _e168 = u; + u = metal::reverse_bits(_e168); + metal::uint2 _e170 = u2_; + u2_ = metal::reverse_bits(_e170); + metal::uint3 _e172 = u3_; + u3_ = metal::reverse_bits(_e172); + metal::uint4 _e174 = u4_; + u4_ = metal::reverse_bits(_e174); + return; +} From c22e8fa1a5e3181010303e9528dde609cd482034 Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sat, 3 May 2025 22:43:04 +0200 Subject: [PATCH 3/7] Add changelog entry for #7664 --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 951238009a..187b89e43d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,7 @@ Naga now infers the correct binding layout when a resource appears only in an as - Apply necessary automatic conversions to the `value` argument of `textureStore`. By @jimblandy in [#7567](https://github.com/gfx-rs/wgpu/pull/7567). - Properly apply WGSL's automatic conversions to the arguments to texture sampling functions. By @jimblandy in [#7548](https://github.com/gfx-rs/wgpu/pull/7548). - Properly evaluate `abs(most negative abstract int)`. By @jimblandy in [#7507](https://github.com/gfx-rs/wgpu/pull/7507). +- Generate vectorized code for `[un]pack4x{I,U}8[Clamp]` on SPIR-V and MSL 2.1+. By @robamler in [#7664](https://github.com/gfx-rs/wgpu/pull/7664). #### DX12 From 62c32b3133af31485afd7309bdbc3c5fd09e56bf Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sat, 3 May 2025 22:52:49 +0200 Subject: [PATCH 4/7] Require `Capability::Int8` for vectorized [un]pack --- naga/src/back/spv/block.rs | 426 ++++++++++++++++++++++++++++--------- 1 file changed, 323 insertions(+), 103 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 8c6ae97a22..f2423407b7 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1552,86 +1552,31 @@ impl BlockContext<'_> { Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16), Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16), fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => { - let (int_type, is_signed) = match fun { - Mf::Pack4xI8 | Mf::Pack4xI8Clamp => (crate::ScalarKind::Sint, true), - Mf::Pack4xU8 | Mf::Pack4xU8Clamp => (crate::ScalarKind::Uint, false), - _ => unreachable!(), - }; - + let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp); let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp); - let wide_vector_type_id = self.get_numeric_type_id(NumericType::Vector { - size: crate::VectorSize::Quad, - scalar: crate::Scalar { - kind: int_type, - width: 4, - }, - }); - let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { - size: crate::VectorSize::Quad, - scalar: crate::Scalar { - kind: crate::ScalarKind::Uint, - width: 1, - }, - }); - - let mut wide_vector = arg0_id; - if should_clamp { - let (min, max, clamp_op) = if is_signed { - ( - crate::Literal::I32(-128), - crate::Literal::I32(127), - spirv::GLOp::SClamp, + let last_instruction = + if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() { + self.write_pack4x8_optimized( + block, + result_type_id, + arg0_id, + id, + is_signed, + should_clamp, ) } else { - ( - crate::Literal::U32(0), - crate::Literal::U32(255), - spirv::GLOp::UClamp, + self.write_pack4x8_polyfill( + block, + result_type_id, + arg0_id, + id, + is_signed, + should_clamp, ) }; - let [min, max] = [min, max].map(|lit| { - let scalar = self.writer.get_constant_scalar(lit); - // TODO: can we cache these constant vectors somehow? - let id = self.gen_id(); - block.body.push(Instruction::composite_construct( - wide_vector_type_id, - id, - &[scalar; 4], - )); - id - }); - - let clamp_id = self.gen_id(); - block.body.push(Instruction::ext_inst( - self.writer.gl450_ext_inst_id, - clamp_op, - wide_vector_type_id, - clamp_id, - &[wide_vector, min, max], - )); - - wide_vector = clamp_id; - } - - let packed_vector = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically. - packed_vector_type_id, - packed_vector, - wide_vector, - )); - // The SPIR-V spec [1] defines the bit order for bit casting between a vector - // and a scalar precisely as required by the WGSL spec [2]. - // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast - // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin - MathOp::Custom(Instruction::unary( - spirv::Op::Bitcast, - result_type_id, - id, - packed_vector, - )) + MathOp::Custom(last_instruction) } Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8), Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8), @@ -1639,38 +1584,28 @@ impl BlockContext<'_> { Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16), Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { - let (int_type, convert_op) = match fun { - Mf::Unpack4xI8 => (crate::ScalarKind::Sint, spirv::Op::SConvert), - Mf::Unpack4xU8 => (crate::ScalarKind::Uint, spirv::Op::UConvert), - _ => unreachable!(), - }; + let is_signed = matches!(fun, Mf::Unpack4xI8); - let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { - size: crate::VectorSize::Quad, - scalar: crate::Scalar { - kind: int_type, - width: 1, - }, - }); - - // The SPIR-V spec [1] defines the bit order for bit casting between a vector - // and a scalar precisely as required by the WGSL spec [2]. - // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast - // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin - let packed_vector = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - packed_vector_type_id, - packed_vector, - arg0_id, - )); + let last_instruction = + if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() { + self.write_unpack4x8_optimized( + block, + result_type_id, + arg0_id, + id, + is_signed, + ) + } else { + self.write_unpack4x8_polyfill( + block, + result_type_id, + arg0_id, + id, + is_signed, + ) + }; - MathOp::Custom(Instruction::unary( - convert_op, - result_type_id, - id, - packed_vector, - )) + MathOp::Custom(last_instruction) } }; @@ -2679,6 +2614,291 @@ impl BlockContext<'_> { } } + /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is available. + fn write_pack4x8_optimized( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + should_clamp: bool, + ) -> Instruction { + let int_type = if is_signed { + crate::ScalarKind::Sint + } else { + crate::ScalarKind::Uint + }; + let wide_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: int_type, + width: 4, + }, + }); + let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: crate::ScalarKind::Uint, + width: 1, + }, + }); + + let mut wide_vector = arg0_id; + if should_clamp { + let (min, max, clamp_op) = if is_signed { + ( + crate::Literal::I32(-128), + crate::Literal::I32(127), + spirv::GLOp::SClamp, + ) + } else { + ( + crate::Literal::U32(0), + crate::Literal::U32(255), + spirv::GLOp::UClamp, + ) + }; + let [min, max] = [min, max].map(|lit| { + let scalar = self.writer.get_constant_scalar(lit); + // TODO: can we cache these constant vectors somehow? + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + wide_vector_type_id, + id, + &[scalar; 4], + )); + id + }); + + let clamp_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + clamp_op, + wide_vector_type_id, + clamp_id, + &[wide_vector, min, max], + )); + + wide_vector = clamp_id; + } + + let packed_vector = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically. + packed_vector_type_id, + packed_vector, + wide_vector, + )); + + // The SPIR-V spec [1] defines the bit order for bit casting between a vector + // and a scalar precisely as required by the WGSL spec [2]. + // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast + // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin + Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector) + } + + /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is not available. + fn write_pack4x8_polyfill( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + should_clamp: bool, + ) -> Instruction { + let int_type = if is_signed { + crate::ScalarKind::Sint + } else { + crate::ScalarKind::Uint + }; + let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); + let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { + kind: int_type, + width: 4, + })); + + let mut last_instruction = Instruction::new(spirv::Op::Nop); + + let zero = self.writer.get_constant_scalar(crate::Literal::U32(0)); + let mut preresult = zero; + block + .body + .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed))); + + let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); + const VEC_LENGTH: u8 = 4; + for i in 0..u32::from(VEC_LENGTH) { + let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8)); + let mut extracted = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::CompositeExtract, + int_type_id, + extracted, + arg0_id, + i, + )); + if is_signed { + let casted = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + uint_type_id, + casted, + extracted, + )); + extracted = casted; + } + if should_clamp { + let (min, max, clamp_op) = if is_signed { + ( + crate::Literal::I32(-128), + crate::Literal::I32(127), + spirv::GLOp::SClamp, + ) + } else { + ( + crate::Literal::U32(0), + crate::Literal::U32(255), + spirv::GLOp::UClamp, + ) + }; + let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit)); + + let clamp_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + clamp_op, + result_type_id, + clamp_id, + &[extracted, min, max], + )); + + extracted = clamp_id; + } + let is_last = i == u32::from(VEC_LENGTH - 1); + if is_last { + last_instruction = Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + id, + preresult, + extracted, + offset, + eight, + ) + } else { + let new_preresult = self.gen_id(); + block.body.push(Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + new_preresult, + preresult, + extracted, + offset, + eight, + )); + preresult = new_preresult; + } + } + last_instruction + } + + /// Emit code for `unpack4x{I,U}8` if capability "Int8" is available. + fn write_unpack4x8_optimized( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + ) -> Instruction { + let (int_type, convert_op) = if is_signed { + (crate::ScalarKind::Sint, spirv::Op::SConvert) + } else { + (crate::ScalarKind::Uint, spirv::Op::UConvert) + }; + + let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: int_type, + width: 1, + }, + }); + + // The SPIR-V spec [1] defines the bit order for bit casting between a vector + // and a scalar precisely as required by the WGSL spec [2]. + // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast + // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin + let packed_vector = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + packed_vector_type_id, + packed_vector, + arg0_id, + )); + + Instruction::unary(convert_op, result_type_id, id, packed_vector) + } + + /// Emit code for `unpack4x{I,U}8` if capability "Int8" is not available. + fn write_unpack4x8_polyfill( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + ) -> Instruction { + let (int_type, extract_op) = if is_signed { + (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract) + } else { + (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract) + }; + + let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32)); + + let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); + let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { + kind: int_type, + width: 4, + })); + block + .body + .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed)); + let arg_id = if is_signed { + let new_arg_id = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + sint_type_id, + new_arg_id, + arg0_id, + )); + new_arg_id + } else { + arg0_id + }; + + const VEC_LENGTH: u8 = 4; + let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id()); + for (i, part_id) in parts.into_iter().enumerate() { + let index = self + .writer + .get_constant_scalar(crate::Literal::U32(i as u32 * 8)); + block.body.push(Instruction::ternary( + extract_op, + int_type_id, + part_id, + arg_id, + index, + eight, + )); + } + + Instruction::composite_construct(result_type_id, id, &parts) + } + /// Generate one or more SPIR-V blocks for `naga_block`. /// /// Use `label_id` as the label for the SPIR-V entry point block. From e7ebc1b3b86a2a3c8dd98da7f5dd10dd1f32d743 Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sun, 4 May 2025 22:11:16 +0200 Subject: [PATCH 5/7] [wgpu-hal] separate 2 float16-related vk features Separates the Vulkan feature sets `VkPhysicalDeviceShaderFloat16Int8Features` and `VkPhysicalDevice16BitStorageFeatures`, which previously were used "together, or not at all". This commit should not change any behavior yet, but I'd like to run full CI tests on it for now. If the CI tests pass, I'll use this separation to enable the `shader_int8` feature separately from the rest of the features to enable optimizations of `[un]pack4x{I,U}8[Clamp]` on SPIR-V. --- wgpu-hal/src/vulkan/adapter.rs | 56 +++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index e9ae1e597a..49dcd19aed 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -62,13 +62,11 @@ pub struct PhysicalDeviceFeatures { /// Features provided by `VK_EXT_texture_compression_astc_hdr`, promoted to Vulkan 1.3. astc_hdr: Option>, - /// Features provided by `VK_KHR_shader_float16_int8` (promoted to Vulkan - /// 1.2) and `VK_KHR_16bit_storage` (promoted to Vulkan 1.1). We use these - /// features together, or not at all. - shader_float16: Option<( - vk::PhysicalDeviceShaderFloat16Int8Features<'static>, - vk::PhysicalDevice16BitStorageFeatures<'static>, - )>, + /// Features provided by `VK_KHR_shader_float16_int8`, promoted to Vulkan 1.2 + shader_float16_int8: Option>, + + /// Features provided by `VK_KHR_16bit_storage`, promoted to Vulkan 1.1 + features_16bit_storage: Option>, /// Features provided by `VK_KHR_acceleration_structure`. acceleration_structure: Option>, @@ -154,9 +152,11 @@ impl PhysicalDeviceFeatures { if let Some(ref mut feature) = self.astc_hdr { info = info.push_next(feature); } - if let Some((ref mut f16_i8_feature, ref mut _16bit_feature)) = self.shader_float16 { - info = info.push_next(f16_i8_feature); - info = info.push_next(_16bit_feature); + if let Some(ref mut feature) = self.shader_float16_int8 { + info = info.push_next(feature); + } + if let Some(ref mut feature) = self.features_16bit_storage { + info = info.push_next(feature); } if let Some(ref mut feature) = self.zero_initialize_workgroup_memory { info = info.push_next(feature); @@ -386,14 +386,18 @@ impl PhysicalDeviceFeatures { } else { None }, - shader_float16: if requested_features.contains(wgt::Features::SHADER_F16) { - Some(( - vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true), + shader_float16_int8: if requested_features.contains(wgt::Features::SHADER_F16) { + Some(vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true)) + } else { + None + }, + features_16bit_storage: if requested_features.contains(wgt::Features::SHADER_F16) { + Some( vk::PhysicalDevice16BitStorageFeatures::default() .storage_buffer16_bit_access(true) .storage_input_output16(true) .uniform_and_storage_buffer16_bit_access(true), - )) + ) } else { None }, @@ -724,7 +728,9 @@ impl PhysicalDeviceFeatures { ); } - if let Some((ref f16_i8, ref bit16)) = self.shader_float16 { + if let (Some(ref f16_i8), Some(ref bit16)) = + (self.shader_float16_int8, self.features_16bit_storage) + { features.set( F::SHADER_F16, f16_i8.shader_float16 != 0 @@ -1474,15 +1480,17 @@ impl super::InstanceShared { .insert(vk::PhysicalDeviceTextureCompressionASTCHDRFeaturesEXT::default()); features2 = features2.push_next(next); } - if capabilities.supports_extension(khr::shader_float16_int8::NAME) - && capabilities.supports_extension(khr::_16bit_storage::NAME) - { - let next = features.shader_float16.insert(( - vk::PhysicalDeviceShaderFloat16Int8FeaturesKHR::default(), - vk::PhysicalDevice16BitStorageFeaturesKHR::default(), - )); - features2 = features2.push_next(&mut next.0); - features2 = features2.push_next(&mut next.1); + if capabilities.supports_extension(khr::shader_float16_int8::NAME) { + let next = features + .shader_float16_int8 + .insert(vk::PhysicalDeviceShaderFloat16Int8FeaturesKHR::default()); + features2 = features2.push_next(next); + } + if capabilities.supports_extension(khr::_16bit_storage::NAME) { + let next = features + .features_16bit_storage + .insert(vk::PhysicalDevice16BitStorageFeaturesKHR::default()); + features2 = features2.push_next(next); } if capabilities.supports_extension(khr::acceleration_structure::NAME) { let next = features From c197de1b5764e2b08e8e4c9cf85f131c056a6803 Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sun, 4 May 2025 23:44:57 +0200 Subject: [PATCH 6/7] Rename a field to follow convention set by `ash` --- wgpu-hal/src/vulkan/adapter.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 49dcd19aed..da54d3187f 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -66,7 +66,7 @@ pub struct PhysicalDeviceFeatures { shader_float16_int8: Option>, /// Features provided by `VK_KHR_16bit_storage`, promoted to Vulkan 1.1 - features_16bit_storage: Option>, + _16bit_storage: Option>, /// Features provided by `VK_KHR_acceleration_structure`. acceleration_structure: Option>, @@ -155,7 +155,7 @@ impl PhysicalDeviceFeatures { if let Some(ref mut feature) = self.shader_float16_int8 { info = info.push_next(feature); } - if let Some(ref mut feature) = self.features_16bit_storage { + if let Some(ref mut feature) = self._16bit_storage { info = info.push_next(feature); } if let Some(ref mut feature) = self.zero_initialize_workgroup_memory { @@ -391,7 +391,7 @@ impl PhysicalDeviceFeatures { } else { None }, - features_16bit_storage: if requested_features.contains(wgt::Features::SHADER_F16) { + _16bit_storage: if requested_features.contains(wgt::Features::SHADER_F16) { Some( vk::PhysicalDevice16BitStorageFeatures::default() .storage_buffer16_bit_access(true) @@ -728,8 +728,7 @@ impl PhysicalDeviceFeatures { ); } - if let (Some(ref f16_i8), Some(ref bit16)) = - (self.shader_float16_int8, self.features_16bit_storage) + if let (Some(ref f16_i8), Some(ref bit16)) = (self.shader_float16_int8, self._16bit_storage) { features.set( F::SHADER_F16, @@ -1488,7 +1487,7 @@ impl super::InstanceShared { } if capabilities.supports_extension(khr::_16bit_storage::NAME) { let next = features - .features_16bit_storage + ._16bit_storage .insert(vk::PhysicalDevice16BitStorageFeaturesKHR::default()); features2 = features2.push_next(next); } From 659df1853a61e2e76cd4db8f9c3b52aac7ad5338 Mon Sep 17 00:00:00 2001 From: Robert Bamler Date: Sun, 4 May 2025 23:42:00 +0200 Subject: [PATCH 7/7] [wgpu-hal] Add `PrivateCapabilities::shader_int8` on Vulkan This allows declaring the SPIR-V capability "Int8", which allows us to generate faster code for `[un]pack4x{I, U}8[Clamp]`. --- wgpu-hal/src/vulkan/adapter.rs | 43 +++++++++++++++++++++++++--------- wgpu-hal/src/vulkan/mod.rs | 15 ++++++++++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index da54d3187f..7e7ca926e9 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -386,10 +386,13 @@ impl PhysicalDeviceFeatures { } else { None }, - shader_float16_int8: if requested_features.contains(wgt::Features::SHADER_F16) { - Some(vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true)) - } else { - None + shader_float16_int8: match requested_features.contains(wgt::Features::SHADER_F16) { + shader_float16 if shader_float16 || private_caps.shader_int8 => Some( + vk::PhysicalDeviceShaderFloat16Int8Features::default() + .shader_float16(shader_float16) + .shader_int8(private_caps.shader_int8), + ), + _ => None, }, _16bit_storage: if requested_features.contains(wgt::Features::SHADER_F16) { Some( @@ -981,6 +984,15 @@ impl PhysicalDeviceProperties { if requested_features.contains(wgt::Features::TEXTURE_FORMAT_NV12) { extensions.push(khr::sampler_ycbcr_conversion::NAME); } + + // Require `VK_KHR_16bit_storage` if the feature `SHADER_F16` was requested + if requested_features.contains(wgt::Features::SHADER_F16) { + // - Feature `SHADER_F16` also requires `VK_KHR_shader_float16_int8`, but we always + // require that anyway (if it is available) below. + // - `VK_KHR_16bit_storage` requires `VK_KHR_storage_buffer_storage_class`, however + // we require that one already. + extensions.push(khr::_16bit_storage::NAME); + } } if self.device_api_version < vk::API_VERSION_1_2 { @@ -1004,13 +1016,10 @@ impl PhysicalDeviceProperties { extensions.push(ext::descriptor_indexing::NAME); } - // Require `VK_KHR_shader_float16_int8` and `VK_KHR_16bit_storage` if the associated feature was requested - if requested_features.contains(wgt::Features::SHADER_F16) { + // Always require `VK_KHR_shader_float16_int8` if available as it enables Int8 + // optimizations. It is also _needed_ if `wgt::Features::SHADER_F16` is requested. + if self.supports_extension(khr::shader_float16_int8::NAME) { extensions.push(khr::shader_float16_int8::NAME); - // `VK_KHR_16bit_storage` requires `VK_KHR_storage_buffer_storage_class`, however we require that one already - if self.device_api_version < vk::API_VERSION_1_1 { - extensions.push(khr::_16bit_storage::NAME); - } } if requested_features.intersects(wgt::Features::EXPERIMENTAL_MESH_SHADER) { @@ -1479,12 +1488,17 @@ impl super::InstanceShared { .insert(vk::PhysicalDeviceTextureCompressionASTCHDRFeaturesEXT::default()); features2 = features2.push_next(next); } - if capabilities.supports_extension(khr::shader_float16_int8::NAME) { + + // `VK_KHR_shader_float16_int8` is promoted to 1.2 + if capabilities.device_api_version >= vk::API_VERSION_1_2 + || capabilities.supports_extension(khr::shader_float16_int8::NAME) + { let next = features .shader_float16_int8 .insert(vk::PhysicalDeviceShaderFloat16Int8FeaturesKHR::default()); features2 = features2.push_next(next); } + if capabilities.supports_extension(khr::_16bit_storage::NAME) { let next = features ._16bit_storage @@ -1728,6 +1742,9 @@ impl super::Instance { shader_integer_dot_product: phd_features .shader_integer_dot_product .is_some_and(|ext| ext.shader_integer_dot_product != 0), + shader_int8: phd_features + .shader_float16_int8 + .is_some_and(|features| features.shader_int8 != 0), }; let capabilities = crate::Capabilities { limits: phd_capabilities.to_wgpu_limits(), @@ -2029,6 +2046,10 @@ impl super::Adapter { spv::Capability::DotProductKHR, ]); } + if self.private_caps.shader_int8 { + // See . + capabilities.extend(&[spv::Capability::Int8]); + } spv::Options { lang_version: match self.phd_capabilities.device_api_version { // Use maximum supported SPIR-V version according to diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index b492f33987..47c91e1d1b 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -536,6 +536,21 @@ struct PrivateCapabilities { /// /// [`VK_KHR_shader_integer_dot_product`]: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_shader_integer_dot_product.html shader_integer_dot_product: bool, + + /// True if this adapter supports 8-bit integers provided by the + /// [`VK_KHR_shader_float16_int8`] extension (promoted to Vulkan 1.2). + /// + /// Allows shaders to declare the "Int8" capability. Note, however, that this + /// feature alone allows the use of 8-bit integers "only in the `Private`, + /// `Workgroup` (for non-Block variables), and `Function` storage classes" + /// ([see spec]). To use 8-bit integers in the interface storage classes (e.g., + /// `StorageBuffer`), you also need to enable the corresponding feature in + /// `VkPhysicalDevice8BitStorageFeatures` and declare the corresponding SPIR-V + /// capability (e.g., `StorageBuffer8BitAccess`). + /// + /// [`VK_KHR_shader_float16_int8`]: https://registry.khronos.org/vulkan/specs/latest/man/html/VK_KHR_shader_float16_int8.html + /// [see spec]: https://registry.khronos.org/vulkan/specs/latest/man/html/VkPhysicalDeviceShaderFloat16Int8Features.html#extension-features-shaderInt8 + shader_int8: bool, } bitflags::bitflags!(