Skip to content

[naga wgsl-in] Implement any() and all() during const evaluation #7166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use arrayvec::ArrayVec;

use crate::{
arena::{Arena, Handle, HandleVec, UniqueArena},
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
TypeInner, UnaryOperator,
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
ScalarKind, Span, Type, TypeInner, UnaryOperator,
};

/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
Expand Down Expand Up @@ -547,6 +547,8 @@ pub enum ConstantEvaluatorError {
InvalidMathArg,
#[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
InvalidMathArgCount(crate::MathFunction, usize, usize),
#[error("Cannot apply relational function to type")]
InvalidRelationalArg(RelationalFunction),
#[error("value of `low` is greater than `high` for clamp built-in function")]
InvalidClamp,
#[error("Splat is defined only on scalar values")]
Expand Down Expand Up @@ -931,9 +933,10 @@ impl<'a> ConstantEvaluator<'a> {
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
"select built-in function".into(),
)),
Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
format!("{fun:?} built-in function"),
)),
Expression::Relational { fun, argument } => {
let argument = self.check_and_get(argument)?;
self.relational(fun, argument, span)
}
Expression::ArrayLength(expr) => match self.behavior {
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
Behavior::Glsl(_) => {
Expand Down Expand Up @@ -2103,6 +2106,41 @@ impl<'a> ConstantEvaluator<'a> {
Ok(Expression::Compose { ty, components })
}

fn relational(
&mut self,
fun: RelationalFunction,
arg: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let arg = self.eval_zero_value_and_splat(arg, span)?;
match fun {
RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
Expression::Literal(Literal::Bool(_)) => Ok(arg),
Expression::Compose { ty, ref components }
if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
{
let components =
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
.map(|component| match self.expressions[component] {
Expression::Literal(Literal::Bool(val)) => Ok(val),
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
})
.collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
let result = match fun {
RelationalFunction::All => components.iter().all(|c| *c),
RelationalFunction::Any => components.iter().any(|c| *c),
_ => unreachable!(),
};
self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
}
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
},
_ => Err(ConstantEvaluatorError::NotImplemented(format!(
"{fun:?} built-in function"
))),
}
}

/// Deep copy `expr` from `expressions` into `self.expressions`.
///
/// Return the root of the new copy.
Expand Down
15 changes: 15 additions & 0 deletions naga/tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
const TWO: u32 = 2u;
const THREE: i32 = 3i;
const TRUE = true;
const FALSE = false;

@compute @workgroup_size(TWO, THREE, TWO - 1u)
fn main() {
Expand Down Expand Up @@ -94,3 +96,16 @@ fn compose_vector_zero_val_binop() {
var b = vec3(vec2i(), 0) + vec3(0, 1, 2);
var c = vec3(vec2i(), 2) + vec3(1, vec2i());
}

fn relational() {
// Test scalar and vector forms of any() and all(), with a mixture of
// consts, literals, zero-values, composes, and splats.
var scalar_any_false = any(false);
var scalar_any_true = any(true);
var scalar_all_false = all(false);
var scalar_all_true = all(true);
var vec_any_false = any(vec4<bool>());
var vec_any_true = any(vec4(bool(), true, vec2(FALSE)));
var vec_all_false = all(vec4(vec3(vec2<bool>(), TRUE), false));
var vec_all_true = all(vec4(true));
}
14 changes: 14 additions & 0 deletions naga/tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ layout(local_size_x = 2, local_size_y = 3, local_size_z = 1) in;

const uint TWO = 2u;
const int THREE = 3;
const bool TRUE = true;
const bool FALSE = false;
const int FOUR = 4;
const int FOUR_ALIAS = 4;
const int TEST_CONSTANT_ADDITION = 8;
Expand Down Expand Up @@ -93,6 +95,18 @@ void compose_vector_zero_val_binop() {
return;
}

void relational() {
bool scalar_any_false = false;
bool scalar_any_true = true;
bool scalar_all_false = false;
bool scalar_all_true = true;
bool vec_any_false = false;
bool vec_any_true = true;
bool vec_all_false = false;
bool vec_all_true = true;
return;
}

void main() {
swizzle_of_compose();
index_of_compose();
Expand Down
16 changes: 16 additions & 0 deletions naga/tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
static const uint TWO = 2u;
static const int THREE = int(3);
static const bool TRUE = true;
static const bool FALSE = false;
static const int FOUR = int(4);
static const int FOUR_ALIAS = int(4);
static const int TEST_CONSTANT_ADDITION = int(8);
Expand Down Expand Up @@ -102,6 +104,20 @@ void compose_vector_zero_val_binop()
return;
}

void relational()
{
bool scalar_any_false = false;
bool scalar_any_true = true;
bool scalar_all_false = false;
bool scalar_all_true = true;
bool vec_any_false = false;
bool vec_any_true = true;
bool vec_all_false = false;
bool vec_all_true = true;

return;
}

[numthreads(2, 3, 1)]
void main()
{
Expand Down
15 changes: 15 additions & 0 deletions naga/tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using metal::uint;

constant uint TWO = 2u;
constant int THREE = 3;
constant bool TRUE = true;
constant bool FALSE = false;
constant int FOUR = 4;
constant int FOUR_ALIAS = 4;
constant int TEST_CONSTANT_ADDITION = 8;
Expand Down Expand Up @@ -101,6 +103,19 @@ void compose_vector_zero_val_binop(
return;
}

void relational(
) {
bool scalar_any_false = false;
bool scalar_any_true = true;
bool scalar_all_false = false;
bool scalar_all_true = true;
bool vec_any_false = false;
bool vec_any_true = true;
bool vec_all_false = false;
bool vec_all_true = true;
return;
}

kernel void main_(
) {
swizzle_of_compose();
Expand Down
109 changes: 62 additions & 47 deletions naga/tests/out/spv/const-exprs.spvasm
Original file line number Diff line number Diff line change
@@ -1,66 +1,67 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 120
; Bound: 132
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %111 "main"
OpExecutionMode %111 LocalSize 2 3 1
OpEntryPoint GLCompute %123 "main"
OpExecutionMode %123 LocalSize 2 3 1
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpTypeInt 32 1
%5 = OpTypeVector %4 4
%6 = OpTypeFloat 32
%7 = OpTypeVector %6 4
%8 = OpTypeVector %6 2
%10 = OpTypeBool
%9 = OpTypeVector %10 2
%5 = OpTypeBool
%6 = OpTypeVector %4 4
%7 = OpTypeFloat 32
%8 = OpTypeVector %7 4
%9 = OpTypeVector %7 2
%10 = OpTypeVector %5 2
%11 = OpTypeVector %4 3
%12 = OpConstant %3 2
%13 = OpConstant %4 3
%14 = OpConstant %4 4
%15 = OpConstant %4 8
%16 = OpConstant %6 3.141
%17 = OpConstant %6 6.282
%18 = OpConstant %6 0.44444445
%19 = OpConstant %6 0.0
%20 = OpConstantComposite %7 %18 %19 %19 %19
%21 = OpConstant %4 0
%22 = OpConstant %4 1
%23 = OpConstant %4 2
%24 = OpConstant %6 4.0
%25 = OpConstant %6 5.0
%26 = OpConstantComposite %8 %24 %25
%27 = OpConstantTrue %10
%28 = OpConstantFalse %10
%29 = OpConstantComposite %9 %27 %28
%14 = OpConstantTrue %5
%15 = OpConstantFalse %5
%16 = OpConstant %4 4
%17 = OpConstant %4 8
%18 = OpConstant %7 3.141
%19 = OpConstant %7 6.282
%20 = OpConstant %7 0.44444445
%21 = OpConstant %7 0.0
%22 = OpConstantComposite %8 %20 %21 %21 %21
%23 = OpConstant %4 0
%24 = OpConstant %4 1
%25 = OpConstant %4 2
%26 = OpConstant %7 4.0
%27 = OpConstant %7 5.0
%28 = OpConstantComposite %9 %26 %27
%29 = OpConstantComposite %10 %14 %15
%32 = OpTypeFunction %2
%33 = OpConstantComposite %5 %14 %13 %23 %22
%35 = OpTypePointer Function %5
%33 = OpConstantComposite %6 %16 %13 %25 %24
%35 = OpTypePointer Function %6
%40 = OpTypePointer Function %4
%44 = OpConstant %4 6
%49 = OpConstant %4 30
%50 = OpConstant %4 70
%53 = OpConstantNull %4
%55 = OpConstantNull %4
%58 = OpConstantNull %5
%58 = OpConstantNull %6
%69 = OpConstant %4 -4
%70 = OpConstantComposite %5 %69 %69 %69 %69
%79 = OpConstant %6 1.0
%80 = OpConstant %6 2.0
%81 = OpConstantComposite %7 %80 %79 %79 %79
%83 = OpTypePointer Function %7
%70 = OpConstantComposite %6 %69 %69 %69 %69
%79 = OpConstant %7 1.0
%80 = OpConstant %7 2.0
%81 = OpConstantComposite %8 %80 %79 %79 %79
%83 = OpTypePointer Function %8
%88 = OpTypeFunction %3 %4
%89 = OpConstant %3 10
%90 = OpConstant %3 20
%91 = OpConstant %3 30
%92 = OpConstant %3 0
%99 = OpConstantNull %3
%102 = OpConstantComposite %11 %22 %22 %22
%103 = OpConstantComposite %11 %21 %22 %23
%104 = OpConstantComposite %11 %22 %21 %23
%102 = OpConstantComposite %11 %24 %24 %24
%103 = OpConstantComposite %11 %23 %24 %25
%104 = OpConstantComposite %11 %24 %23 %25
%106 = OpTypePointer Function %11
%113 = OpTypePointer Function %5
%31 = OpFunction %2 None %32
%30 = OpLabel
%34 = OpVariable %35 Function %33
Expand All @@ -70,7 +71,7 @@ OpReturn
OpFunctionEnd
%38 = OpFunction %2 None %32
%37 = OpLabel
%39 = OpVariable %40 Function %23
%39 = OpVariable %40 Function %25
OpBranch %41
%41 = OpLabel
OpReturn
Expand Down Expand Up @@ -99,7 +100,7 @@ OpStore %54 %61
%63 = OpLoad %4 %52
%64 = OpLoad %4 %54
%65 = OpLoad %4 %56
%66 = OpCompositeConstruct %5 %62 %63 %64 %65
%66 = OpCompositeConstruct %6 %62 %63 %64 %65
OpStore %57 %66
OpReturn
OpFunctionEnd
Expand Down Expand Up @@ -153,14 +154,28 @@ OpReturn
OpFunctionEnd
%111 = OpFunction %2 None %32
%110 = OpLabel
OpBranch %112
%112 = OpLabel
%113 = OpFunctionCall %2 %31
%114 = OpFunctionCall %2 %38
%115 = OpFunctionCall %2 %43
%116 = OpFunctionCall %2 %48
%117 = OpFunctionCall %2 %68
%118 = OpFunctionCall %2 %74
%119 = OpFunctionCall %2 %78
%119 = OpVariable %113 Function %15
%116 = OpVariable %113 Function %14
%112 = OpVariable %113 Function %15
%120 = OpVariable %113 Function %14
%117 = OpVariable %113 Function %15
%114 = OpVariable %113 Function %14
%118 = OpVariable %113 Function %14
%115 = OpVariable %113 Function %15
OpBranch %121
%121 = OpLabel
OpReturn
OpFunctionEnd
%123 = OpFunction %2 None %32
%122 = OpLabel
OpBranch %124
%124 = OpLabel
%125 = OpFunctionCall %2 %31
%126 = OpFunctionCall %2 %38
%127 = OpFunctionCall %2 %43
%128 = OpFunctionCall %2 %48
%129 = OpFunctionCall %2 %68
%130 = OpFunctionCall %2 %74
%131 = OpFunctionCall %2 %78
OpReturn
OpFunctionEnd
15 changes: 15 additions & 0 deletions naga/tests/out/wgsl/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
const TWO: u32 = 2u;
const THREE: i32 = 3i;
const TRUE: bool = true;
const FALSE: bool = false;
const FOUR: i32 = 4i;
const FOUR_ALIAS: i32 = 4i;
const TEST_CONSTANT_ADDITION: i32 = 8i;
Expand Down Expand Up @@ -93,6 +95,19 @@ fn compose_vector_zero_val_binop() {
return;
}

fn relational() {
var scalar_any_false: bool = false;
var scalar_any_true: bool = true;
var scalar_all_false: bool = false;
var scalar_all_true: bool = true;
var vec_any_false: bool = false;
var vec_any_true: bool = true;
var vec_all_false: bool = false;
var vec_all_true: bool = true;

return;
}

@compute @workgroup_size(2, 3, 1)
fn main() {
swizzle_of_compose();
Expand Down
Loading