Skip to content

Commit d977735

Browse files
jamienicoljimblandy
authored andcommitted
[naga const_eval] Ensure eval_zero_value_and_splat() lowers a Splat of a ZeroValue correctly
eval_zero_value_and_splat() is called to lower ZeroValue and Splat expressions into Literal and Compose expressions. However, in its current form it either calls splat() *or* eval_zero_value_impl() depending on the expression type. splat() will lower a Splat of a scalar ZeroValue to a vector ZeroValue, which means eval_zero_value_and_splat() can still return a ZeroValue. Its callers, such as binary_op(), are unable to handle this ZeroValue, so cannot proceed with const evaluation. This patch makes it so that eval_zero_value_and_splat() will first call splat(), *and then* call eval_zero_value_impl(), which will lower the vector ZeroValue returned by splat() into a Compose of Literals. Callers such as binary_op() are perfectly able to handle this Compose, so can now proceed with const evaluation.
1 parent 2a456f5 commit d977735

File tree

1 file changed

+90
-5
lines changed

1 file changed

+90
-5
lines changed

naga/src/proc/constant_evaluator.rs

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,14 +1398,19 @@ impl<'a> ConstantEvaluator<'a> {
13981398
/// [`Compose`]: Expression::Compose
13991399
fn eval_zero_value_and_splat(
14001400
&mut self,
1401-
expr: Handle<Expression>,
1401+
mut expr: Handle<Expression>,
14021402
span: Span,
14031403
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1404-
match self.expressions[expr] {
1405-
Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1406-
Expression::Splat { size, value } => self.splat(value, size, span),
1407-
_ => Ok(expr),
1404+
// The result of the splat() for a Splat of a scalar ZeroValue is a
1405+
// vector ZeroValue, so we must call eval_zero_value_impl() after
1406+
// splat() in order to ensure we have no ZeroValues remaining.
1407+
if let Expression::Splat { size, value } = self.expressions[expr] {
1408+
expr = self.splat(value, size, span)?;
14081409
}
1410+
if let Expression::ZeroValue(ty) = self.expressions[expr] {
1411+
expr = self.eval_zero_value_impl(ty, span)?;
1412+
}
1413+
Ok(expr)
14091414
}
14101415

14111416
/// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
@@ -2976,4 +2981,84 @@ mod tests {
29762981
panic!("unexpected evaluation result")
29772982
}
29782983
}
2984+
2985+
#[test]
2986+
fn splat_of_zero_value() {
2987+
let mut types = UniqueArena::new();
2988+
let constants = Arena::new();
2989+
let overrides = Arena::new();
2990+
let mut global_expressions = Arena::new();
2991+
2992+
let f32_ty = types.insert(
2993+
Type {
2994+
name: None,
2995+
inner: TypeInner::Scalar(crate::Scalar::F32),
2996+
},
2997+
Default::default(),
2998+
);
2999+
3000+
let vec2_f32_ty = types.insert(
3001+
Type {
3002+
name: None,
3003+
inner: TypeInner::Vector {
3004+
size: VectorSize::Bi,
3005+
scalar: crate::Scalar::F32,
3006+
},
3007+
},
3008+
Default::default(),
3009+
);
3010+
3011+
let five =
3012+
global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
3013+
let five_splat = global_expressions.append(
3014+
Expression::Splat {
3015+
size: VectorSize::Bi,
3016+
value: five,
3017+
},
3018+
Default::default(),
3019+
);
3020+
let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
3021+
let zero_splat = global_expressions.append(
3022+
Expression::Splat {
3023+
size: VectorSize::Bi,
3024+
value: zero,
3025+
},
3026+
Default::default(),
3027+
);
3028+
3029+
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3030+
let mut solver = ConstantEvaluator {
3031+
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3032+
types: &mut types,
3033+
constants: &constants,
3034+
overrides: &overrides,
3035+
expressions: &mut global_expressions,
3036+
expression_kind_tracker,
3037+
};
3038+
3039+
let solved_add = solver
3040+
.try_eval_and_append(
3041+
Expression::Binary {
3042+
op: crate::BinaryOperator::Add,
3043+
left: zero_splat,
3044+
right: five_splat,
3045+
},
3046+
Default::default(),
3047+
)
3048+
.unwrap();
3049+
3050+
let pass = match global_expressions[solved_add] {
3051+
Expression::Compose { ty, ref components } => {
3052+
ty == vec2_f32_ty
3053+
&& components.iter().all(|&component| {
3054+
let component = &global_expressions[component];
3055+
matches!(*component, Expression::Literal(Literal::F32(5.0)))
3056+
})
3057+
}
3058+
_ => false,
3059+
};
3060+
if !pass {
3061+
panic!("unexpected evaluation result")
3062+
}
3063+
}
29793064
}

0 commit comments

Comments
 (0)