diff --git a/CHANGELOG.md b/CHANGELOG.md index 894da6ddc7..f1f65d4bf0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,6 +82,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216). - Support constant evaluation for `firstLeadingBit` and `firstTrailingBit` numeric built-ins in WGSL. Front-ends that translate to these built-ins also benefit from constant evaluation. By @ErichDonGubler in [#5101](https://github.com/gfx-rs/wgpu/pull/5101). - Add `first` and `either` sampling types for `@interpolate(flat, …)` in WGSL. By @ErichDonGubler in [#6181](https://github.com/gfx-rs/wgpu/pull/6181). - Support for more atomic ops in the SPIR-V frontend. By @schell in [#5824](https://github.com/gfx-rs/wgpu/pull/5824). +- Support pipeline overrides in WGSL backend. By @cbbowen in [#6310](https://github.com/gfx-rs/wgpu/pull/6310). #### General diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 0f2635eb0d..466aad2844 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -106,12 +106,6 @@ impl Writer { } pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { - if !module.overrides.is_empty() { - return Err(Error::Unimplemented( - "Pipeline constants are not yet supported for this back-end".to_string(), - )); - } - self.reset(module); // Save all ep result types @@ -147,6 +141,14 @@ impl Writer { } } + // Write all overrides + if !module.overrides.is_empty() { + for (handle, _) in module.overrides.iter() { + self.write_override(module, handle)?; + } + writeln!(self.out)?; + } + // Write all globals for (ty, global) in module.global_variables.iter() { self.write_global(module, global, ty)?; @@ -1260,6 +1262,9 @@ impl Writer { self.write_const_expression(module, constant.init)?; } } + Expression::Override(handle) => { + write!(self.out, "{}", self.names[&NameKey::Override(handle)])?; + } Expression::ZeroValue(ty) => { self.write_type(module, ty)?; write!(self.out, "()")?; @@ -1281,6 +1286,12 @@ impl Writer { write_expression(self, value)?; write!(self.out, ")")?; } + Expression::Unary { op, expr } => { + self.write_unary_expression(op, expr, write_expression)? + } + Expression::Binary { op, left, right } => { + self.write_binary_expression(op, left, right, write_expression)? + } _ => unreachable!(), } @@ -1331,18 +1342,18 @@ impl Writer { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } - Expression::Override(_) => unreachable!(), + Expression::Override(handle) => { + write!(self.out, "{}", self.names[&NameKey::Override(handle)])?; + } Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; write!(self.out, "{name}")?; } Expression::Binary { op, left, right } => { - write!(self.out, "(")?; - self.write_expr(module, left, func_ctx)?; - write!(self.out, " {} ", back::binary_operation_str(op))?; - self.write_expr(module, right, func_ctx)?; - write!(self.out, ")")?; + self.write_binary_expression(op, left, right, |writer, expr| { + writer.write_expr(module, expr, func_ctx) + })? } Expression::Access { base, index } => { self.write_expr_with_indirection(module, base, func_ctx, indirection)?; @@ -1766,16 +1777,9 @@ impl Writer { } } Expression::Unary { op, expr } => { - let unary = match op { - crate::UnaryOperator::Negate => "-", - crate::UnaryOperator::LogicalNot => "!", - crate::UnaryOperator::BitwiseNot => "~", - }; - - write!(self.out, "{unary}(")?; - self.write_expr(module, expr, func_ctx)?; - - write!(self.out, ")")? + self.write_unary_expression(op, expr, |writer, expr| { + writer.write_expr(module, expr, func_ctx) + })? } Expression::Select { @@ -1836,6 +1840,49 @@ impl Writer { Ok(()) } + /// Helper method used to write unary expressions + fn write_unary_expression( + &mut self, + op: crate::UnaryOperator, + expr: Handle, + write_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, Handle) -> BackendResult, + { + let unary = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => "!", + crate::UnaryOperator::BitwiseNot => "~", + }; + + write!(self.out, "{unary}(")?; + write_expression(self, expr)?; + write!(self.out, ")")?; + + Ok(()) + } + + /// Helper method used to write binary expressions + fn write_binary_expression( + &mut self, + op: crate::BinaryOperator, + left: Handle, + right: Handle, + write_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, Handle) -> BackendResult, + { + write!(self.out, "(")?; + write_expression(self, left)?; + write!(self.out, " {} ", back::binary_operation_str(op))?; + write_expression(self, right)?; + write!(self.out, ")")?; + + Ok(()) + } + /// Helper method used to write global variables /// # Notes /// Always adds a newline @@ -1906,6 +1953,33 @@ impl Writer { Ok(()) } + /// Helper method used to write override declarations + /// + /// # Notes + /// Ends in a newline + fn write_override( + &mut self, + module: &Module, + handle: Handle, + ) -> BackendResult { + let r#override = &module.overrides[handle]; + let name = self.namer.call(r#override.name.as_deref().unwrap_or("")); + if let Some(id) = r#override.id { + write!(self.out, "@id({id}) ")?; + } + write!(self.out, "override {name}: ")?; + self.write_type(module, r#override.ty)?; + if let Some(init) = r#override.init { + write!(self.out, " = ")?; + self.write_const_expression(module, init)?; + } + writeln!(self.out, ";")?; + + self.names.insert(NameKey::Override(handle), name); + + Ok(()) + } + // See https://github.com/rust-lang/rust-clippy/issues/4979. #[allow(clippy::missing_const_for_fn)] pub fn finish(self) -> W { diff --git a/naga/src/proc/namer.rs b/naga/src/proc/namer.rs index 8afacb593d..dbbd3c26c0 100644 --- a/naga/src/proc/namer.rs +++ b/naga/src/proc/namer.rs @@ -8,6 +8,7 @@ const SEPARATOR: char = '_'; #[derive(Debug, Eq, Hash, PartialEq)] pub enum NameKey { Constant(Handle), + Override(Handle), GlobalVariable(Handle), Type(Handle), StructMember(Handle, u32), diff --git a/naga/tests/out/wgsl/overrides.wgsl b/naga/tests/out/wgsl/overrides.wgsl new file mode 100644 index 0000000000..37af12e505 --- /dev/null +++ b/naga/tests/out/wgsl/overrides.wgsl @@ -0,0 +1,24 @@ +@id(0) override has_point_light: bool = true; +@id(1200) override specular_param: f32 = 2.3f; +@id(1300) override gain: f32; +override width: f32 = 0f; +override depth: f32; +override height: f32 = (2f * depth); +override inferred_f32_: f32 = 2.718f; + +var gain_x_10_: f32 = (gain * 10f); +var store_override: f32; + +@compute @workgroup_size(1, 1, 1) +fn main() { + var t: f32 = (height * 5f); + var x: bool; + var gain_x_100_: f32; + + let a = !(has_point_light); + x = a; + let _e7 = gain_x_10_; + gain_x_100_ = (_e7 * 10f); + store_override = gain; + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 936203986d..1962ca3eed 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -913,7 +913,8 @@ fn convert_wgsl() { | Targets::SPIRV | Targets::METAL | Targets::HLSL - | Targets::GLSL, + | Targets::GLSL + | Targets::WGSL, ), ( "overrides-atomicCompareExchangeWeak",