diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index f4b58bc89a..be7c503a01 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -4,8 +4,8 @@ use arrayvec::ArrayVec; use crate::{ arena::{Arena, Handle, HandleVec, UniqueArena}, - ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type, - TypeInner, UnaryOperator, + ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction, + ScalarKind, Span, Type, TypeInner, UnaryOperator, }; /// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating @@ -547,6 +547,8 @@ pub enum ConstantEvaluatorError { InvalidMathArg, #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")] InvalidMathArgCount(crate::MathFunction, usize, usize), + #[error("Cannot apply relational function to type")] + InvalidRelationalArg(RelationalFunction), #[error("value of `low` is greater than `high` for clamp built-in function")] InvalidClamp, #[error("Splat is defined only on scalar values")] @@ -931,9 +933,10 @@ impl<'a> ConstantEvaluator<'a> { Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented( "select built-in function".into(), )), - Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented( - format!("{fun:?} built-in function"), - )), + Expression::Relational { fun, argument } => { + let argument = self.check_and_get(argument)?; + self.relational(fun, argument, span) + } Expression::ArrayLength(expr) => match self.behavior { Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), Behavior::Glsl(_) => { @@ -2103,6 +2106,41 @@ impl<'a> ConstantEvaluator<'a> { Ok(Expression::Compose { ty, components }) } + fn relational( + &mut self, + fun: RelationalFunction, + arg: Handle, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let arg = self.eval_zero_value_and_splat(arg, span)?; + match fun { + RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] { + Expression::Literal(Literal::Bool(_)) => Ok(arg), + Expression::Compose { ty, ref components } + if matches!(self.types[ty].inner, TypeInner::Vector { .. }) => + { + let components = + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .map(|component| match self.expressions[component] { + Expression::Literal(Literal::Bool(val)) => Ok(val), + _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)), + }) + .collect::, _>>()?; + let result = match fun { + RelationalFunction::All => components.iter().all(|c| *c), + RelationalFunction::Any => components.iter().any(|c| *c), + _ => unreachable!(), + }; + self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span) + } + _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)), + }, + _ => Err(ConstantEvaluatorError::NotImplemented(format!( + "{fun:?} built-in function" + ))), + } + } + /// Deep copy `expr` from `expressions` into `self.expressions`. /// /// Return the root of the new copy. diff --git a/naga/tests/in/const-exprs.wgsl b/naga/tests/in/const-exprs.wgsl index 5165c49ceb..72e5726355 100644 --- a/naga/tests/in/const-exprs.wgsl +++ b/naga/tests/in/const-exprs.wgsl @@ -1,5 +1,7 @@ const TWO: u32 = 2u; const THREE: i32 = 3i; +const TRUE = true; +const FALSE = false; @compute @workgroup_size(TWO, THREE, TWO - 1u) fn main() { @@ -94,3 +96,16 @@ fn compose_vector_zero_val_binop() { var b = vec3(vec2i(), 0) + vec3(0, 1, 2); var c = vec3(vec2i(), 2) + vec3(1, vec2i()); } + +fn relational() { + // Test scalar and vector forms of any() and all(), with a mixture of + // consts, literals, zero-values, composes, and splats. + var scalar_any_false = any(false); + var scalar_any_true = any(true); + var scalar_all_false = all(false); + var scalar_all_true = all(true); + var vec_any_false = any(vec4()); + var vec_any_true = any(vec4(bool(), true, vec2(FALSE))); + var vec_all_false = all(vec4(vec3(vec2(), TRUE), false)); + var vec_all_true = all(vec4(true)); +} diff --git a/naga/tests/out/glsl/const-exprs.main.Compute.glsl b/naga/tests/out/glsl/const-exprs.main.Compute.glsl index 4b473bed7c..f2607ce0eb 100644 --- a/naga/tests/out/glsl/const-exprs.main.Compute.glsl +++ b/naga/tests/out/glsl/const-exprs.main.Compute.glsl @@ -7,6 +7,8 @@ layout(local_size_x = 2, local_size_y = 3, local_size_z = 1) in; const uint TWO = 2u; const int THREE = 3; +const bool TRUE = true; +const bool FALSE = false; const int FOUR = 4; const int FOUR_ALIAS = 4; const int TEST_CONSTANT_ADDITION = 8; @@ -93,6 +95,18 @@ void compose_vector_zero_val_binop() { return; } +void relational() { + bool scalar_any_false = false; + bool scalar_any_true = true; + bool scalar_all_false = false; + bool scalar_all_true = true; + bool vec_any_false = false; + bool vec_any_true = true; + bool vec_all_false = false; + bool vec_all_true = true; + return; +} + void main() { swizzle_of_compose(); index_of_compose(); diff --git a/naga/tests/out/hlsl/const-exprs.hlsl b/naga/tests/out/hlsl/const-exprs.hlsl index ab2160cb19..e0c4da5c77 100644 --- a/naga/tests/out/hlsl/const-exprs.hlsl +++ b/naga/tests/out/hlsl/const-exprs.hlsl @@ -1,5 +1,7 @@ static const uint TWO = 2u; static const int THREE = int(3); +static const bool TRUE = true; +static const bool FALSE = false; static const int FOUR = int(4); static const int FOUR_ALIAS = int(4); static const int TEST_CONSTANT_ADDITION = int(8); @@ -102,6 +104,20 @@ void compose_vector_zero_val_binop() return; } +void relational() +{ + bool scalar_any_false = false; + bool scalar_any_true = true; + bool scalar_all_false = false; + bool scalar_all_true = true; + bool vec_any_false = false; + bool vec_any_true = true; + bool vec_all_false = false; + bool vec_all_true = true; + + return; +} + [numthreads(2, 3, 1)] void main() { diff --git a/naga/tests/out/msl/const-exprs.msl b/naga/tests/out/msl/const-exprs.msl index dc0c394868..71dd868f4e 100644 --- a/naga/tests/out/msl/const-exprs.msl +++ b/naga/tests/out/msl/const-exprs.msl @@ -6,6 +6,8 @@ using metal::uint; constant uint TWO = 2u; constant int THREE = 3; +constant bool TRUE = true; +constant bool FALSE = false; constant int FOUR = 4; constant int FOUR_ALIAS = 4; constant int TEST_CONSTANT_ADDITION = 8; @@ -101,6 +103,19 @@ void compose_vector_zero_val_binop( return; } +void relational( +) { + bool scalar_any_false = false; + bool scalar_any_true = true; + bool scalar_all_false = false; + bool scalar_all_true = true; + bool vec_any_false = false; + bool vec_any_true = true; + bool vec_all_false = false; + bool vec_all_true = true; + return; +} + kernel void main_( ) { swizzle_of_compose(); diff --git a/naga/tests/out/spv/const-exprs.spvasm b/naga/tests/out/spv/const-exprs.spvasm index 63b256d5de..83777667d9 100644 --- a/naga/tests/out/spv/const-exprs.spvasm +++ b/naga/tests/out/spv/const-exprs.spvasm @@ -1,66 +1,67 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 120 +; Bound: 132 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %111 "main" -OpExecutionMode %111 LocalSize 2 3 1 +OpEntryPoint GLCompute %123 "main" +OpExecutionMode %123 LocalSize 2 3 1 %2 = OpTypeVoid %3 = OpTypeInt 32 0 %4 = OpTypeInt 32 1 -%5 = OpTypeVector %4 4 -%6 = OpTypeFloat 32 -%7 = OpTypeVector %6 4 -%8 = OpTypeVector %6 2 -%10 = OpTypeBool -%9 = OpTypeVector %10 2 +%5 = OpTypeBool +%6 = OpTypeVector %4 4 +%7 = OpTypeFloat 32 +%8 = OpTypeVector %7 4 +%9 = OpTypeVector %7 2 +%10 = OpTypeVector %5 2 %11 = OpTypeVector %4 3 %12 = OpConstant %3 2 %13 = OpConstant %4 3 -%14 = OpConstant %4 4 -%15 = OpConstant %4 8 -%16 = OpConstant %6 3.141 -%17 = OpConstant %6 6.282 -%18 = OpConstant %6 0.44444445 -%19 = OpConstant %6 0.0 -%20 = OpConstantComposite %7 %18 %19 %19 %19 -%21 = OpConstant %4 0 -%22 = OpConstant %4 1 -%23 = OpConstant %4 2 -%24 = OpConstant %6 4.0 -%25 = OpConstant %6 5.0 -%26 = OpConstantComposite %8 %24 %25 -%27 = OpConstantTrue %10 -%28 = OpConstantFalse %10 -%29 = OpConstantComposite %9 %27 %28 +%14 = OpConstantTrue %5 +%15 = OpConstantFalse %5 +%16 = OpConstant %4 4 +%17 = OpConstant %4 8 +%18 = OpConstant %7 3.141 +%19 = OpConstant %7 6.282 +%20 = OpConstant %7 0.44444445 +%21 = OpConstant %7 0.0 +%22 = OpConstantComposite %8 %20 %21 %21 %21 +%23 = OpConstant %4 0 +%24 = OpConstant %4 1 +%25 = OpConstant %4 2 +%26 = OpConstant %7 4.0 +%27 = OpConstant %7 5.0 +%28 = OpConstantComposite %9 %26 %27 +%29 = OpConstantComposite %10 %14 %15 %32 = OpTypeFunction %2 -%33 = OpConstantComposite %5 %14 %13 %23 %22 -%35 = OpTypePointer Function %5 +%33 = OpConstantComposite %6 %16 %13 %25 %24 +%35 = OpTypePointer Function %6 %40 = OpTypePointer Function %4 %44 = OpConstant %4 6 %49 = OpConstant %4 30 %50 = OpConstant %4 70 %53 = OpConstantNull %4 %55 = OpConstantNull %4 -%58 = OpConstantNull %5 +%58 = OpConstantNull %6 %69 = OpConstant %4 -4 -%70 = OpConstantComposite %5 %69 %69 %69 %69 -%79 = OpConstant %6 1.0 -%80 = OpConstant %6 2.0 -%81 = OpConstantComposite %7 %80 %79 %79 %79 -%83 = OpTypePointer Function %7 +%70 = OpConstantComposite %6 %69 %69 %69 %69 +%79 = OpConstant %7 1.0 +%80 = OpConstant %7 2.0 +%81 = OpConstantComposite %8 %80 %79 %79 %79 +%83 = OpTypePointer Function %8 %88 = OpTypeFunction %3 %4 %89 = OpConstant %3 10 %90 = OpConstant %3 20 %91 = OpConstant %3 30 %92 = OpConstant %3 0 %99 = OpConstantNull %3 -%102 = OpConstantComposite %11 %22 %22 %22 -%103 = OpConstantComposite %11 %21 %22 %23 -%104 = OpConstantComposite %11 %22 %21 %23 +%102 = OpConstantComposite %11 %24 %24 %24 +%103 = OpConstantComposite %11 %23 %24 %25 +%104 = OpConstantComposite %11 %24 %23 %25 %106 = OpTypePointer Function %11 +%113 = OpTypePointer Function %5 %31 = OpFunction %2 None %32 %30 = OpLabel %34 = OpVariable %35 Function %33 @@ -70,7 +71,7 @@ OpReturn OpFunctionEnd %38 = OpFunction %2 None %32 %37 = OpLabel -%39 = OpVariable %40 Function %23 +%39 = OpVariable %40 Function %25 OpBranch %41 %41 = OpLabel OpReturn @@ -99,7 +100,7 @@ OpStore %54 %61 %63 = OpLoad %4 %52 %64 = OpLoad %4 %54 %65 = OpLoad %4 %56 -%66 = OpCompositeConstruct %5 %62 %63 %64 %65 +%66 = OpCompositeConstruct %6 %62 %63 %64 %65 OpStore %57 %66 OpReturn OpFunctionEnd @@ -153,14 +154,28 @@ OpReturn OpFunctionEnd %111 = OpFunction %2 None %32 %110 = OpLabel -OpBranch %112 -%112 = OpLabel -%113 = OpFunctionCall %2 %31 -%114 = OpFunctionCall %2 %38 -%115 = OpFunctionCall %2 %43 -%116 = OpFunctionCall %2 %48 -%117 = OpFunctionCall %2 %68 -%118 = OpFunctionCall %2 %74 -%119 = OpFunctionCall %2 %78 +%119 = OpVariable %113 Function %15 +%116 = OpVariable %113 Function %14 +%112 = OpVariable %113 Function %15 +%120 = OpVariable %113 Function %14 +%117 = OpVariable %113 Function %15 +%114 = OpVariable %113 Function %14 +%118 = OpVariable %113 Function %14 +%115 = OpVariable %113 Function %15 +OpBranch %121 +%121 = OpLabel +OpReturn +OpFunctionEnd +%123 = OpFunction %2 None %32 +%122 = OpLabel +OpBranch %124 +%124 = OpLabel +%125 = OpFunctionCall %2 %31 +%126 = OpFunctionCall %2 %38 +%127 = OpFunctionCall %2 %43 +%128 = OpFunctionCall %2 %48 +%129 = OpFunctionCall %2 %68 +%130 = OpFunctionCall %2 %74 +%131 = OpFunctionCall %2 %78 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/const-exprs.wgsl b/naga/tests/out/wgsl/const-exprs.wgsl index 4649807eb8..9fcbb588c9 100644 --- a/naga/tests/out/wgsl/const-exprs.wgsl +++ b/naga/tests/out/wgsl/const-exprs.wgsl @@ -1,5 +1,7 @@ const TWO: u32 = 2u; const THREE: i32 = 3i; +const TRUE: bool = true; +const FALSE: bool = false; const FOUR: i32 = 4i; const FOUR_ALIAS: i32 = 4i; const TEST_CONSTANT_ADDITION: i32 = 8i; @@ -93,6 +95,19 @@ fn compose_vector_zero_val_binop() { return; } +fn relational() { + var scalar_any_false: bool = false; + var scalar_any_true: bool = true; + var scalar_all_false: bool = false; + var scalar_all_true: bool = true; + var vec_any_false: bool = false; + var vec_any_true: bool = true; + var vec_all_false: bool = false; + var vec_all_true: bool = true; + + return; +} + @compute @workgroup_size(2, 3, 1) fn main() { swizzle_of_compose();