Skip to content

[wgsl-in] add support for override declarations #4793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ fn write_function_expressions(
let (label, color_id) = match *expression {
E::Literal(_) => ("Literal".into(), 2),
E::Constant(_) => ("Constant".into(), 2),
E::Override(_) => ("Override".into(), 2),
E::ZeroValue(_) => ("ZeroValue".into(), 2),
E::Compose { ref components, .. } => {
payload = Some(Payload::Arguments(components));
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,7 @@ impl<'a, W: Write> Writer<'a, W> {
|writer, expr| writer.write_expr(expr, ctx),
)?;
}
Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())),
// `Access` is applied to arrays, vectors and matrices and is written as indexing
Expression::Access { base, index } => {
self.write_expr(base, ctx)?;
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
// All of the multiplication can be expressed as `mul`,
// except vector * vector, which needs to use the "*" operator.
Expression::Binary {
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,9 @@ impl<W: Write> Writer<W> {
|writer, context, expr| writer.put_expression(expr, context, true),
)?;
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP".into()))
}
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> {
let init = self.ir_module.constants[handle].init;
self.writer.constant_ids[init.index()]
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP"))
}
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,9 @@ impl<W: Write> Writer<W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
Expression::FunctionArgument(pos) => {
let name_key = func_ctx.argument_key(pos);
let name = &self.names[&name_key];
Expand Down
9 changes: 9 additions & 0 deletions naga/src/compact/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle};

pub struct ExpressionTracer<'tracer> {
pub constants: &'tracer Arena<crate::Constant>,
pub overrides: &'tracer Arena<crate::Override>,

/// The arena in which we are currently tracing expressions.
pub expressions: &'tracer Arena<crate::Expression>,
Expand Down Expand Up @@ -88,6 +89,11 @@ impl<'tracer> ExpressionTracer<'tracer> {
None => self.expressions_used.insert(init),
}
}
Ex::Override(_) => {
// All overrides are considered used by definition. We mark
// their types and initialization expressions as used in
// `compact::compact`, so we have no more work to do here.
}
Ex::ZeroValue(ty) => self.types_used.insert(ty),
Ex::Compose { ty, ref components } => {
self.types_used.insert(ty);
Expand Down Expand Up @@ -219,6 +225,9 @@ impl ModuleMap {
| Ex::CallResult(_)
| Ex::RayQueryProceedResult => {}

// All overrides are retained, so their handles never change.
Ex::Override(_) => {}

// Expressions that contain handles that need to be adjusted.
Ex::Constant(ref mut constant) => self.constants.adjust(constant),
Ex::ZeroValue(ref mut ty) => self.types.adjust(ty),
Expand Down
2 changes: 2 additions & 0 deletions naga/src/compact/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::{FunctionMap, ModuleMap};
pub struct FunctionTracer<'a> {
pub function: &'a crate::Function,
pub constants: &'a crate::Arena<crate::Constant>,
pub overrides: &'a crate::Arena<crate::Override>,

pub types_used: &'a mut HandleSet<crate::Type>,
pub constants_used: &'a mut HandleSet<crate::Constant>,
Expand Down Expand Up @@ -47,6 +48,7 @@ impl<'a> FunctionTracer<'a> {
fn as_expression(&mut self) -> super::expressions::ExpressionTracer {
super::expressions::ExpressionTracer {
constants: self.constants,
overrides: self.overrides,
expressions: &self.function.expressions,

types_used: self.types_used,
Expand Down
19 changes: 19 additions & 0 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) {
}
}

// We treat all overrides as used by definition.
for (_, override_) in module.overrides.iter() {
module_tracer.types_used.insert(override_.ty);
if let Some(init) = override_.init {
module_tracer.const_expressions_used.insert(init);
}
}

// We assume that all functions are used.
//
// Observe which types, constant expressions, constants, and
Expand Down Expand Up @@ -158,6 +166,15 @@ pub fn compact(module: &mut crate::Module) {
}
});

// Adjust override types and initializers.
log::trace!("adjusting overrides");
for (_, override_) in module.overrides.iter_mut() {
module_map.types.adjust(&mut override_.ty);
if let Some(init) = override_.init.as_mut() {
module_map.const_expressions.adjust(init);
}
}

// Adjust global variables' types and initializers.
log::trace!("adjusting global variables");
for (_, global) in module.global_variables.iter_mut() {
Expand Down Expand Up @@ -235,6 +252,7 @@ impl<'module> ModuleTracer<'module> {
expressions::ExpressionTracer {
expressions: &self.module.const_expressions,
constants: &self.module.constants,
overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
expressions_used: &mut self.const_expressions_used,
Expand All @@ -249,6 +267,7 @@ impl<'module> ModuleTracer<'module> {
FunctionTracer {
function,
constants: &self.module.constants,
overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
const_expressions_used: &mut self.const_expressions_used,
Expand Down
2 changes: 2 additions & 0 deletions naga/src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
expressions: &mut fun.expressions,
local_arena: &mut fun.local_variables,
const_arena: &mut module.constants,
overrides: &mut module.overrides,
const_expressions: &mut module.const_expressions,
type_arena: &module.types,
global_arena: &module.global_variables,
Expand Down Expand Up @@ -573,6 +574,7 @@ impl<'function> BlockContext<'function> {
crate::proc::GlobalCtx {
types: self.type_arena,
constants: self.const_arena,
overrides: self.overrides,
const_expressions: self.const_expressions,
}
}
Expand Down
3 changes: 2 additions & 1 deletion naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ struct BlockContext<'function> {
local_arena: &'function mut Arena<crate::LocalVariable>,
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
overrides: &'function mut Arena<crate::Override>,
const_expressions: &'function mut Arena<crate::Expression>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
Expand Down Expand Up @@ -3933,7 +3934,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeImage => self.parse_type_image(inst, &mut module),
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
Op::Constant => self.parse_constant(inst, &mut module),
Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
Expand Down
17 changes: 13 additions & 4 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ pub enum Error<'a> {
expected: String,
got: String,
},
MissingType(Span),
DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span),
InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span),
Expand Down Expand Up @@ -251,6 +251,7 @@ pub enum Error<'a> {
ExpectedPositiveArrayLength(Span),
MissingWorkgroupSize(Span),
ConstantEvaluatorError(ConstantEvaluatorError, Span),
PipelineConstantIDValue(Span),
}

impl<'a> Error<'a> {
Expand Down Expand Up @@ -500,11 +501,11 @@ impl<'a> Error<'a> {
notes: vec![],
}
}
Error::MissingType(name_span) => ParseError {
message: format!("variable `{}` needs a type", &source[name_span]),
Error::DeclMissingTypeAndInit(name_span) => ParseError {
message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]),
labels: vec![(
name_span,
format!("definition of `{}`", &source[name_span]).into(),
"needs a type specifier or initializer".into(),
)],
notes: vec![],
},
Expand Down Expand Up @@ -712,6 +713,14 @@ impl<'a> Error<'a> {
)],
notes: vec![],
},
Error::PipelineConstantIDValue(span) => ParseError {
message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(),
labels: vec![(
span,
"must be between 0 and 65535 inclusive".into(),
)],
notes: vec![],
},
}
}
}
1 change: 1 addition & 0 deletions naga/src/front/wgsl/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
}
Expand Down
70 changes: 66 additions & 4 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ enum LoweredGlobalDecl {
Function(Handle<crate::Function>),
Var(Handle<crate::GlobalVariable>),
Const(Handle<crate::Constant>),
Override(Handle<crate::Override>),
Type(Handle<crate::Type>),
EntryPoint,
}
Expand Down Expand Up @@ -933,6 +934,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
}
ast::GlobalDeclKind::Override(ref o) => {
let init = o
.init
.map(|init| self.expression(init, &mut ctx.as_const()))
.transpose()?;
let inferred_type = init
.map(|init| ctx.as_const().register_type(init))
.transpose()?;

let explicit_ty =
o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
.transpose()?;

let id =
o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
.transpose()?;

let id = if let Some((id, id_span)) = id {
Some(
u16::try_from(id)
.map_err(|_| Error::PipelineConstantIDValue(id_span))?,
)
} else {
None
};

let ty = match (explicit_ty, inferred_type) {
(Some(explicit_ty), Some(inferred_type)) => {
if explicit_ty == inferred_type {
explicit_ty
} else {
let gctx = ctx.module.to_ctx();
return Err(Error::InitializationTypeMismatch {
name: o.name.span,
expected: explicit_ty.to_wgsl(&gctx),
got: inferred_type.to_wgsl(&gctx),
});
}
}
(Some(explicit_ty), None) => explicit_ty,
(None, Some(inferred_type)) => inferred_type,
(None, None) => {
return Err(Error::DeclMissingTypeAndInit(o.name.span));
}
};

let handle = ctx.module.overrides.append(
crate::Override {
name: Some(o.name.name.to_string()),
id,
ty,
init,
},
span,
);

ctx.globals
.insert(o.name.name, LoweredGlobalDecl::Override(handle));
}
ast::GlobalDeclKind::Struct(ref s) => {
let handle = self.r#struct(s, span, &mut ctx)?;
ctx.globals
Expand Down Expand Up @@ -1160,7 +1220,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.as_expression(block, &mut emitter)
.register_type(initializer)?,
(None, None) => {
return Err(Error::MissingType(v.name.span));
return Err(Error::DeclMissingTypeAndInit(v.name.span));
}
};

Expand Down Expand Up @@ -1715,9 +1775,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
Ok(Some(handle))
}
Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
Err(Error::Unexpected(function.span, ExpectedToken::Function))
}
Some(
&LoweredGlobalDecl::Const(_)
| &LoweredGlobalDecl::Override(_)
| &LoweredGlobalDecl::Var(_),
) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
Some(&LoweredGlobalDecl::Function(function)) => {
let arguments = arguments
Expand Down
9 changes: 9 additions & 0 deletions naga/src/front/wgsl/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> {
Fn(Function<'a>),
Var(GlobalVariable<'a>),
Const(Const<'a>),
Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
}
Expand Down Expand Up @@ -200,6 +201,14 @@ pub struct Const<'a> {
pub init: Handle<Expression<'a>>,
}

#[derive(Debug)]
pub struct Override<'a> {
pub name: Ident<'a>,
pub id: Option<Handle<Expression<'a>>>,
pub ty: Option<Handle<Type<'a>>>,
pub init: Option<Handle<Expression<'a>>>,
}

/// The size of an [`Array`] or [`BindingArray`].
///
/// [`Array`]: Type::Array
Expand Down
Loading