Skip to content

Commit c4b4387

Browse files
committed
[naga wgsl-in] Support abstract operands to binary operators.
1 parent f2828ac commit c4b4387

11 files changed

+498
-28
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ This feature allowed you to call `global_id` on any wgpu opaque handle to get a
6868

6969
#### Naga
7070

71-
- Naga'sn WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).
71+
- Naga's WGSL front end now allows binary operators to produce values with abstract types, rather than concretizing thir operands. By @jimblandy in [#4850](https://github.com/gfx-rs/wgpu/pull/4850).
72+
73+
- Naga's WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).
74+
7275
- Naga constant evaluation can now process binary operators whose operands are both vectors. By @jimblandy in [#4861](https://github.com/gfx-rs/wgpu/pull/4861).
7376

7477
### Changes

naga/src/front/wgsl/error.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@ pub enum Error<'a> {
257257
source_span: Span,
258258
source_type: String,
259259
},
260+
AutoConversionLeafScalar {
261+
dest_span: Span,
262+
dest_scalar: String,
263+
source_span: Span,
264+
source_type: String,
265+
},
260266
ConcretizationFailed {
261267
expr_span: Span,
262268
expr_type: String,
@@ -738,6 +744,20 @@ impl<'a> Error<'a> {
738744
],
739745
notes: vec![],
740746
},
747+
Error::AutoConversionLeafScalar { dest_span, ref dest_scalar, source_span, ref source_type } => ParseError {
748+
message: format!("automatic conversions cannot convert elements of `{source_type}` to `{dest_scalar}`"),
749+
labels: vec![
750+
(
751+
dest_span,
752+
format!("a value with elements of type {dest_scalar} is required here").into(),
753+
),
754+
(
755+
source_span,
756+
format!("this expression has type {source_type}").into(),
757+
)
758+
],
759+
notes: vec![],
760+
},
741761
Error::ConcretizationFailed { expr_span, ref expr_type, ref scalar, ref inner } => ParseError {
742762
message: format!("failed to convert expression to a concrete type: {}", inner),
743763
labels: vec![

naga/src/front/wgsl/lower/conversion.rs

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,80 @@ impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> {
5151
}
5252
};
5353

54-
let converted = if let crate::TypeInner::Array { .. } = *goal_inner {
55-
let span = self.get_expression_span(expr);
54+
self.convert_leaf_scalar(expr, expr_span, goal_scalar)
55+
}
56+
57+
/// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions.
58+
///
59+
/// If no conversions are necessary, return `expr` unchanged.
60+
///
61+
/// If automatic conversions cannot convert `expr` to `goal_scalar`, return
62+
/// an [`AutoConversionLeafScalar`] error.
63+
///
64+
/// Although the Load Rule is one of the automatic conversions, this
65+
/// function assumes it has already been applied if appropriate, as
66+
/// indicated by the fact that the Rust type of `expr` is not `Typed<_>`.
67+
///
68+
/// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar
69+
pub fn try_automatic_conversion_for_leaf_scalar(
70+
&mut self,
71+
expr: Handle<crate::Expression>,
72+
goal_scalar: crate::Scalar,
73+
goal_span: Span,
74+
) -> Result<Handle<crate::Expression>, super::Error<'source>> {
75+
let expr_span = self.get_expression_span(expr);
76+
let expr_resolution = super::resolve!(self, expr);
77+
let types = &self.module.types;
78+
let expr_inner = expr_resolution.inner_with(types);
79+
80+
let make_error = || {
81+
let gctx = &self.module.to_ctx();
82+
let source_type = expr_resolution.to_wgsl(gctx);
83+
super::Error::AutoConversionLeafScalar {
84+
dest_span: goal_span,
85+
dest_scalar: goal_scalar.to_wgsl(),
86+
source_span: expr_span,
87+
source_type,
88+
}
89+
};
90+
91+
let expr_scalar = match expr_inner.scalar() {
92+
Some(scalar) => scalar,
93+
None => return Err(make_error()),
94+
};
95+
96+
if expr_scalar == goal_scalar {
97+
return Ok(expr);
98+
}
99+
100+
if !expr_scalar.automatically_converts_to(goal_scalar) {
101+
return Err(make_error());
102+
}
103+
104+
assert!(expr_scalar.is_abstract());
105+
106+
self.convert_leaf_scalar(expr, expr_span, goal_scalar)
107+
}
108+
109+
fn convert_leaf_scalar(
110+
&mut self,
111+
expr: Handle<crate::Expression>,
112+
expr_span: Span,
113+
goal_scalar: crate::Scalar,
114+
) -> Result<Handle<crate::Expression>, super::Error<'source>> {
115+
let expr_inner = super::resolve_inner!(self, expr);
116+
if let crate::TypeInner::Array { .. } = *expr_inner {
56117
self.as_const_evaluator()
57-
.cast_array(expr, goal_scalar, span)
58-
.map_err(|err| super::Error::ConstantEvaluatorError(err, span))?
118+
.cast_array(expr, goal_scalar, expr_span)
119+
.map_err(|err| super::Error::ConstantEvaluatorError(err, expr_span))
59120
} else {
60121
let cast = crate::Expression::As {
61122
expr,
62123
kind: goal_scalar.kind,
63124
convert: Some(goal_scalar.width),
64125
};
65-
self.append_expression(cast, expr_span)?
66-
};
67-
68-
Ok(converted)
126+
self.append_expression(cast, expr_span)
127+
}
69128
}
70129

71130
/// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions.
@@ -428,6 +487,11 @@ impl crate::Scalar {
428487
}
429488
}
430489

490+
/// Return `true` if automatic conversions will covert `self` to `goal`.
491+
pub fn automatically_converts_to(self, goal: Self) -> bool {
492+
self.automatic_conversion_combine(goal) == Some(goal)
493+
}
494+
431495
const fn concretize(self) -> Self {
432496
use crate::ScalarKind as Sk;
433497
match self.kind {

naga/src/front/wgsl/lower/mod.rs

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,11 +1602,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16021602
return Ok(Typed::Reference(pointer));
16031603
}
16041604
ast::Expression::Binary { op, left, right } => {
1605-
// Load both operands.
1606-
let mut left = self.expression(left, ctx)?;
1607-
let mut right = self.expression(right, ctx)?;
1608-
ctx.binary_op_splat(op, &mut left, &mut right)?;
1609-
Typed::Plain(crate::Expression::Binary { op, left, right })
1605+
self.binary(op, left, right, span, ctx)?
16101606
}
16111607
ast::Expression::Call {
16121608
ref function,
@@ -1737,6 +1733,52 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
17371733
expr.try_map(|handle| ctx.append_expression(handle, span))
17381734
}
17391735

1736+
fn binary(
1737+
&mut self,
1738+
op: crate::BinaryOperator,
1739+
left: Handle<ast::Expression<'source>>,
1740+
right: Handle<ast::Expression<'source>>,
1741+
span: Span,
1742+
ctx: &mut ExpressionContext<'source, '_, '_>,
1743+
) -> Result<Typed<crate::Expression>, Error<'source>> {
1744+
// Load both operands.
1745+
let mut left = self.expression_for_abstract(left, ctx)?;
1746+
let mut right = self.expression_for_abstract(right, ctx)?;
1747+
1748+
// Convert `scalar op vector` to `vector op vector` by introducing
1749+
// `Splat` expressions.
1750+
ctx.binary_op_splat(op, &mut left, &mut right)?;
1751+
1752+
// Apply automatic conversions.
1753+
match op {
1754+
// Shift operators require the right operand to be `u32` or
1755+
// `vecN<u32>`. We can let the validator sort out vector length
1756+
// issues, but the right operand must be, or convert to, a u32 leaf
1757+
// scalar.
1758+
crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => {
1759+
right =
1760+
ctx.try_automatic_conversion_for_leaf_scalar(right, crate::Scalar::U32, span)?;
1761+
}
1762+
1763+
// All other operators follow the same pattern: reconcile the
1764+
// scalar leaf types. If there's no reconciliation possible,
1765+
// leave the expressions as they are: validation will report the
1766+
// problem.
1767+
_ => {
1768+
ctx.grow_types(left)?;
1769+
ctx.grow_types(right)?;
1770+
if let Ok(consensus_scalar) =
1771+
ctx.automatic_conversion_consensus([left, right].iter())
1772+
{
1773+
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
1774+
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
1775+
}
1776+
}
1777+
}
1778+
1779+
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
1780+
}
1781+
17401782
/// Generate Naga IR for call expressions and statements, and type
17411783
/// constructor expressions.
17421784
///

naga/src/proc/constant_evaluator.rs

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub enum ConstantEvaluatorError {
141141
InvalidAccessIndexTy,
142142
#[error("Constants don't support array length expressions")]
143143
ArrayLength,
144-
#[error("Cannot cast type `{from}` to `{to}`")]
144+
#[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
145145
InvalidCastArg { from: String, to: String },
146146
#[error("Cannot apply the unary op to the argument")]
147147
InvalidUnaryOpArg,
@@ -989,15 +989,11 @@ impl<'a> ConstantEvaluator<'a> {
989989
let expr = self.eval_zero_value(expr, span)?;
990990

991991
let make_error = || -> Result<_, ConstantEvaluatorError> {
992-
let ty = self.resolve_type(expr)?;
992+
let from = format!("{:?} {:?}", expr, self.expressions[expr]);
993993

994-
#[cfg(feature = "wgsl-in")]
995-
let from = ty.to_wgsl(&self.to_ctx());
996994
#[cfg(feature = "wgsl-in")]
997995
let to = target.to_wgsl();
998996

999-
#[cfg(not(feature = "wgsl-in"))]
1000-
let from = format!("{ty:?}");
1001997
#[cfg(not(feature = "wgsl-in"))]
1002998
let to = format!("{target:?}");
1003999

@@ -1325,6 +1321,47 @@ impl<'a> ConstantEvaluator<'a> {
13251321
BinaryOperator::Modulo => a % b,
13261322
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
13271323
}),
1324+
(Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
1325+
Literal::AbstractInt(match op {
1326+
BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
1327+
ConstantEvaluatorError::Overflow("addition".into())
1328+
})?,
1329+
BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
1330+
ConstantEvaluatorError::Overflow("subtraction".into())
1331+
})?,
1332+
BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
1333+
ConstantEvaluatorError::Overflow("multiplication".into())
1334+
})?,
1335+
BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
1336+
if b == 0 {
1337+
ConstantEvaluatorError::DivisionByZero
1338+
} else {
1339+
ConstantEvaluatorError::Overflow("division".into())
1340+
}
1341+
})?,
1342+
BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
1343+
if b == 0 {
1344+
ConstantEvaluatorError::RemainderByZero
1345+
} else {
1346+
ConstantEvaluatorError::Overflow("remainder".into())
1347+
}
1348+
})?,
1349+
BinaryOperator::And => a & b,
1350+
BinaryOperator::ExclusiveOr => a ^ b,
1351+
BinaryOperator::InclusiveOr => a | b,
1352+
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1353+
})
1354+
}
1355+
(Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
1356+
Literal::AbstractFloat(match op {
1357+
BinaryOperator::Add => a + b,
1358+
BinaryOperator::Subtract => a - b,
1359+
BinaryOperator::Multiply => a * b,
1360+
BinaryOperator::Divide => a / b,
1361+
BinaryOperator::Modulo => a % b,
1362+
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
1363+
})
1364+
}
13281365
(Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
13291366
BinaryOperator::LogicalAnd => a && b,
13301367
BinaryOperator::LogicalOr => a || b,
@@ -1550,7 +1587,10 @@ impl<'a> ConstantEvaluator<'a> {
15501587
};
15511588
Tr::Value(TypeInner::Vector { scalar, size })
15521589
}
1553-
_ => return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant),
1590+
_ => {
1591+
log::debug!("resolve_type: SubexpressionsAreNotConstant");
1592+
return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1593+
}
15541594
};
15551595

15561596
Ok(resolution)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
const plus_fafaf: f32 = 1.0 + 2.0;
2+
const plus_fafai: f32 = 1.0 + 2;
3+
const plus_faf_f: f32 = 1.0 + 2f;
4+
const plus_faiaf: f32 = 1 + 2.0;
5+
const plus_faiai: f32 = 1 + 2;
6+
const plus_fai_f: f32 = 1 + 2f;
7+
const plus_f_faf: f32 = 1f + 2.0;
8+
const plus_f_fai: f32 = 1f + 2;
9+
const plus_f_f_f: f32 = 1f + 2f;
10+
11+
const plus_iaiai: i32 = 1 + 2;
12+
const plus_iai_i: i32 = 1 + 2i;
13+
const plus_i_iai: i32 = 1i + 2;
14+
const plus_i_i_i: i32 = 1i + 2i;
15+
16+
const plus_uaiai: u32 = 1 + 2;
17+
const plus_uai_u: u32 = 1 + 2u;
18+
const plus_u_uai: u32 = 1u + 2;
19+
const plus_u_u_u: u32 = 1u + 2u;
20+
21+
fn runtime_values() {
22+
var f: f32 = 42;
23+
var i: i32 = 43;
24+
var u: u32 = 44;
25+
26+
var plus_fafaf: f32 = 1.0 + 2.0;
27+
var plus_fafai: f32 = 1.0 + 2;
28+
var plus_faf_f: f32 = 1.0 + f;
29+
var plus_faiaf: f32 = 1 + 2.0;
30+
var plus_faiai: f32 = 1 + 2;
31+
var plus_fai_f: f32 = 1 + f;
32+
var plus_f_faf: f32 = f + 2.0;
33+
var plus_f_fai: f32 = f + 2;
34+
var plus_f_f_f: f32 = f + f;
35+
36+
var plus_iaiai: i32 = 1 + 2;
37+
var plus_iai_i: i32 = 1 + i;
38+
var plus_i_iai: i32 = i + 2;
39+
var plus_i_i_i: i32 = i + i;
40+
41+
var plus_uaiai: u32 = 1 + 2;
42+
var plus_uai_u: u32 = 1 + u;
43+
var plus_u_uai: u32 = u + 2;
44+
var plus_u_u_u: u32 = u + u;
45+
}

0 commit comments

Comments
 (0)