diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index a2ad70a010..5c7d94b838 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -3,13 +3,13 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use super::CodegenCx; use crate::abi::ConvSpirvType; -use crate::attr::{AggregatedSpirvAttributes, Entry, Spanned, SpecConstant}; +use crate::attr::{AggregatedSpirvAttributes, Entry, ExecutionModeExtra, Spanned, SpecConstant}; use crate::builder::Builder; use crate::builder_spirv::{SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; use rspirv::dr::Operand; use rspirv::spirv::{ - Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word, + Capability, Decoration, Dim, ExecutionMode, ExecutionModel, FunctionControl, StorageClass, Word, }; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_data_structures::fx::FxHashMap; @@ -139,6 +139,7 @@ impl<'tcx> CodegenCx<'tcx> { hir_params, name, entry.execution_model, + &entry.execution_modes, ); let mut emit = self.emit_global(); entry @@ -157,6 +158,7 @@ impl<'tcx> CodegenCx<'tcx> { hir_params: &[hir::Param<'tcx>], name: String, execution_model: ExecutionModel, + execution_modes: &[(ExecutionMode, ExecutionModeExtra)], ) -> Word { let stub_fn = { let void = SpirvType::Void.def(span, self); @@ -182,6 +184,7 @@ impl<'tcx> CodegenCx<'tcx> { bx.set_span(hir_param.span); self.declare_shader_interface_for_param( execution_model, + execution_modes, entry_arg_abi, hir_param, &mut op_entry_point_interface_operands, @@ -419,6 +422,7 @@ impl<'tcx> CodegenCx<'tcx> { fn declare_shader_interface_for_param( &self, execution_model: ExecutionModel, + execution_modes: &[(ExecutionMode, ExecutionModeExtra)], entry_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>, hir_param: &hir::Param<'tcx>, op_entry_point_interface_operands: &mut Vec, @@ -428,6 +432,21 @@ impl<'tcx> CodegenCx<'tcx> { ) { let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id)); + // Handle WorkgroupSize builtin specially as it's a constant, not a variable + if let Some(builtin) = &attrs.builtin { + if builtin.value == rspirv::spirv::BuiltIn::WorkgroupSize { + self.handle_workgroup_size_builtin( + execution_modes, + entry_arg_abi, + hir_param, + &attrs, + bx, + call_args, + ); + return; + } + } + let EntryParamDeducedFromRustRefOrValue { value_layout, storage_class, @@ -984,4 +1003,78 @@ impl<'tcx> CodegenCx<'tcx> { } } } + + fn handle_workgroup_size_builtin( + &self, + execution_modes: &[(ExecutionMode, ExecutionModeExtra)], + entry_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>, + hir_param: &hir::Param<'tcx>, + attrs: &AggregatedSpirvAttributes, + _bx: &mut Builder<'_, 'tcx>, + call_args: &mut Vec, + ) { + // Find the LocalSize execution mode + let local_size = execution_modes.iter().find_map(|(mode, extra)| { + if *mode == ExecutionMode::LocalSize { + Some(extra.as_ref()) + } else { + None + } + }); + + let local_size_values = match local_size { + Some(values) => values, + None => { + self.tcx.dcx().span_err( + attrs.builtin.as_ref().unwrap().span, + "WorkgroupSize builtin requires LocalSize execution mode to be set (e.g., #[spirv(compute(threads(x, y, z)))])", + ); + return; + } + }; + + // Validate the parameter type and get the SPIR-V type + let value_spirv_type = entry_arg_abi.layout.spirv_type(hir_param.ty_span, self); + let expected_element_type = SpirvType::Integer(32, false).def(hir_param.span, self); + let expected_vec_type = SpirvType::Vector { + element: expected_element_type, + count: 3, + } + .def(hir_param.span, self); + + // Verify the type matches what we expect (Vec3) + if value_spirv_type != expected_vec_type { + self.tcx.dcx().span_err( + hir_param.ty_span, + "WorkgroupSize builtin must have type UVec3", + ); + return; + } + + let components: Vec<_> = local_size_values + .iter() + .map(|&size| self.constant_u32(hir_param.span, size).def_cx(self)) + .collect(); + + let const_composite_id = self + .emit_global() + .constant_composite(expected_vec_type, components); + + // Decorate with BuiltIn WorkgroupSize + self.emit_global().decorate( + const_composite_id, + Decoration::BuiltIn, + std::iter::once(Operand::BuiltIn(rspirv::spirv::BuiltIn::WorkgroupSize)), + ); + + // Emit OpName + if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { + self.emit_global() + .name(const_composite_id, ident.to_string()); + } + + // Add the constant as the call argument + let value = const_composite_id.with_type(expected_vec_type); + call_args.push(value); + } } diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index 8b1c610a32..418e89f4c4 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -59,7 +59,7 @@ const BUILTINS: &[(&str, BuiltIn)] = { ("frag_depth", FragDepth), ("helper_invocation", HelperInvocation), ("num_workgroups", NumWorkgroups), - // ("workgroup_size", WorkgroupSize), -- constant + ("workgroup_size", WorkgroupSize), ("workgroup_id", WorkgroupId), ("local_invocation_id", LocalInvocationId), ("global_invocation_id", GlobalInvocationId), diff --git a/tests/compiletests/ui/dis/workgroup-size.rs b/tests/compiletests/ui/dis/workgroup-size.rs new file mode 100644 index 0000000000..ce71d23c8c --- /dev/null +++ b/tests/compiletests/ui/dis/workgroup-size.rs @@ -0,0 +1,20 @@ +#![crate_name = "workgroup_size"] + +// Tests that the WorkgroupSize builtin is correctly generated as a constant. + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" + +use spirv_std::glam::UVec3; +use spirv_std::spirv; + +#[spirv(compute(threads(8, 4, 2)))] +pub fn main(#[spirv(workgroup_size)] size: UVec3, #[spirv(local_invocation_id)] local_id: UVec3) { + // The workgroup_size should be (8, 4, 2) + // Using the size parameter ensures it's included in the generated SPIR-V + let _total = size.x + size.y + size.z + local_id.x; +} diff --git a/tests/compiletests/ui/dis/workgroup-size.stderr b/tests/compiletests/ui/dis/workgroup-size.stderr new file mode 100644 index 0000000000..00ceae0fae --- /dev/null +++ b/tests/compiletests/ui/dis/workgroup-size.stderr @@ -0,0 +1,27 @@ +OpCapability Shader +OpCapability Float64 +OpCapability Int64 +OpCapability Int16 +OpCapability Int8 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "main" %2 +OpExecutionMode %1 LocalSize 8 4 2 +%3 = OpString "$OPSTRING_FILENAME/workgroup-size.rs" +OpName %2 "local_id" +OpName %4 "size" +OpName %5 "workgroup_size::main" +OpDecorate %2 BuiltIn LocalInvocationId +OpDecorate %4 BuiltIn WorkgroupSize +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %6 3 +%8 = OpTypePointer Input %7 +%9 = OpTypeVoid +%10 = OpTypeFunction %9 +%2 = OpVariable %8 Input +%11 = OpTypeFunction %9 %7 %7 +%12 = OpConstant %6 8 +%13 = OpConstant %6 4 +%14 = OpConstant %6 2 +%4 = OpConstantComposite %7 %12 %13 %14 diff --git a/tests/compiletests/ui/spirv-attr/all-builtins.rs b/tests/compiletests/ui/spirv-attr/all-builtins.rs index cd83259d39..a4a642c754 100644 --- a/tests/compiletests/ui/spirv-attr/all-builtins.rs +++ b/tests/compiletests/ui/spirv-attr/all-builtins.rs @@ -35,6 +35,7 @@ pub fn compute( #[spirv(num_workgroups)] num_workgroups: UVec3, #[spirv(subgroup_id)] subgroup_id: u32, #[spirv(workgroup_id)] workgroup_id: UVec3, + #[spirv(workgroup_size)] workgroup_size: UVec3, #[spirv(workgroup)] workgroup_local_memory: &mut [u32; 256], ) { } diff --git a/tests/compiletests/ui/spirv-attr/workgroup-size.rs b/tests/compiletests/ui/spirv-attr/workgroup-size.rs new file mode 100644 index 0000000000..ac8daee106 --- /dev/null +++ b/tests/compiletests/ui/spirv-attr/workgroup-size.rs @@ -0,0 +1,12 @@ +// build-pass + +use spirv_std::glam::UVec3; +use spirv_std::spirv; + +#[spirv(compute(threads(8, 4, 2)))] +pub fn main(#[spirv(workgroup_size)] size: UVec3, #[spirv(local_invocation_id)] local_id: UVec3) { + // The workgroup_size should be (8, 4, 2) + assert!(size.x == 8); + assert!(size.y == 4); + assert!(size.z == 2); +} diff --git a/tests/difftests/tests/Cargo.lock b/tests/difftests/tests/Cargo.lock index 686e22f50d..031100c4bd 100644 --- a/tests/difftests/tests/Cargo.lock +++ b/tests/difftests/tests/Cargo.lock @@ -1400,6 +1400,21 @@ dependencies = [ "bitflags 2.9.0", ] +[[package]] +name = "workgroup-size-rust" +version = "0.0.0" +dependencies = [ + "difftest", + "spirv-std", +] + +[[package]] +name = "workgroup-size-wgsl" +version = "0.0.0" +dependencies = [ + "difftest", +] + [[package]] name = "xml-rs" version = "0.8.25" diff --git a/tests/difftests/tests/Cargo.toml b/tests/difftests/tests/Cargo.toml index 34f08f9250..4322d04d9e 100644 --- a/tests/difftests/tests/Cargo.toml +++ b/tests/difftests/tests/Cargo.toml @@ -3,6 +3,8 @@ resolver = "2" members = [ "simple-compute/simple-compute-rust", "simple-compute/simple-compute-wgsl", + "workgroup-size/workgroup-size-rust", + "workgroup-size/workgroup-size-wgsl", ] [workspace.package] diff --git a/tests/difftests/tests/workgroup-size/workgroup-size-rust/Cargo.toml b/tests/difftests/tests/workgroup-size/workgroup-size-rust/Cargo.toml new file mode 100644 index 0000000000..64bf45656e --- /dev/null +++ b/tests/difftests/tests/workgroup-size/workgroup-size-rust/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "workgroup-size-rust" +edition.workspace = true + +[lints] +workspace = true + +[lib] +crate-type = ["dylib"] + +# Common deps +[dependencies] + +# GPU deps +spirv-std.workspace = true + +# CPU deps +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/workgroup-size/workgroup-size-rust/src/lib.rs b/tests/difftests/tests/workgroup-size/workgroup-size-rust/src/lib.rs new file mode 100644 index 0000000000..ac1898ca63 --- /dev/null +++ b/tests/difftests/tests/workgroup-size/workgroup-size-rust/src/lib.rs @@ -0,0 +1,30 @@ +#![no_std] + +use spirv_std::glam::UVec3; +use spirv_std::spirv; + +#[spirv(compute(threads(8, 4, 2)))] +pub fn main_cs( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32], + #[spirv(workgroup_size)] workgroup_size: UVec3, + #[spirv(global_invocation_id)] global_id: UVec3, + #[spirv(local_invocation_id)] local_id: UVec3, +) { + let idx = global_id.x as usize; + + if idx < output.len() { + // Store a value that encodes the workgroup dimensions + // This allows us to verify that the workgroup_size builtin is working correctly + let encoded = (workgroup_size.x << 16) | (workgroup_size.y << 8) | workgroup_size.z; + + // Also encode the local invocation ID to show it's within the workgroup bounds + let local_encoded = (local_id.x << 16) | (local_id.y << 8) | local_id.z; + + // Store: encoded workgroup size in even indices, local ID in odd indices + if idx % 2 == 0 { + output[idx] = encoded; // Should be (8 << 16) | (4 << 8) | 2 = 0x080402 + } else { + output[idx] = local_encoded; + } + } +} diff --git a/tests/difftests/tests/workgroup-size/workgroup-size-rust/src/main.rs b/tests/difftests/tests/workgroup-size/workgroup-size-rust/src/main.rs new file mode 100644 index 0000000000..08ce97487c --- /dev/null +++ b/tests/difftests/tests/workgroup-size/workgroup-size-rust/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{RustComputeShader, WgpuComputeTest}; + +fn main() { + // Load the config from the harness. + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Define test parameters, loading the rust shader from the current crate. + // Dispatch 2x2x1 workgroups with workgroup size [8, 4, 2] + // This gives us a total of 2*2*1 * 8*4*2 = 256 invocations + let test = WgpuComputeTest::new(RustComputeShader::default(), [2, 2, 1], 1024); + + // Run the test and write the output to a file. + test.run_test(&config).unwrap(); +} diff --git a/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/Cargo.toml b/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/Cargo.toml new file mode 100644 index 0000000000..17f6be9ecc --- /dev/null +++ b/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "workgroup-size-wgsl" +edition.workspace = true + +[lints] +workspace = true + +[[bin]] +name = "workgroup-size-wgsl" +path = "src/main.rs" + +[dependencies] +difftest.workspace = true \ No newline at end of file diff --git a/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/shader.wgsl b/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/shader.wgsl new file mode 100644 index 0000000000..b1dcbe371a --- /dev/null +++ b/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/shader.wgsl @@ -0,0 +1,33 @@ +@group(0) @binding(0) +var output: array; + +// Define workgroup dimensions as named constants +const WORKGROUP_SIZE_X: u32 = 8u; +const WORKGROUP_SIZE_Y: u32 = 4u; +const WORKGROUP_SIZE_Z: u32 = 2u; + +@compute @workgroup_size(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y, WORKGROUP_SIZE_Z) +fn main_cs( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let idx = global_id.x; + + if (idx < arrayLength(&output)) { + // Use the named constants to create the workgroup_size vector + let workgroup_size = vec3(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y, WORKGROUP_SIZE_Z); + + // Store a value that encodes the workgroup dimensions + let encoded = (workgroup_size.x << 16u) | (workgroup_size.y << 8u) | workgroup_size.z; + + // Also encode the local invocation ID + let local_encoded = (local_id.x << 16u) | (local_id.y << 8u) | local_id.z; + + // Store: encoded workgroup size in even indices, local ID in odd indices + if (idx % 2u == 0u) { + output[idx] = encoded; // Should be (8 << 16) | (4 << 8) | 2 = 0x080402 + } else { + output[idx] = local_encoded; + } + } +} diff --git a/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/src/main.rs b/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/src/main.rs new file mode 100644 index 0000000000..f3450a45d2 --- /dev/null +++ b/tests/difftests/tests/workgroup-size/workgroup-size-wgsl/src/main.rs @@ -0,0 +1,15 @@ +use difftest::config::Config; +use difftest::scaffold::compute::{WgpuComputeTest, WgslComputeShader}; + +fn main() { + // Load the config from the harness. + let config = Config::from_path(std::env::args().nth(1).unwrap()).unwrap(); + + // Define test parameters, loading the wgsl shader from the crate directory. + // Dispatch 2x2x1 workgroups with workgroup size [8, 4, 2] + // This gives us a total of 2*2*1 * 8*4*2 = 256 invocations + let test = WgpuComputeTest::new(WgslComputeShader::default(), [2, 2, 1], 1024); + + // Run the test and write the output to a file. + test.run_test(&config).unwrap(); +}