Skip to content

[naga] Support pipeline overrides in the WGSL backend #6310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- Support constant evaluation for `firstLeadingBit` and `firstTrailingBit` numeric built-ins in WGSL. Front-ends that translate to these built-ins also benefit from constant evaluation. By @ErichDonGubler in [#5101](https://github.com/gfx-rs/wgpu/pull/5101).
- Add `first` and `either` sampling types for `@interpolate(flat, …)` in WGSL. By @ErichDonGubler in [#6181](https://github.com/gfx-rs/wgpu/pull/6181).
- Support for more atomic ops in the SPIR-V frontend. By @schell in [#5824](https://github.com/gfx-rs/wgpu/pull/5824).
- Support pipeline overrides in WGSL backend. By @cbbowen in [#6310](https://github.com/gfx-rs/wgpu/pull/6310).

#### General

Expand Down
118 changes: 96 additions & 22 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,6 @@ impl<W: Write> Writer<W> {
}

pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
if !module.overrides.is_empty() {
return Err(Error::Unimplemented(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}

self.reset(module);

// Save all ep result types
Expand Down Expand Up @@ -147,6 +141,14 @@ impl<W: Write> Writer<W> {
}
}

// Write all overrides
if !module.overrides.is_empty() {
for (handle, _) in module.overrides.iter() {
self.write_override(module, handle)?;
}
writeln!(self.out)?;
}

// Write all globals
for (ty, global) in module.global_variables.iter() {
self.write_global(module, global, ty)?;
Expand Down Expand Up @@ -1260,6 +1262,9 @@ impl<W: Write> Writer<W> {
self.write_const_expression(module, constant.init)?;
}
}
Expression::Override(handle) => {
write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
}
Expression::ZeroValue(ty) => {
self.write_type(module, ty)?;
write!(self.out, "()")?;
Expand All @@ -1281,6 +1286,12 @@ impl<W: Write> Writer<W> {
write_expression(self, value)?;
write!(self.out, ")")?;
}
Expression::Unary { op, expr } => {
self.write_unary_expression(op, expr, write_expression)?
}
Expression::Binary { op, left, right } => {
self.write_binary_expression(op, left, right, write_expression)?
}
_ => unreachable!(),
}

Expand Down Expand Up @@ -1331,18 +1342,18 @@ impl<W: Write> Writer<W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => unreachable!(),
Expression::Override(handle) => {
write!(self.out, "{}", self.names[&NameKey::Override(handle)])?;
}
Expression::FunctionArgument(pos) => {
let name_key = func_ctx.argument_key(pos);
let name = &self.names[&name_key];
write!(self.out, "{name}")?;
}
Expression::Binary { op, left, right } => {
write!(self.out, "(")?;
self.write_expr(module, left, func_ctx)?;
write!(self.out, " {} ", back::binary_operation_str(op))?;
self.write_expr(module, right, func_ctx)?;
write!(self.out, ")")?;
self.write_binary_expression(op, left, right, |writer, expr| {
writer.write_expr(module, expr, func_ctx)
})?
}
Expression::Access { base, index } => {
self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
Expand Down Expand Up @@ -1766,16 +1777,9 @@ impl<W: Write> Writer<W> {
}
}
Expression::Unary { op, expr } => {
let unary = match op {
crate::UnaryOperator::Negate => "-",
crate::UnaryOperator::LogicalNot => "!",
crate::UnaryOperator::BitwiseNot => "~",
};

write!(self.out, "{unary}(")?;
self.write_expr(module, expr, func_ctx)?;

write!(self.out, ")")?
self.write_unary_expression(op, expr, |writer, expr| {
writer.write_expr(module, expr, func_ctx)
})?
}

Expression::Select {
Expand Down Expand Up @@ -1836,6 +1840,49 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Helper method used to write unary expressions
fn write_unary_expression<E>(
&mut self,
op: crate::UnaryOperator,
expr: Handle<crate::Expression>,
write_expression: E,
) -> BackendResult
where
E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
{
let unary = match op {
crate::UnaryOperator::Negate => "-",
crate::UnaryOperator::LogicalNot => "!",
crate::UnaryOperator::BitwiseNot => "~",
};

write!(self.out, "{unary}(")?;
write_expression(self, expr)?;
write!(self.out, ")")?;

Ok(())
}

/// Helper method used to write binary expressions
fn write_binary_expression<E>(
&mut self,
op: crate::BinaryOperator,
left: Handle<crate::Expression>,
right: Handle<crate::Expression>,
write_expression: E,
) -> BackendResult
where
E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
{
write!(self.out, "(")?;
write_expression(self, left)?;
write!(self.out, " {} ", back::binary_operation_str(op))?;
write_expression(self, right)?;
write!(self.out, ")")?;

Ok(())
}

/// Helper method used to write global variables
/// # Notes
/// Always adds a newline
Expand Down Expand Up @@ -1906,6 +1953,33 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Helper method used to write override declarations
///
/// # Notes
/// Ends in a newline
fn write_override(
&mut self,
module: &Module,
handle: Handle<crate::Override>,
) -> BackendResult {
let r#override = &module.overrides[handle];
let name = self.namer.call(r#override.name.as_deref().unwrap_or(""));
if let Some(id) = r#override.id {
write!(self.out, "@id({id}) ")?;
}
write!(self.out, "override {name}: ")?;
self.write_type(module, r#override.ty)?;
if let Some(init) = r#override.init {
write!(self.out, " = ")?;
self.write_const_expression(module, init)?;
}
writeln!(self.out, ";")?;

self.names.insert(NameKey::Override(handle), name);

Ok(())
}

// See https://github.com/rust-lang/rust-clippy/issues/4979.
#[allow(clippy::missing_const_for_fn)]
pub fn finish(self) -> W {
Expand Down
1 change: 1 addition & 0 deletions naga/src/proc/namer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const SEPARATOR: char = '_';
#[derive(Debug, Eq, Hash, PartialEq)]
pub enum NameKey {
Constant(Handle<crate::Constant>),
Override(Handle<crate::Override>),
GlobalVariable(Handle<crate::GlobalVariable>),
Type(Handle<crate::Type>),
StructMember(Handle<crate::Type>, u32),
Expand Down
24 changes: 24 additions & 0 deletions naga/tests/out/wgsl/overrides.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@id(0) override has_point_light: bool = true;
@id(1200) override specular_param: f32 = 2.3f;
@id(1300) override gain: f32;
override width: f32 = 0f;
override depth: f32;
override height: f32 = (2f * depth);
override inferred_f32_: f32 = 2.718f;

var<private> gain_x_10_: f32 = (gain * 10f);
var<private> store_override: f32;

@compute @workgroup_size(1, 1, 1)
fn main() {
var t: f32 = (height * 5f);
var x: bool;
var gain_x_100_: f32;

let a = !(has_point_light);
x = a;
let _e7 = gain_x_10_;
gain_x_100_ = (_e7 * 10f);
store_override = gain;
return;
}
3 changes: 2 additions & 1 deletion naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,8 @@ fn convert_wgsl() {
| Targets::SPIRV
| Targets::METAL
| Targets::HLSL
| Targets::GLSL,
| Targets::GLSL
| Targets::WGSL,
),
(
"overrides-atomicCompareExchangeWeak",
Expand Down