Skip to content

Commit e5b7df3

Browse files
teoxoyjimblandy
andcommitted
[wgsl-in] add support for override declarations (#4793)
Co-authored-by: Jim Blandy <[email protected]>
1 parent 747fb30 commit e5b7df3

37 files changed

+515
-28
lines changed

naga/src/back/dot/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ fn write_function_expressions(
404404
let (label, color_id) = match *expression {
405405
E::Literal(_) => ("Literal".into(), 2),
406406
E::Constant(_) => ("Constant".into(), 2),
407+
E::Override(_) => ("Override".into(), 2),
407408
E::ZeroValue(_) => ("ZeroValue".into(), 2),
408409
E::Compose { ref components, .. } => {
409410
payload = Some(Payload::Arguments(components));

naga/src/back/glsl/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,6 +2538,7 @@ impl<'a, W: Write> Writer<'a, W> {
25382538
|writer, expr| writer.write_expr(expr, ctx),
25392539
)?;
25402540
}
2541+
Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())),
25412542
// `Access` is applied to arrays, vectors and matrices and is written as indexing
25422543
Expression::Access { base, index } => {
25432544
self.write_expr(base, ctx)?;

naga/src/back/hlsl/writer.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2141,6 +2141,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
21412141
|writer, expr| writer.write_expr(module, expr, func_ctx),
21422142
)?;
21432143
}
2144+
Expression::Override(_) => {
2145+
return Err(Error::Unimplemented("overrides are WIP".into()))
2146+
}
21442147
// All of the multiplication can be expressed as `mul`,
21452148
// except vector * vector, which needs to use the "*" operator.
21462149
Expression::Binary {

naga/src/back/msl/writer.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,9 @@ impl<W: Write> Writer<W> {
14311431
|writer, context, expr| writer.put_expression(expr, context, true),
14321432
)?;
14331433
}
1434+
crate::Expression::Override(_) => {
1435+
return Err(Error::FeatureNotImplemented("overrides are WIP".into()))
1436+
}
14341437
crate::Expression::Access { base, .. }
14351438
| crate::Expression::AccessIndex { base, .. } => {
14361439
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.

naga/src/back/spv/block.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> {
239239
let init = self.ir_module.constants[handle].init;
240240
self.writer.constant_ids[init.index()]
241241
}
242+
crate::Expression::Override(_) => {
243+
return Err(Error::FeatureNotImplemented("overrides are WIP"))
244+
}
242245
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
243246
crate::Expression::Compose { ty, ref components } => {
244247
self.temp_list.clear();

naga/src/back/wgsl/writer.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,9 @@ impl<W: Write> Writer<W> {
11991199
|writer, expr| writer.write_expr(module, expr, func_ctx),
12001200
)?;
12011201
}
1202+
Expression::Override(_) => {
1203+
return Err(Error::Unimplemented("overrides are WIP".into()))
1204+
}
12021205
Expression::FunctionArgument(pos) => {
12031206
let name_key = func_ctx.argument_key(pos);
12041207
let name = &self.names[&name_key];

naga/src/compact/expressions.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle};
33

44
pub struct ExpressionTracer<'tracer> {
55
pub constants: &'tracer Arena<crate::Constant>,
6+
pub overrides: &'tracer Arena<crate::Override>,
67

78
/// The arena in which we are currently tracing expressions.
89
pub expressions: &'tracer Arena<crate::Expression>,
@@ -88,6 +89,11 @@ impl<'tracer> ExpressionTracer<'tracer> {
8889
None => self.expressions_used.insert(init),
8990
}
9091
}
92+
Ex::Override(_) => {
93+
// All overrides are considered used by definition. We mark
94+
// their types and initialization expressions as used in
95+
// `compact::compact`, so we have no more work to do here.
96+
}
9197
Ex::ZeroValue(ty) => self.types_used.insert(ty),
9298
Ex::Compose { ty, ref components } => {
9399
self.types_used.insert(ty);
@@ -219,6 +225,9 @@ impl ModuleMap {
219225
| Ex::CallResult(_)
220226
| Ex::RayQueryProceedResult => {}
221227

228+
// All overrides are retained, so their handles never change.
229+
Ex::Override(_) => {}
230+
222231
// Expressions that contain handles that need to be adjusted.
223232
Ex::Constant(ref mut constant) => self.constants.adjust(constant),
224233
Ex::ZeroValue(ref mut ty) => self.types.adjust(ty),

naga/src/compact/functions.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::{FunctionMap, ModuleMap};
44
pub struct FunctionTracer<'a> {
55
pub function: &'a crate::Function,
66
pub constants: &'a crate::Arena<crate::Constant>,
7+
pub overrides: &'a crate::Arena<crate::Override>,
78

89
pub types_used: &'a mut HandleSet<crate::Type>,
910
pub constants_used: &'a mut HandleSet<crate::Constant>,
@@ -47,6 +48,7 @@ impl<'a> FunctionTracer<'a> {
4748
fn as_expression(&mut self) -> super::expressions::ExpressionTracer {
4849
super::expressions::ExpressionTracer {
4950
constants: self.constants,
51+
overrides: self.overrides,
5052
expressions: &self.function.expressions,
5153

5254
types_used: self.types_used,

naga/src/compact/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) {
5454
}
5555
}
5656

57+
// We treat all overrides as used by definition.
58+
for (_, override_) in module.overrides.iter() {
59+
module_tracer.types_used.insert(override_.ty);
60+
if let Some(init) = override_.init {
61+
module_tracer.const_expressions_used.insert(init);
62+
}
63+
}
64+
5765
// We assume that all functions are used.
5866
//
5967
// Observe which types, constant expressions, constants, and
@@ -158,6 +166,15 @@ pub fn compact(module: &mut crate::Module) {
158166
}
159167
});
160168

169+
// Adjust override types and initializers.
170+
log::trace!("adjusting overrides");
171+
for (_, override_) in module.overrides.iter_mut() {
172+
module_map.types.adjust(&mut override_.ty);
173+
if let Some(init) = override_.init.as_mut() {
174+
module_map.const_expressions.adjust(init);
175+
}
176+
}
177+
161178
// Adjust global variables' types and initializers.
162179
log::trace!("adjusting global variables");
163180
for (_, global) in module.global_variables.iter_mut() {
@@ -235,6 +252,7 @@ impl<'module> ModuleTracer<'module> {
235252
expressions::ExpressionTracer {
236253
expressions: &self.module.const_expressions,
237254
constants: &self.module.constants,
255+
overrides: &self.module.overrides,
238256
types_used: &mut self.types_used,
239257
constants_used: &mut self.constants_used,
240258
expressions_used: &mut self.const_expressions_used,
@@ -249,6 +267,7 @@ impl<'module> ModuleTracer<'module> {
249267
FunctionTracer {
250268
function,
251269
constants: &self.module.constants,
270+
overrides: &self.module.overrides,
252271
types_used: &mut self.types_used,
253272
constants_used: &mut self.constants_used,
254273
const_expressions_used: &mut self.const_expressions_used,

naga/src/front/spv/function.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
128128
expressions: &mut fun.expressions,
129129
local_arena: &mut fun.local_variables,
130130
const_arena: &mut module.constants,
131+
overrides: &mut module.overrides,
131132
const_expressions: &mut module.const_expressions,
132133
type_arena: &module.types,
133134
global_arena: &module.global_variables,
@@ -581,6 +582,7 @@ impl<'function> BlockContext<'function> {
581582
crate::proc::GlobalCtx {
582583
types: self.type_arena,
583584
constants: self.const_arena,
585+
overrides: self.overrides,
584586
const_expressions: self.const_expressions,
585587
}
586588
}

naga/src/front/spv/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ struct BlockContext<'function> {
531531
local_arena: &'function mut Arena<crate::LocalVariable>,
532532
/// Constants arena of the module being processed
533533
const_arena: &'function mut Arena<crate::Constant>,
534+
overrides: &'function mut Arena<crate::Override>,
534535
const_expressions: &'function mut Arena<crate::Expression>,
535536
/// Type arena of the module being processed
536537
type_arena: &'function UniqueArena<crate::Type>,
@@ -3934,7 +3935,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
39343935
Op::TypeImage => self.parse_type_image(inst, &mut module),
39353936
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
39363937
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
3937-
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
3938+
Op::Constant => self.parse_constant(inst, &mut module),
39383939
Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
39393940
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
39403941
Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),

naga/src/front/wgsl/error.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ pub enum Error<'a> {
190190
expected: String,
191191
got: String,
192192
},
193-
MissingType(Span),
193+
DeclMissingTypeAndInit(Span),
194194
MissingAttribute(&'static str, Span),
195195
InvalidAtomicPointer(Span),
196196
InvalidAtomicOperandType(Span),
@@ -273,6 +273,7 @@ pub enum Error<'a> {
273273
span: Span,
274274
limit: u8,
275275
},
276+
PipelineConstantIDValue(Span),
276277
}
277278

278279
impl<'a> Error<'a> {
@@ -522,11 +523,11 @@ impl<'a> Error<'a> {
522523
notes: vec![],
523524
}
524525
}
525-
Error::MissingType(name_span) => ParseError {
526-
message: format!("variable `{}` needs a type", &source[name_span]),
526+
Error::DeclMissingTypeAndInit(name_span) => ParseError {
527+
message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]),
527528
labels: vec![(
528529
name_span,
529-
format!("definition of `{}`", &source[name_span]).into(),
530+
"needs a type specifier or initializer".into(),
530531
)],
531532
notes: vec![],
532533
},
@@ -781,6 +782,14 @@ impl<'a> Error<'a> {
781782
format!("nesting limit is currently set to {limit}"),
782783
],
783784
},
785+
Error::PipelineConstantIDValue(span) => ParseError {
786+
message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(),
787+
labels: vec![(
788+
span,
789+
"must be between 0 and 65535 inclusive".into(),
790+
)],
791+
notes: vec![],
792+
},
784793
}
785794
}
786795
}

naga/src/front/wgsl/index.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
187187
ast::GlobalDeclKind::Fn(ref f) => f.name,
188188
ast::GlobalDeclKind::Var(ref v) => v.name,
189189
ast::GlobalDeclKind::Const(ref c) => c.name,
190+
ast::GlobalDeclKind::Override(ref o) => o.name,
190191
ast::GlobalDeclKind::Struct(ref s) => s.name,
191192
ast::GlobalDeclKind::Type(ref t) => t.name,
192193
}

naga/src/front/wgsl/lower/mod.rs

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,7 @@ enum LoweredGlobalDecl {
786786
Function(Handle<crate::Function>),
787787
Var(Handle<crate::GlobalVariable>),
788788
Const(Handle<crate::Constant>),
789+
Override(Handle<crate::Override>),
789790
Type(Handle<crate::Type>),
790791
EntryPoint,
791792
}
@@ -965,6 +966,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
965966
ctx.globals
966967
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
967968
}
969+
ast::GlobalDeclKind::Override(ref o) => {
970+
let init = o
971+
.init
972+
.map(|init| self.expression(init, &mut ctx.as_const()))
973+
.transpose()?;
974+
let inferred_type = init
975+
.map(|init| ctx.as_const().register_type(init))
976+
.transpose()?;
977+
978+
let explicit_ty =
979+
o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
980+
.transpose()?;
981+
982+
let id =
983+
o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
984+
.transpose()?;
985+
986+
let id = if let Some((id, id_span)) = id {
987+
Some(
988+
u16::try_from(id)
989+
.map_err(|_| Error::PipelineConstantIDValue(id_span))?,
990+
)
991+
} else {
992+
None
993+
};
994+
995+
let ty = match (explicit_ty, inferred_type) {
996+
(Some(explicit_ty), Some(inferred_type)) => {
997+
if explicit_ty == inferred_type {
998+
explicit_ty
999+
} else {
1000+
let gctx = ctx.module.to_ctx();
1001+
return Err(Error::InitializationTypeMismatch {
1002+
name: o.name.span,
1003+
expected: explicit_ty.to_wgsl(&gctx),
1004+
got: inferred_type.to_wgsl(&gctx),
1005+
});
1006+
}
1007+
}
1008+
(Some(explicit_ty), None) => explicit_ty,
1009+
(None, Some(inferred_type)) => inferred_type,
1010+
(None, None) => {
1011+
return Err(Error::DeclMissingTypeAndInit(o.name.span));
1012+
}
1013+
};
1014+
1015+
let handle = ctx.module.overrides.append(
1016+
crate::Override {
1017+
name: Some(o.name.name.to_string()),
1018+
id,
1019+
ty,
1020+
init,
1021+
},
1022+
span,
1023+
);
1024+
1025+
ctx.globals
1026+
.insert(o.name.name, LoweredGlobalDecl::Override(handle));
1027+
}
9681028
ast::GlobalDeclKind::Struct(ref s) => {
9691029
let handle = self.r#struct(s, span, &mut ctx)?;
9701030
ctx.globals
@@ -1202,7 +1262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
12021262
ty = explicit_ty;
12031263
initializer = None;
12041264
}
1205-
(None, None) => return Err(Error::MissingType(v.name.span)),
1265+
(None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)),
12061266
}
12071267

12081268
let (const_initializer, initializer) = {
@@ -1818,9 +1878,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
18181878
)?;
18191879
Ok(Some(handle))
18201880
}
1821-
Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
1822-
Err(Error::Unexpected(function.span, ExpectedToken::Function))
1823-
}
1881+
Some(
1882+
&LoweredGlobalDecl::Const(_)
1883+
| &LoweredGlobalDecl::Override(_)
1884+
| &LoweredGlobalDecl::Var(_),
1885+
) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
18241886
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
18251887
Some(&LoweredGlobalDecl::Function(function)) => {
18261888
let arguments = arguments

naga/src/front/wgsl/parse/ast.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> {
8282
Fn(Function<'a>),
8383
Var(GlobalVariable<'a>),
8484
Const(Const<'a>),
85+
Override(Override<'a>),
8586
Struct(Struct<'a>),
8687
Type(TypeAlias<'a>),
8788
}
@@ -200,6 +201,14 @@ pub struct Const<'a> {
200201
pub init: Handle<Expression<'a>>,
201202
}
202203

204+
#[derive(Debug)]
205+
pub struct Override<'a> {
206+
pub name: Ident<'a>,
207+
pub id: Option<Handle<Expression<'a>>>,
208+
pub ty: Option<Handle<Type<'a>>>,
209+
pub init: Option<Handle<Expression<'a>>>,
210+
}
211+
203212
/// The size of an [`Array`] or [`BindingArray`].
204213
///
205214
/// [`Array`]: Type::Array

0 commit comments

Comments
 (0)