From f6f6dba73b0dc132a0af88a162935c4e7c44177e Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:01:00 +0100 Subject: [PATCH 01/12] refactor `try_eval_and_append` body --- naga/src/proc/constant_evaluator.rs | 72 +++++++++++++++++------------ 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 532f364532..f1f01e5855 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -258,6 +258,17 @@ enum Behavior<'a> { Glsl(GlslRestrictions<'a>), } +impl Behavior<'_> { + /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions. + const fn has_runtime_restrictions(&self) -> bool { + matches!( + self, + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)) + ) + } +} + /// A context for evaluating constant expressions. /// /// A `ConstantEvaluator` points at an expression arena to which it can append @@ -699,37 +710,40 @@ impl<'a> ConstantEvaluator<'a> { expr: Expression, span: Span, ) -> Result, ConstantEvaluatorError> { - match ( - &self.behavior, - self.expression_kind_tracker.type_of_with_expr(&expr), - ) { - // avoid errors on unimplemented functionality if possible - ( - &Behavior::Wgsl(WgslRestrictions::Runtime(_)) - | &Behavior::Glsl(GlslRestrictions::Runtime(_)), - ExpressionKind::Const, - ) => match self.try_eval_and_append_impl(&expr, span) { - Err( - ConstantEvaluatorError::NotImplemented(_) - | ConstantEvaluatorError::InvalidBinaryOpArgs, - ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), - res => res, + match self.expression_kind_tracker.type_of_with_expr(&expr) { + ExpressionKind::Const => { + let eval_result = self.try_eval_and_append_impl(&expr, span); + // avoid errors on unimplemented functionality if possible + if self.behavior.has_runtime_restrictions() + && matches!( + eval_result, + Err(ConstantEvaluatorError::NotImplemented(_) + | ConstantEvaluatorError::InvalidBinaryOpArgs,) + ) + { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + eval_result + } + } + ExpressionKind::Override => match self.behavior { + Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => { + Ok(self.append_expr(expr, span, ExpressionKind::Override)) + } + Behavior::Wgsl(WgslRestrictions::Const) => { + Err(ConstantEvaluatorError::OverrideExpr) + } + Behavior::Glsl(_) => { + unreachable!() + } }, - (_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span), - (&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => { - Err(ConstantEvaluatorError::OverrideExpr) + ExpressionKind::Runtime => { + if self.behavior.has_runtime_restrictions() { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + Err(ConstantEvaluatorError::RuntimeExpr) + } } - ( - &Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)), - ExpressionKind::Override, - ) => Ok(self.append_expr(expr, span, ExpressionKind::Override)), - (&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(), - ( - &Behavior::Wgsl(WgslRestrictions::Runtime(_)) - | &Behavior::Glsl(GlslRestrictions::Runtime(_)), - ExpressionKind::Runtime, - ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), - (_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr), } } From fc497c459821f717dcf58702bbaf1af6b1b43fe6 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 6 Mar 2024 12:22:33 +0100 Subject: [PATCH 02/12] evaluate override-expressions in functions --- naga/src/back/pipeline_constants.rs | 278 ++++++++++++++++++++- naga/tests/in/overrides.wgsl | 6 +- naga/tests/out/analysis/overrides.info.ron | 78 +++++- naga/tests/out/hlsl/overrides.hlsl | 5 + naga/tests/out/ir/overrides.compact.ron | 50 +++- naga/tests/out/ir/overrides.ron | 50 +++- naga/tests/out/msl/overrides.msl | 4 + naga/tests/out/spv/overrides.main.spvasm | 15 +- 8 files changed, 472 insertions(+), 14 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 298ccbc0d3..bd9eec76ee 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -1,10 +1,11 @@ use super::PipelineConstants; use crate::{ - proc::{ConstantEvaluator, ConstantEvaluatorError}, + proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, - Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan, + Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, + Span, Statement, SwitchCase, TypeInner, WithSpan, }; -use std::{borrow::Cow, collections::HashSet}; +use std::{borrow::Cow, collections::HashSet, mem}; use thiserror::Error; #[derive(Error, Debug, Clone)] @@ -175,6 +176,18 @@ pub(super) fn process_overrides<'a>( } } + let mut functions = mem::take(&mut module.functions); + for (_, function) in functions.iter_mut() { + process_function(&mut module, &override_map, function)?; + } + let _ = mem::replace(&mut module.functions, functions); + + let mut entry_points = mem::take(&mut module.entry_points); + for ep in entry_points.iter_mut() { + process_function(&mut module, &override_map, &mut ep.function)?; + } + let _ = mem::replace(&mut module.entry_points, entry_points); + // Now that the global expression arena has changed, we need to // recompute those expressions' types. For the time being, do a // full re-validation. @@ -237,6 +250,64 @@ fn process_override( Ok(h) } +/// Replaces all `Expression::Override`s in this function's expression arena +/// with `Expression::Constant` and evaluates all expressions in its arena. +fn process_function( + module: &mut Module, + override_map: &[Handle], + function: &mut Function, +) -> Result<(), ConstantEvaluatorError> { + // A map from original local expression handles to + // handles in the new, local expression arena. + let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len()); + + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + + let mut expressions = mem::take(&mut function.expressions); + + // Dummy `emitter` and `block` for the constant evaluator. + // We can ignore the concept of emitting expressions here since + // expressions have already been covered by a `Statement::Emit` + // in the frontend. + // The only thing we might have to do is remove some expressions + // that have been covered by a `Statement::Emit`. See the docs of + // `filter_emits_in_block` for the reasoning. + let mut emitter = Emitter::default(); + let mut block = Block::new(); + + for (old_h, expr, span) in expressions.drain() { + let mut expr = match expr { + Expression::Override(h) => Expression::Constant(override_map[h.index()]), + expr => expr, + }; + let mut evaluator = ConstantEvaluator::for_wgsl_function( + module, + &mut function.expressions, + &mut local_expression_kind_tracker, + &mut emitter, + &mut block, + ); + adjust_expr(&adjusted_local_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_local_expressions.len()); + adjusted_local_expressions.push(h); + } + + adjust_block(&adjusted_local_expressions, &mut function.body); + + let new_body = filter_emits_in_block(&function.body, &function.expressions); + let _ = mem::replace(&mut function.body, new_body); + + let named_expressions = mem::take(&mut function.named_expressions); + for (expr_h, name) in named_expressions { + function + .named_expressions + .insert(adjusted_local_expressions[expr_h.index()], name); + } + + Ok(()) +} + /// Replace every expression handle in `expr` with its counterpart /// given by `new_pos`. fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { @@ -409,6 +480,207 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { } } +/// Replace every expression handle in `block` with its counterpart +/// given by `new_pos`. +fn adjust_block(new_pos: &[Handle], block: &mut Block) { + for stmt in block.iter_mut() { + adjust_stmt(new_pos, stmt); + } +} + +/// Replace every expression handle in `stmt` with its counterpart +/// given by `new_pos`. +fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { + let adjust = |expr: &mut Handle| { + *expr = new_pos[expr.index()]; + }; + match *stmt { + Statement::Emit(ref mut range) => { + if let Some((mut first, mut last)) = range.first_and_last() { + adjust(&mut first); + adjust(&mut last); + *range = Range::new_from_bounds(first, last); + } + } + Statement::Block(ref mut block) => { + adjust_block(new_pos, block); + } + Statement::If { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust_block(new_pos, accept); + adjust_block(new_pos, reject); + } + Statement::Switch { + ref mut selector, + ref mut cases, + } => { + adjust(selector); + for case in cases.iter_mut() { + adjust_block(new_pos, &mut case.body); + } + } + Statement::Loop { + ref mut body, + ref mut continuing, + ref mut break_if, + } => { + adjust_block(new_pos, body); + adjust_block(new_pos, continuing); + if let Some(e) = break_if.as_mut() { + adjust(e); + } + } + Statement::Return { ref mut value } => { + if let Some(e) = value.as_mut() { + adjust(e); + } + } + Statement::Store { + ref mut pointer, + ref mut value, + } => { + adjust(pointer); + adjust(value); + } + Statement::ImageStore { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut value, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + adjust(value); + } + crate::Statement::Atomic { + ref mut pointer, + ref mut value, + ref mut result, + .. + } => { + adjust(pointer); + adjust(value); + adjust(result); + } + Statement::WorkGroupUniformLoad { + ref mut pointer, + ref mut result, + } => { + adjust(pointer); + adjust(result); + } + Statement::Call { + ref mut arguments, + ref mut result, + .. + } => { + for argument in arguments.iter_mut() { + adjust(argument); + } + if let Some(e) = result.as_mut() { + adjust(e); + } + } + Statement::RayQuery { ref mut query, .. } => { + adjust(query); + } + Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} + } +} + +/// Filters out expressions that `needs_pre_emit`. This step is necessary after +/// const evaluation since unevaluated expressions could have been included in +/// `Statement::Emit`; but since they have been evaluated we need to filter those +/// out. +fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Block { + let mut out = Block::with_capacity(block.len()); + for (stmt, span) in block.span_iter() { + match stmt { + &Statement::Emit(ref range) => { + let mut current = None; + for expr_h in range.clone() { + if expressions[expr_h].needs_pre_emit() { + if let Some((first, last)) = current { + out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + } + + current = None; + } else if let Some((_, ref mut last)) = current { + *last = expr_h; + } else { + current = Some((expr_h, expr_h)); + } + } + if let Some((first, last)) = current { + out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + } + } + &Statement::Block(ref block) => { + let block = filter_emits_in_block(block, expressions); + out.push(Statement::Block(block), *span); + } + &Statement::If { + condition, + ref accept, + ref reject, + } => { + let accept = filter_emits_in_block(accept, expressions); + let reject = filter_emits_in_block(reject, expressions); + out.push( + Statement::If { + condition, + accept, + reject, + }, + *span, + ); + } + &Statement::Switch { + selector, + ref cases, + } => { + let cases = cases + .iter() + .map(|case| { + let body = filter_emits_in_block(&case.body, expressions); + SwitchCase { + value: case.value, + body, + fall_through: case.fall_through, + } + }) + .collect(); + out.push(Statement::Switch { selector, cases }, *span); + } + &Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let body = filter_emits_in_block(body, expressions); + let continuing = filter_emits_in_block(continuing, expressions); + out.push( + Statement::Loop { + body, + continuing, + break_if, + }, + *span, + ); + } + stmt => out.push(stmt.clone(), *span), + } + } + out +} + fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { // note that in rust 0.0 == -0.0 match scalar { diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index 41e99f9426..b06edecdb9 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -14,4 +14,8 @@ override inferred_f32 = 2.718; @compute @workgroup_size(1) -fn main() {} \ No newline at end of file +fn main() { + var t = height * 5; + let a = !has_point_light; + var x = a; +} \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 7a2447f3c0..389e7fba7f 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -15,7 +15,83 @@ may_kill: false, sampling_set: [], global_uses: [], - expressions: [], + expressions: [ + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(4), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(7), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ], sampling: [], dual_source_blending: false, ), diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 0a849fd4db..1541ae7281 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -9,5 +9,10 @@ static const float inferred_f32_ = 2.718; [numthreads(1, 1, 1)] void main() { + float t = (float)0; + bool x = (bool)0; + + t = 23.0; + x = true; return; } diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index 7a60f14239..b0a230a716 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -90,10 +90,54 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [], - named_expressions: {}, + local_variables: [ + ( + name: Some("t"), + ty: 2, + init: None, + ), + ( + name: Some("x"), + ty: 1, + init: None, + ), + ], + expressions: [ + Override(6), + Literal(F32(5.0)), + Binary( + op: Multiply, + left: 1, + right: 2, + ), + LocalVariable(1), + Override(1), + Unary( + op: LogicalNot, + expr: 5, + ), + LocalVariable(2), + ], + named_expressions: { + 6: "a", + }, body: [ + Emit(( + start: 2, + end: 3, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 5, + end: 6, + )), + Store( + pointer: 7, + value: 6, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index 7a60f14239..b0a230a716 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -90,10 +90,54 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [], - named_expressions: {}, + local_variables: [ + ( + name: Some("t"), + ty: 2, + init: None, + ), + ( + name: Some("x"), + ty: 1, + init: None, + ), + ], + expressions: [ + Override(6), + Literal(F32(5.0)), + Binary( + op: Multiply, + left: 1, + right: 2, + ), + LocalVariable(1), + Override(1), + Unary( + op: LogicalNot, + expr: 5, + ), + LocalVariable(2), + ], + named_expressions: { + 6: "a", + }, body: [ + Emit(( + start: 2, + end: 3, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 5, + end: 6, + )), + Store( + pointer: 7, + value: 6, + ), Return( value: None, ), diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index 13a3b623a0..0bc9e6b12c 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -14,5 +14,9 @@ constant float inferred_f32_ = 2.718; kernel void main_( ) { + float t = {}; + bool x = {}; + t = 23.0; + x = true; return; } diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index 7731edfb93..d421606ca9 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 17 +; Bound: 24 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -19,9 +19,18 @@ OpExecutionMode %14 LocalSize 1 1 1 %11 = OpConstant %4 4.6 %12 = OpConstant %4 2.718 %15 = OpTypeFunction %2 +%16 = OpConstant %4 23.0 +%18 = OpTypePointer Function %4 +%19 = OpConstantNull %4 +%21 = OpTypePointer Function %3 +%22 = OpConstantNull %3 %14 = OpFunction %2 None %15 %13 = OpLabel -OpBranch %16 -%16 = OpLabel +%17 = OpVariable %18 Function %19 +%20 = OpVariable %21 Function %22 +OpBranch %23 +%23 = OpLabel +OpStore %17 %16 +OpStore %20 %5 OpReturn OpFunctionEnd \ No newline at end of file From 4d73a819379462691aadc561463f8c76386b6c32 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 6 Mar 2024 12:23:16 +0100 Subject: [PATCH 03/12] allow private variables to have an override-expression initializer --- naga/src/front/wgsl/lower/mod.rs | 2 +- naga/src/valid/interface.rs | 4 +- naga/tests/in/overrides.wgsl | 4 ++ naga/tests/out/analysis/overrides.info.ron | 70 +++++++++++++++++++++- naga/tests/out/hlsl/overrides.hlsl | 5 ++ naga/tests/out/ir/overrides.compact.ron | 45 +++++++++++++- naga/tests/out/ir/overrides.ron | 45 +++++++++++++- naga/tests/out/msl/overrides.msl | 4 ++ naga/tests/out/spv/overrides.main.spvasm | 43 +++++++------ 9 files changed, 199 insertions(+), 23 deletions(-) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 1a8b75811b..7abd95114d 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -916,7 +916,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let init; if let Some(init_ast) = v.init { - let mut ectx = ctx.as_const(); + let mut ectx = ctx.as_override(); let lowered = self.expression_for_abstract(init_ast, &mut ectx)?; let ty_res = crate::proc::TypeResolution::Handle(ty); let converted = ectx diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 0e42075de1..2435b34c29 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -31,7 +31,7 @@ pub enum GlobalVariableError { Handle, #[source] Disalignment, ), - #[error("Initializer must be a const-expression")] + #[error("Initializer must be an override-expression")] InitializerExprType, #[error("Initializer doesn't match the variable type")] InitializerType, @@ -529,7 +529,7 @@ impl super::Validator { } } - if !global_expr_kind.is_const(init) { + if !global_expr_kind.is_const_or_override(init) { return Err(GlobalVariableError::InitializerExprType); } diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index b06edecdb9..ab1d637a11 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -13,9 +13,13 @@ override inferred_f32 = 2.718; +var gain_x_10: f32 = gain * 10.; + @compute @workgroup_size(1) fn main() { var t = height * 5; let a = !has_point_light; var x = a; + + var gain_x_100 = gain_x_10 * 10.; } \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 389e7fba7f..6ea54bb296 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -14,7 +14,9 @@ ), may_kill: false, sampling_set: [], - global_uses: [], + global_uses: [ + ("READ"), + ], expressions: [ ( uniformity: ( @@ -91,6 +93,63 @@ space: Function, )), ), + ( + uniformity: ( + non_uniform_result: Some(8), + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 2, + space: Private, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(8), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(8), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(12), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), ], sampling: [], dual_source_blending: false, @@ -119,5 +178,14 @@ kind: Float, width: 4, ))), + Handle(2), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), ], ) \ No newline at end of file diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 1541ae7281..072cd9ffcc 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -6,13 +6,18 @@ static const float depth = 2.3; static const float height = 4.6; static const float inferred_f32_ = 2.718; +static float gain_x_10_ = 11.0; + [numthreads(1, 1, 1)] void main() { float t = (float)0; bool x = (bool)0; + float gain_x_100_ = (float)0; t = 23.0; x = true; + float _expr10 = gain_x_10_; + gain_x_100_ = (_expr10 * 10.0); return; } diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index b0a230a716..4188354224 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -65,7 +65,15 @@ init: Some(7), ), ], - global_variables: [], + global_variables: [ + ( + name: Some("gain_x_10"), + space: Private, + binding: None, + ty: 2, + init: Some(10), + ), + ], global_expressions: [ Literal(Bool(true)), Literal(F32(2.3)), @@ -78,6 +86,13 @@ right: 4, ), Literal(F32(2.718)), + Override(3), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 8, + right: 9, + ), ], functions: [], entry_points: [ @@ -101,6 +116,11 @@ ty: 1, init: None, ), + ( + name: Some("gain_x_100"), + ty: 2, + init: None, + ), ], expressions: [ Override(6), @@ -117,6 +137,17 @@ expr: 5, ), LocalVariable(2), + GlobalVariable(1), + Load( + pointer: 8, + ), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 9, + right: 10, + ), + LocalVariable(3), ], named_expressions: { 6: "a", @@ -138,6 +169,18 @@ pointer: 7, value: 6, ), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 12, + value: 11, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index b0a230a716..4188354224 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -65,7 +65,15 @@ init: Some(7), ), ], - global_variables: [], + global_variables: [ + ( + name: Some("gain_x_10"), + space: Private, + binding: None, + ty: 2, + init: Some(10), + ), + ], global_expressions: [ Literal(Bool(true)), Literal(F32(2.3)), @@ -78,6 +86,13 @@ right: 4, ), Literal(F32(2.718)), + Override(3), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 8, + right: 9, + ), ], functions: [], entry_points: [ @@ -101,6 +116,11 @@ ty: 1, init: None, ), + ( + name: Some("gain_x_100"), + ty: 2, + init: None, + ), ], expressions: [ Override(6), @@ -117,6 +137,17 @@ expr: 5, ), LocalVariable(2), + GlobalVariable(1), + Load( + pointer: 8, + ), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 9, + right: 10, + ), + LocalVariable(3), ], named_expressions: { 6: "a", @@ -138,6 +169,18 @@ pointer: 7, value: 6, ), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 12, + value: 11, + ), Return( value: None, ), diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index 0bc9e6b12c..f884d1b527 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -14,9 +14,13 @@ constant float inferred_f32_ = 2.718; kernel void main_( ) { + float gain_x_10_ = 11.0; float t = {}; bool x = {}; + float gain_x_100_ = {}; t = 23.0; x = true; + float _e10 = gain_x_10_; + gain_x_100_ = _e10 * 10.0; return; } diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index d421606ca9..d4ce4752ed 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,12 +1,12 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 24 +; Bound: 32 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %14 "main" -OpExecutionMode %14 LocalSize 1 1 1 +OpEntryPoint GLCompute %18 "main" +OpExecutionMode %18 LocalSize 1 1 1 %2 = OpTypeVoid %3 = OpTypeBool %4 = OpTypeFloat 32 @@ -18,19 +18,28 @@ OpExecutionMode %14 LocalSize 1 1 1 %10 = OpConstant %4 2.0 %11 = OpConstant %4 4.6 %12 = OpConstant %4 2.718 -%15 = OpTypeFunction %2 -%16 = OpConstant %4 23.0 -%18 = OpTypePointer Function %4 -%19 = OpConstantNull %4 -%21 = OpTypePointer Function %3 -%22 = OpConstantNull %3 -%14 = OpFunction %2 None %15 -%13 = OpLabel -%17 = OpVariable %18 Function %19 -%20 = OpVariable %21 Function %22 -OpBranch %23 -%23 = OpLabel -OpStore %17 %16 -OpStore %20 %5 +%13 = OpConstant %4 10.0 +%14 = OpConstant %4 11.0 +%16 = OpTypePointer Private %4 +%15 = OpVariable %16 Private %14 +%19 = OpTypeFunction %2 +%20 = OpConstant %4 23.0 +%22 = OpTypePointer Function %4 +%23 = OpConstantNull %4 +%25 = OpTypePointer Function %3 +%26 = OpConstantNull %3 +%28 = OpConstantNull %4 +%18 = OpFunction %2 None %19 +%17 = OpLabel +%21 = OpVariable %22 Function %23 +%24 = OpVariable %25 Function %26 +%27 = OpVariable %22 Function %28 +OpBranch %29 +%29 = OpLabel +OpStore %21 %20 +OpStore %24 %5 +%30 = OpLoad %4 %15 +%31 = OpFMul %4 %30 %13 +OpStore %27 %31 OpReturn OpFunctionEnd \ No newline at end of file From 68ce24f6d5fa32e7a23563be15afcd4652d92b42 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:52:14 -0700 Subject: [PATCH 04/12] [naga] Doc tweaks for `back::pipeline_constants`. --- naga/src/back/pipeline_constants.rs | 39 ++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index bd9eec76ee..143afd8a57 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -188,9 +188,9 @@ pub(super) fn process_overrides<'a>( } let _ = mem::replace(&mut module.entry_points, entry_points); - // Now that the global expression arena has changed, we need to - // recompute those expressions' types. For the time being, do a - // full re-validation. + // Now that we've rewritten all the expressions, we need to + // recompute their types and other metadata. For the time being, + // do a full re-validation. let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); let module_info = validator.validate_no_overrides(&module)?; @@ -250,8 +250,15 @@ fn process_override( Ok(h) } -/// Replaces all `Expression::Override`s in this function's expression arena -/// with `Expression::Constant` and evaluates all expressions in its arena. +/// Replace all override expressions in `function` with fully-evaluated constants. +/// +/// Replace all `Expression::Override`s in `function`'s expression arena with +/// the corresponding `Expression::Constant`s, as given in `override_map`. +/// Replace any expressions whose values are now known with their fully +/// evaluated form. +/// +/// If `h` is a `Handle`, then `override_map[h.index()]` is the +/// `Handle` for the override's final value. fn process_function( module: &mut Module, override_map: &[Handle], @@ -298,6 +305,8 @@ fn process_function( let new_body = filter_emits_in_block(&function.body, &function.expressions); let _ = mem::replace(&mut function.body, new_body); + // We've changed the keys of `function.named_expression`, so we have to + // rebuild it from scratch. let named_expressions = mem::take(&mut function.named_expressions); for (expr_h, name) in named_expressions { function @@ -595,10 +604,22 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { } } -/// Filters out expressions that `needs_pre_emit`. This step is necessary after -/// const evaluation since unevaluated expressions could have been included in -/// `Statement::Emit`; but since they have been evaluated we need to filter those -/// out. +/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced. +/// +/// According to validation, [`Emit`] statements must not cover any expressions +/// for which [`Expression::needs_pre_emit`] returns true. All expressions built +/// by successful constant evaluation fall into that category, meaning that +/// `process_function` will usually rewrite [`Override`] expressions and those +/// that use their values into pre-emitted expressions, leaving any [`Emit`] +/// statements that cover them invalid. +/// +/// This function rewrites all [`Emit`] statements into zero or more new +/// [`Emit`] statements covering only those expressions in the original range +/// that are not pre-emitted. +/// +/// [`Emit`]: Statement::Emit +/// [`needs_pre_emit`]: Expression::needs_pre_emit +/// [`Override`]: Expression::Override fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Block { let mut out = Block::with_capacity(block.len()); for (stmt, span) in block.span_iter() { From 8f582661ccbfa5d732d78df288686ebb4c7940c1 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:53:05 -0700 Subject: [PATCH 05/12] [naga] Simplify uses of `replace` in `back::pipeline_constants`. --- naga/src/back/pipeline_constants.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 143afd8a57..0cc5df5732 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -180,13 +180,13 @@ pub(super) fn process_overrides<'a>( for (_, function) in functions.iter_mut() { process_function(&mut module, &override_map, function)?; } - let _ = mem::replace(&mut module.functions, functions); + module.functions = functions; let mut entry_points = mem::take(&mut module.entry_points); for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut ep.function)?; } - let _ = mem::replace(&mut module.entry_points, entry_points); + module.entry_points = entry_points; // Now that we've rewritten all the expressions, we need to // recompute their types and other metadata. For the time being, @@ -303,7 +303,7 @@ fn process_function( adjust_block(&adjusted_local_expressions, &mut function.body); let new_body = filter_emits_in_block(&function.body, &function.expressions); - let _ = mem::replace(&mut function.body, new_body); + function.body = new_body; // We've changed the keys of `function.named_expression`, so we have to // rebuild it from scratch. From bd9ae612d5224aa13a9f5a4d9b94efb373248e83 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:53:39 -0700 Subject: [PATCH 06/12] [naga] Hoist `ConstantEvaluator` construction in `process_function`. There's no need to build a fresh `ConstantEvaluator` for every expression; just build it once and reuse it. --- naga/src/back/pipeline_constants.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 0cc5df5732..b8789d3b93 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -282,18 +282,18 @@ fn process_function( let mut emitter = Emitter::default(); let mut block = Block::new(); - for (old_h, expr, span) in expressions.drain() { - let mut expr = match expr { - Expression::Override(h) => Expression::Constant(override_map[h.index()]), - expr => expr, - }; - let mut evaluator = ConstantEvaluator::for_wgsl_function( - module, - &mut function.expressions, - &mut local_expression_kind_tracker, - &mut emitter, - &mut block, - ); + let mut evaluator = ConstantEvaluator::for_wgsl_function( + module, + &mut function.expressions, + &mut local_expression_kind_tracker, + &mut emitter, + &mut block, + ); + + for (old_h, mut expr, span) in expressions.drain() { + if let Expression::Override(h) = expr { + expr = Expression::Constant(override_map[h.index()]); + } adjust_expr(&adjusted_local_expressions, &mut expr); let h = evaluator.try_eval_and_append(expr, span)?; debug_assert_eq!(old_h.index(), adjusted_local_expressions.len()); From b5cec4fca205fc689a2bd8f731552eee004cbaaa Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:47:49 -0700 Subject: [PATCH 07/12] [naga] Let `filter_emits_with_block` operate on a `&mut Block`. This removes some clones and collects, simplifies call sites, and isn't any more complicated to implement. --- naga/src/back/pipeline_constants.rs | 76 +++++++++++++---------------- naga/src/block.rs | 6 +++ 2 files changed, 39 insertions(+), 43 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index b8789d3b93..62e3cd0e42 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -3,7 +3,7 @@ use crate::{ proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, - Span, Statement, SwitchCase, TypeInner, WithSpan, + Span, Statement, TypeInner, WithSpan, }; use std::{borrow::Cow, collections::HashSet, mem}; use thiserror::Error; @@ -302,8 +302,7 @@ fn process_function( adjust_block(&adjusted_local_expressions, &mut function.body); - let new_body = filter_emits_in_block(&function.body, &function.expressions); - function.body = new_body; + filter_emits_in_block(&mut function.body, &function.expressions); // We've changed the keys of `function.named_expression`, so we have to // rebuild it from scratch. @@ -620,16 +619,16 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { /// [`Emit`]: Statement::Emit /// [`needs_pre_emit`]: Expression::needs_pre_emit /// [`Override`]: Expression::Override -fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Block { - let mut out = Block::with_capacity(block.len()); - for (stmt, span) in block.span_iter() { +fn filter_emits_in_block(block: &mut Block, expressions: &Arena) { + let original = std::mem::replace(block, Block::with_capacity(block.len())); + for (stmt, span) in original.span_into_iter() { match stmt { - &Statement::Emit(ref range) => { + Statement::Emit(range) => { let mut current = None; - for expr_h in range.clone() { + for expr_h in range { if expressions[expr_h].needs_pre_emit() { if let Some((first, last)) = current { - out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); } current = None; @@ -640,66 +639,57 @@ fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Bloc } } if let Some((first, last)) = current { - out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); } } - &Statement::Block(ref block) => { - let block = filter_emits_in_block(block, expressions); - out.push(Statement::Block(block), *span); + Statement::Block(mut child) => { + filter_emits_in_block(&mut child, expressions); + block.push(Statement::Block(child), span); } - &Statement::If { + Statement::If { condition, - ref accept, - ref reject, + mut accept, + mut reject, } => { - let accept = filter_emits_in_block(accept, expressions); - let reject = filter_emits_in_block(reject, expressions); - out.push( + filter_emits_in_block(&mut accept, expressions); + filter_emits_in_block(&mut reject, expressions); + block.push( Statement::If { condition, accept, reject, }, - *span, + span, ); } - &Statement::Switch { + Statement::Switch { selector, - ref cases, + mut cases, } => { - let cases = cases - .iter() - .map(|case| { - let body = filter_emits_in_block(&case.body, expressions); - SwitchCase { - value: case.value, - body, - fall_through: case.fall_through, - } - }) - .collect(); - out.push(Statement::Switch { selector, cases }, *span); + for case in &mut cases { + filter_emits_in_block(&mut case.body, expressions); + } + block.push(Statement::Switch { selector, cases }, span); } - &Statement::Loop { - ref body, - ref continuing, + Statement::Loop { + mut body, + mut continuing, break_if, } => { - let body = filter_emits_in_block(body, expressions); - let continuing = filter_emits_in_block(continuing, expressions); - out.push( + filter_emits_in_block(&mut body, expressions); + filter_emits_in_block(&mut continuing, expressions); + block.push( Statement::Loop { body, continuing, break_if, }, - *span, + span, ); } - stmt => out.push(stmt.clone(), *span), + stmt => block.push(stmt.clone(), span), } } - out } fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { diff --git a/naga/src/block.rs b/naga/src/block.rs index 0abda9da7c..2e86a928f1 100644 --- a/naga/src/block.rs +++ b/naga/src/block.rs @@ -65,6 +65,12 @@ impl Block { self.span_info.splice(range.clone(), other.span_info); self.body.splice(range, other.body); } + + pub fn span_into_iter(self) -> impl Iterator { + let Block { body, span_info } = self; + body.into_iter().zip(span_info) + } + pub fn span_iter(&self) -> impl Iterator { let span_iter = self.span_info.iter(); self.body.iter().zip(span_iter) From 7f8a56acca9e3a8e19cc151ed6e42386b572d954 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sun, 24 Mar 2024 13:47:39 -0700 Subject: [PATCH 08/12] [naga] Tweak comments in `ConstantEvaluator::try_eval_and_append`. I found I needed a little bit more detail here. --- naga/src/proc/constant_evaluator.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index f1f01e5855..547fbbc652 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -713,7 +713,10 @@ impl<'a> ConstantEvaluator<'a> { match self.expression_kind_tracker.type_of_with_expr(&expr) { ExpressionKind::Const => { let eval_result = self.try_eval_and_append_impl(&expr, span); - // avoid errors on unimplemented functionality if possible + // We should be able to evaluate `Const` expressions at this + // point. If we failed to, then that probably means we just + // haven't implemented that part of constant evaluation. Work + // around this by simply emitting it as a run-time expression. if self.behavior.has_runtime_restrictions() && matches!( eval_result, From 96be56fa4f6b43cf008bed0efea01079edfd1b78 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 13:31:00 -0700 Subject: [PATCH 09/12] [naga] Add missing newline to test input file. --- naga/tests/in/overrides.wgsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index ab1d637a11..6173c3463f 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -22,4 +22,4 @@ fn main() { var x = a; var gain_x_100 = gain_x_10 * 10.; -} \ No newline at end of file +} From ab08e05f9c33f0427b3dedc36da24686836d6e65 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 18:19:17 -0700 Subject: [PATCH 10/12] [naga] Handle comparison operands in pipeline constant evaluation. Properly adjust `AtomicFunction::Exchange::compare` after pipeline constant evaluation. --- naga/src/back/pipeline_constants.rs | 17 ++- ...rrides-atomicCompareExchangeWeak.param.ron | 9 ++ .../overrides-atomicCompareExchangeWeak.wgsl | 7 + ...ides-atomicCompareExchangeWeak.compact.ron | 128 ++++++++++++++++++ .../overrides-atomicCompareExchangeWeak.ron | 128 ++++++++++++++++++ ...errides-atomicCompareExchangeWeak.f.spvasm | 52 +++++++ naga/tests/snapshots.rs | 4 + 7 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron create mode 100644 naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl create mode 100644 naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron create mode 100644 naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron create mode 100644 naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 62e3cd0e42..a7606f5bb7 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -571,11 +571,26 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { ref mut pointer, ref mut value, ref mut result, - .. + ref mut fun, } => { adjust(pointer); adjust(value); adjust(result); + match *fun { + crate::AtomicFunction::Exchange { + compare: Some(ref mut compare), + } => { + adjust(compare); + } + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max + | crate::AtomicFunction::Exchange { compare: None } => {} + } } Statement::WorkGroupUniformLoad { ref mut pointer, diff --git a/naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron b/naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron new file mode 100644 index 0000000000..ff9c84ac61 --- /dev/null +++ b/naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron @@ -0,0 +1,9 @@ +( + spv: ( + version: (1, 0), + separate_entry_points: true, + ), + pipeline_constants: { + "o": 2.0 + } +) diff --git a/naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl b/naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl new file mode 100644 index 0000000000..03376b5931 --- /dev/null +++ b/naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl @@ -0,0 +1,7 @@ +override o: i32; +var a: atomic; + +@compute @workgroup_size(1) +fn f() { + atomicCompareExchangeWeak(&a, u32(o), 1u); +} diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron new file mode 100644 index 0000000000..8c889382dd --- /dev/null +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron @@ -0,0 +1,128 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Sint, + width: 4, + )), + ), + ( + name: None, + inner: Atomic(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("__atomic_compare_exchange_result"), + inner: Struct( + members: [ + ( + name: Some("old_value"), + ty: 3, + binding: None, + offset: 0, + ), + ( + name: Some("exchanged"), + ty: 4, + binding: None, + offset: 4, + ), + ], + span: 8, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: { + AtomicCompareExchangeWeakResult(( + kind: Uint, + width: 4, + )): 5, + }, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("a"), + space: WorkGroup, + binding: None, + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "f", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("f"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(1), + Override(1), + As( + expr: 2, + kind: Uint, + convert: Some(4), + ), + Literal(U32(1)), + AtomicResult( + ty: 5, + comparison: true, + ), + ], + named_expressions: {}, + body: [ + Emit(( + start: 2, + end: 3, + )), + Atomic( + pointer: 1, + fun: Exchange( + compare: Some(3), + ), + value: 4, + result: 5, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron new file mode 100644 index 0000000000..8c889382dd --- /dev/null +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron @@ -0,0 +1,128 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Sint, + width: 4, + )), + ), + ( + name: None, + inner: Atomic(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("__atomic_compare_exchange_result"), + inner: Struct( + members: [ + ( + name: Some("old_value"), + ty: 3, + binding: None, + offset: 0, + ), + ( + name: Some("exchanged"), + ty: 4, + binding: None, + offset: 4, + ), + ], + span: 8, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: { + AtomicCompareExchangeWeakResult(( + kind: Uint, + width: 4, + )): 5, + }, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("a"), + space: WorkGroup, + binding: None, + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "f", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("f"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(1), + Override(1), + As( + expr: 2, + kind: Uint, + convert: Some(4), + ), + Literal(U32(1)), + AtomicResult( + ty: 5, + comparison: true, + ), + ], + named_expressions: {}, + body: [ + Emit(( + start: 2, + end: 3, + )), + Atomic( + pointer: 1, + fun: Exchange( + compare: Some(3), + ), + value: 4, + result: 5, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm b/naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm new file mode 100644 index 0000000000..59c69ae1fc --- /dev/null +++ b/naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm @@ -0,0 +1,52 @@ +; SPIR-V +; Version: 1.0 +; Generator: rspirv +; Bound: 33 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %11 "f" %18 +OpExecutionMode %11 LocalSize 1 1 1 +OpMemberDecorate %6 0 Offset 0 +OpMemberDecorate %6 1 Offset 4 +OpDecorate %18 BuiltIn LocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 1 +%4 = OpTypeInt 32 0 +%5 = OpTypeBool +%6 = OpTypeStruct %4 %5 +%7 = OpConstant %3 2 +%9 = OpTypePointer Workgroup %4 +%8 = OpVariable %9 Workgroup +%12 = OpTypeFunction %2 +%13 = OpConstant %4 2 +%14 = OpConstant %4 1 +%16 = OpConstantNull %4 +%17 = OpTypeVector %4 3 +%19 = OpTypePointer Input %17 +%18 = OpVariable %19 Input +%21 = OpConstantNull %17 +%22 = OpTypeVector %5 3 +%27 = OpConstant %4 264 +%30 = OpConstant %4 256 +%11 = OpFunction %2 None %12 +%10 = OpLabel +OpBranch %15 +%15 = OpLabel +%20 = OpLoad %17 %18 +%23 = OpIEqual %22 %20 %21 +%24 = OpAll %5 %23 +OpSelectionMerge %25 None +OpBranchConditional %24 %26 %25 +%26 = OpLabel +OpStore %8 %16 +OpBranch %25 +%25 = OpLabel +OpControlBarrier %13 %13 %27 +OpBranch %28 +%28 = OpLabel +%31 = OpAtomicCompareExchange %4 %8 %7 %30 %30 %14 %13 +%32 = OpIEqual %5 %31 %13 +%29 = OpCompositeConstruct %6 %31 %32 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index e2f6dff25f..151e8b3da3 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -853,6 +853,10 @@ fn convert_wgsl() { "overrides", Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL, ), + ( + "overrides-atomicCompareExchangeWeak", + Targets::IR | Targets::SPIRV, + ), ]; for &(name, targets) in inputs.iter() { From 86b056ddf3aa0c97f81df8be92020a983f0f4a5b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 18:29:11 -0700 Subject: [PATCH 11/12] [naga] Spell out members in adjust_expr. --- naga/src/back/pipeline_constants.rs | 61 ++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index a7606f5bb7..d41eeedef2 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -324,7 +324,8 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { }; match *expr { Expression::Compose { - ref mut components, .. + ref mut components, + ty: _, } => { for c in components.iter_mut() { adjust(c); @@ -337,13 +338,23 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(base); adjust(index); } - Expression::AccessIndex { ref mut base, .. } => { + Expression::AccessIndex { + ref mut base, + index: _, + } => { adjust(base); } - Expression::Splat { ref mut value, .. } => { + Expression::Splat { + ref mut value, + size: _, + } => { adjust(value); } - Expression::Swizzle { ref mut vector, .. } => { + Expression::Swizzle { + ref mut vector, + size: _, + pattern: _, + } => { adjust(vector); } Expression::Load { ref mut pointer } => { @@ -357,7 +368,7 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { ref mut offset, ref mut level, ref mut depth_ref, - .. + gather: _, } => { adjust(image); adjust(sampler); @@ -416,16 +427,21 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(e); } } - _ => {} + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => {} } } - Expression::Unary { ref mut expr, .. } => { + Expression::Unary { + ref mut expr, + op: _, + } => { adjust(expr); } Expression::Binary { ref mut left, ref mut right, - .. + op: _, } => { adjust(left); adjust(right); @@ -439,11 +455,16 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(accept); adjust(reject); } - Expression::Derivative { ref mut expr, .. } => { + Expression::Derivative { + ref mut expr, + axis: _, + ctrl: _, + } => { adjust(expr); } Expression::Relational { - ref mut argument, .. + ref mut argument, + fun: _, } => { adjust(argument); } @@ -452,7 +473,7 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { ref mut arg1, ref mut arg2, ref mut arg3, - .. + fun: _, } => { adjust(arg); if let Some(e) = arg1.as_mut() { @@ -465,13 +486,20 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(e); } } - Expression::As { ref mut expr, .. } => { + Expression::As { + ref mut expr, + kind: _, + convert: _, + } => { adjust(expr); } Expression::ArrayLength(ref mut expr) => { adjust(expr); } - Expression::RayQueryGetIntersection { ref mut query, .. } => { + Expression::RayQueryGetIntersection { + ref mut query, + committed: _, + } => { adjust(query); } Expression::Literal(_) @@ -483,8 +511,11 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { | Expression::Constant(_) | Expression::Override(_) | Expression::ZeroValue(_) - | Expression::AtomicResult { .. } - | Expression::WorkGroupUniformLoadResult { .. } => {} + | Expression::AtomicResult { + ty: _, + comparison: _, + } + | Expression::WorkGroupUniformLoadResult { ty: _ } => {} } } From 28dc70ed1cb6ab4f80648e783242d95f90fab75d Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 19:11:23 -0700 Subject: [PATCH 12/12] [naga] Adjust RayQuery statements in override processing. --- naga/src/back/pipeline_constants.rs | 20 +- naga/tests/in/overrides-ray-query.param.ron | 18 ++ naga/tests/in/overrides-ray-query.wgsl | 21 ++ .../out/ir/overrides-ray-query.compact.ron | 259 ++++++++++++++++++ naga/tests/out/ir/overrides-ray-query.ron | 259 ++++++++++++++++++ naga/tests/out/msl/overrides-ray-query.msl | 45 +++ .../out/spv/overrides-ray-query.main.spvasm | 77 ++++++ naga/tests/snapshots.rs | 5 + 8 files changed, 702 insertions(+), 2 deletions(-) create mode 100644 naga/tests/in/overrides-ray-query.param.ron create mode 100644 naga/tests/in/overrides-ray-query.wgsl create mode 100644 naga/tests/out/ir/overrides-ray-query.compact.ron create mode 100644 naga/tests/out/ir/overrides-ray-query.ron create mode 100644 naga/tests/out/msl/overrides-ray-query.msl create mode 100644 naga/tests/out/spv/overrides-ray-query.main.spvasm diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d41eeedef2..c1fd2d02cc 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -633,7 +633,7 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { Statement::Call { ref mut arguments, ref mut result, - .. + function: _, } => { for argument in arguments.iter_mut() { adjust(argument); @@ -642,8 +642,24 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { adjust(e); } } - Statement::RayQuery { ref mut query, .. } => { + Statement::RayQuery { + ref mut query, + ref mut fun, + } => { adjust(query); + match *fun { + crate::RayQueryFunction::Initialize { + ref mut acceleration_structure, + ref mut descriptor, + } => { + adjust(acceleration_structure); + adjust(descriptor); + } + crate::RayQueryFunction::Proceed { ref mut result } => { + adjust(result); + } + crate::RayQueryFunction::Terminate => {} + } } Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} } diff --git a/naga/tests/in/overrides-ray-query.param.ron b/naga/tests/in/overrides-ray-query.param.ron new file mode 100644 index 0000000000..588656aaac --- /dev/null +++ b/naga/tests/in/overrides-ray-query.param.ron @@ -0,0 +1,18 @@ +( + god_mode: true, + spv: ( + version: (1, 4), + separate_entry_points: true, + ), + msl: ( + lang_version: (2, 4), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + zero_initialize_workgroup_memory: false, + per_entry_point_map: {}, + inline_samplers: [], + ), + pipeline_constants: { + "o": 2.0 + } +) diff --git a/naga/tests/in/overrides-ray-query.wgsl b/naga/tests/in/overrides-ray-query.wgsl new file mode 100644 index 0000000000..dca7447ed0 --- /dev/null +++ b/naga/tests/in/overrides-ray-query.wgsl @@ -0,0 +1,21 @@ +override o: f32; + +@group(0) @binding(0) +var acc_struct: acceleration_structure; + +@compute @workgroup_size(1) +fn main() { + var rq: ray_query; + + let desc = RayDesc( + RAY_FLAG_TERMINATE_ON_FIRST_HIT, + 0xFFu, + o * 17.0, + o * 19.0, + vec3(o * 23.0), + vec3(o * 29.0, o * 31.0, o * 37.0), + ); + rayQueryInitialize(&rq, acc_struct, desc); + + while (rayQueryProceed(&rq)) {} +} diff --git a/naga/tests/out/ir/overrides-ray-query.compact.ron b/naga/tests/out/ir/overrides-ray-query.compact.ron new file mode 100644 index 0000000000..b127259bbb --- /dev/null +++ b/naga/tests/out/ir/overrides-ray-query.compact.ron @@ -0,0 +1,259 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: AccelerationStructure, + ), + ( + name: None, + inner: RayQuery, + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: Some("RayDesc"), + inner: Struct( + members: [ + ( + name: Some("flags"), + ty: 4, + binding: None, + offset: 0, + ), + ( + name: Some("cull_mask"), + ty: 4, + binding: None, + offset: 4, + ), + ( + name: Some("tmin"), + ty: 1, + binding: None, + offset: 8, + ), + ( + name: Some("tmax"), + ty: 1, + binding: None, + offset: 12, + ), + ( + name: Some("origin"), + ty: 5, + binding: None, + offset: 16, + ), + ( + name: Some("dir"), + ty: 5, + binding: None, + offset: 32, + ), + ], + span: 48, + ), + ), + ], + special_types: ( + ray_desc: Some(6), + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("acc_struct"), + space: Handle, + binding: Some(( + group: 0, + binding: 0, + )), + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("rq"), + ty: 3, + init: None, + ), + ], + expressions: [ + LocalVariable(1), + Literal(U32(4)), + Literal(U32(255)), + Override(1), + Literal(F32(17.0)), + Binary( + op: Multiply, + left: 4, + right: 5, + ), + Override(1), + Literal(F32(19.0)), + Binary( + op: Multiply, + left: 7, + right: 8, + ), + Override(1), + Literal(F32(23.0)), + Binary( + op: Multiply, + left: 10, + right: 11, + ), + Splat( + size: Tri, + value: 12, + ), + Override(1), + Literal(F32(29.0)), + Binary( + op: Multiply, + left: 14, + right: 15, + ), + Override(1), + Literal(F32(31.0)), + Binary( + op: Multiply, + left: 17, + right: 18, + ), + Override(1), + Literal(F32(37.0)), + Binary( + op: Multiply, + left: 20, + right: 21, + ), + Compose( + ty: 5, + components: [ + 16, + 19, + 22, + ], + ), + Compose( + ty: 6, + components: [ + 2, + 3, + 6, + 9, + 13, + 23, + ], + ), + GlobalVariable(1), + RayQueryProceedResult, + ], + named_expressions: { + 24: "desc", + }, + body: [ + Emit(( + start: 5, + end: 6, + )), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 11, + end: 13, + )), + Emit(( + start: 15, + end: 16, + )), + Emit(( + start: 18, + end: 19, + )), + Emit(( + start: 21, + end: 24, + )), + RayQuery( + query: 1, + fun: Initialize( + acceleration_structure: 25, + descriptor: 24, + ), + ), + Loop( + body: [ + RayQuery( + query: 1, + fun: Proceed( + result: 26, + ), + ), + If( + condition: 26, + accept: [], + reject: [ + Break, + ], + ), + Block([]), + ], + continuing: [], + break_if: None, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides-ray-query.ron b/naga/tests/out/ir/overrides-ray-query.ron new file mode 100644 index 0000000000..b127259bbb --- /dev/null +++ b/naga/tests/out/ir/overrides-ray-query.ron @@ -0,0 +1,259 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: AccelerationStructure, + ), + ( + name: None, + inner: RayQuery, + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: Some("RayDesc"), + inner: Struct( + members: [ + ( + name: Some("flags"), + ty: 4, + binding: None, + offset: 0, + ), + ( + name: Some("cull_mask"), + ty: 4, + binding: None, + offset: 4, + ), + ( + name: Some("tmin"), + ty: 1, + binding: None, + offset: 8, + ), + ( + name: Some("tmax"), + ty: 1, + binding: None, + offset: 12, + ), + ( + name: Some("origin"), + ty: 5, + binding: None, + offset: 16, + ), + ( + name: Some("dir"), + ty: 5, + binding: None, + offset: 32, + ), + ], + span: 48, + ), + ), + ], + special_types: ( + ray_desc: Some(6), + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("acc_struct"), + space: Handle, + binding: Some(( + group: 0, + binding: 0, + )), + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("rq"), + ty: 3, + init: None, + ), + ], + expressions: [ + LocalVariable(1), + Literal(U32(4)), + Literal(U32(255)), + Override(1), + Literal(F32(17.0)), + Binary( + op: Multiply, + left: 4, + right: 5, + ), + Override(1), + Literal(F32(19.0)), + Binary( + op: Multiply, + left: 7, + right: 8, + ), + Override(1), + Literal(F32(23.0)), + Binary( + op: Multiply, + left: 10, + right: 11, + ), + Splat( + size: Tri, + value: 12, + ), + Override(1), + Literal(F32(29.0)), + Binary( + op: Multiply, + left: 14, + right: 15, + ), + Override(1), + Literal(F32(31.0)), + Binary( + op: Multiply, + left: 17, + right: 18, + ), + Override(1), + Literal(F32(37.0)), + Binary( + op: Multiply, + left: 20, + right: 21, + ), + Compose( + ty: 5, + components: [ + 16, + 19, + 22, + ], + ), + Compose( + ty: 6, + components: [ + 2, + 3, + 6, + 9, + 13, + 23, + ], + ), + GlobalVariable(1), + RayQueryProceedResult, + ], + named_expressions: { + 24: "desc", + }, + body: [ + Emit(( + start: 5, + end: 6, + )), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 11, + end: 13, + )), + Emit(( + start: 15, + end: 16, + )), + Emit(( + start: 18, + end: 19, + )), + Emit(( + start: 21, + end: 24, + )), + RayQuery( + query: 1, + fun: Initialize( + acceleration_structure: 25, + descriptor: 24, + ), + ), + Loop( + body: [ + RayQuery( + query: 1, + fun: Proceed( + result: 26, + ), + ), + If( + condition: 26, + accept: [], + reject: [ + Break, + ], + ), + Block([]), + ], + continuing: [], + break_if: None, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/msl/overrides-ray-query.msl b/naga/tests/out/msl/overrides-ray-query.msl new file mode 100644 index 0000000000..3a508b6f61 --- /dev/null +++ b/naga/tests/out/msl/overrides-ray-query.msl @@ -0,0 +1,45 @@ +// language: metal2.4 +#include +#include + +using metal::uint; +struct _RayQuery { + metal::raytracing::intersector intersector; + metal::raytracing::intersector::result_type intersection; + bool ready = false; +}; +constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) { + return ty==metal::raytracing::intersection_type::triangle ? 1 : + ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0; +} + +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; +constant float o = 2.0; + +kernel void main_( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +) { + _RayQuery rq = {}; + RayDesc desc = RayDesc {4u, 255u, 34.0, 38.0, metal::float3(46.0), metal::float3(58.0, 62.0, 74.0)}; + rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq.intersector.set_opacity_cull_mode((desc.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (desc.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq.intersector.force_opacity((desc.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (desc.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq.intersector.accept_any_intersection((desc.flags & 4) != 0); + rq.intersection = rq.intersector.intersect(metal::raytracing::ray(desc.origin, desc.dir, desc.tmin, desc.tmax), acc_struct, desc.cull_mask); rq.ready = true; + while(true) { + bool _e31 = rq.ready; + rq.ready = false; + if (_e31) { + } else { + break; + } + } + return; +} diff --git a/naga/tests/out/spv/overrides-ray-query.main.spvasm b/naga/tests/out/spv/overrides-ray-query.main.spvasm new file mode 100644 index 0000000000..a341393468 --- /dev/null +++ b/naga/tests/out/spv/overrides-ray-query.main.spvasm @@ -0,0 +1,77 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 46 +OpCapability Shader +OpCapability RayQueryKHR +OpExtension "SPV_KHR_ray_query" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %13 "main" %10 +OpExecutionMode %13 LocalSize 1 1 1 +OpMemberDecorate %8 0 Offset 0 +OpMemberDecorate %8 1 Offset 4 +OpMemberDecorate %8 2 Offset 8 +OpMemberDecorate %8 3 Offset 12 +OpMemberDecorate %8 4 Offset 16 +OpMemberDecorate %8 5 Offset 32 +OpDecorate %10 DescriptorSet 0 +OpDecorate %10 Binding 0 +%2 = OpTypeVoid +%3 = OpTypeFloat 32 +%4 = OpTypeAccelerationStructureNV +%5 = OpTypeRayQueryKHR +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %3 3 +%8 = OpTypeStruct %6 %6 %3 %3 %7 %7 +%9 = OpConstant %3 2.0 +%11 = OpTypePointer UniformConstant %4 +%10 = OpVariable %11 UniformConstant +%14 = OpTypeFunction %2 +%16 = OpConstant %6 4 +%17 = OpConstant %6 255 +%18 = OpConstant %3 34.0 +%19 = OpConstant %3 38.0 +%20 = OpConstant %3 46.0 +%21 = OpConstantComposite %7 %20 %20 %20 +%22 = OpConstant %3 58.0 +%23 = OpConstant %3 62.0 +%24 = OpConstant %3 74.0 +%25 = OpConstantComposite %7 %22 %23 %24 +%26 = OpConstantComposite %8 %16 %17 %18 %19 %21 %25 +%28 = OpTypePointer Function %5 +%41 = OpTypeBool +%13 = OpFunction %2 None %14 +%12 = OpLabel +%27 = OpVariable %28 Function +%15 = OpLoad %4 %10 +OpBranch %29 +%29 = OpLabel +%30 = OpCompositeExtract %6 %26 0 +%31 = OpCompositeExtract %6 %26 1 +%32 = OpCompositeExtract %3 %26 2 +%33 = OpCompositeExtract %3 %26 3 +%34 = OpCompositeExtract %7 %26 4 +%35 = OpCompositeExtract %7 %26 5 +OpRayQueryInitializeKHR %27 %15 %30 %31 %34 %32 %35 %33 +OpBranch %36 +%36 = OpLabel +OpLoopMerge %37 %39 None +OpBranch %38 +%38 = OpLabel +%40 = OpRayQueryProceedKHR %41 %27 +OpSelectionMerge %42 None +OpBranchConditional %40 %42 %43 +%43 = OpLabel +OpBranch %37 +%42 = OpLabel +OpBranch %44 +%44 = OpLabel +OpBranch %45 +%45 = OpLabel +OpBranch %39 +%39 = OpLabel +OpBranch %36 +%37 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 151e8b3da3..94c50c7975 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -466,6 +466,7 @@ fn write_output_spv( ); } } else { + assert!(pipeline_constants.is_empty()); write_output_spv_inner(input, module, info, &options, None, "spvasm"); } } @@ -857,6 +858,10 @@ fn convert_wgsl() { "overrides-atomicCompareExchangeWeak", Targets::IR | Targets::SPIRV, ), + ( + "overrides-ray-query", + Targets::IR | Targets::SPIRV | Targets::METAL, + ), ]; for &(name, targets) in inputs.iter() {