diff --git a/CHANGELOG.md b/CHANGELOG.md index 78075b4a06..602799ed08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,27 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410) - Implement `WGSL`'s `unpack4xI8`,`unpack4xU8`,`pack4xI8` and `pack4xU8`. By @VlaDexa in [#5424](https://github.com/gfx-rs/wgpu/pull/5424) - Began work adding support for atomics to the SPIR-V frontend. Tracking issue is [here](https://github.com/gfx-rs/wgpu/issues/4489). By @schell in [#5702](https://github.com/gfx-rs/wgpu/pull/5702). +- Compile shaders with a "base module". This allows for cleanly implementing imports, or shadertoy-like environments. By @stefnotch in [#5791](https://github.com/gfx-rs/wgpu/pull/5791). + ```rust + use naga::front::wgsl::Frontend; + + let base_module = Frontend::new() + .parse(" + fn main_image(frag_coord: vec2f) -> vec4f { + return vec4f(sin(frag_coord.x), cos(frag_coord.y), 0.0, 1.0); + }").unwrap(); + // A full shadertoy implementation would rename all globals from base_module, except for `main_image`. + let result = Frontend::new() + .parse_to_ast(" + @fragment + fn fs_main(@builtin(position) pos : vec4f, @location(0) uv: vec2f) -> vec4f { + return main_image(uv); + }") + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); + ``` + ### Changes diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e7cce17723..201c104d0a 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -913,8 +913,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { pub fn lower( &mut self, tu: &'temp ast::TranslationUnit<'source>, + base_module: Option<&'source crate::Module>, ) -> Result> { - let mut module = crate::Module::default(); + let mut module = base_module.map(|v| v.to_owned()).unwrap_or_default(); let mut ctx = GlobalContext { ast_expressions: &tu.expressions, @@ -925,6 +926,70 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(), }; + if let Some(base_module) = base_module { + // The handles for base_module are equal to the handles for ctx.module, because we just cloned the arenas. + for (handle, f) in base_module.functions.iter() { + if let Some(name) = f.name.as_ref() { + ctx.globals + .insert(name, LoweredGlobalDecl::Function(handle)); + } + } + for (handle, v) in base_module.global_variables.iter() { + if let Some(name) = v.name.as_ref() { + ctx.globals.insert(name, LoweredGlobalDecl::Var(handle)); + } + } + for (handle, c) in base_module.constants.iter() { + if let Some(name) = c.name.as_ref() { + ctx.globals.insert(name, LoweredGlobalDecl::Const(handle)); + } + } + for (handle, o) in base_module.overrides.iter() { + if let Some(name) = o.name.as_ref() { + ctx.globals + .insert(name, LoweredGlobalDecl::Override(handle)); + } + } + for (handle, t) in base_module.types.iter() { + if let Some(name) = t.name.as_ref() { + ctx.globals.insert(name, LoweredGlobalDecl::Type(handle)); + } + } + for entry_point in base_module.entry_points.iter() { + ctx.globals + .insert(entry_point.name.as_str(), LoweredGlobalDecl::EntryPoint); + } + *ctx.global_expression_kind_tracker = + crate::proc::ExpressionKindTracker::from_arena(&ctx.module.global_expressions); + } + + // check for redefinitions + for (_, decl) in tu.decls.iter() { + let ident = match decl.kind { + ast::GlobalDeclKind::Fn(ref f) => f.name, + ast::GlobalDeclKind::Var(ref v) => v.name, + ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Override(ref o) => o.name, + ast::GlobalDeclKind::Struct(ref s) => s.name, + ast::GlobalDeclKind::Type(ref t) => t.name, + }; + if let Some(old) = ctx.globals.get(ident.name) { + let span = match *old { + LoweredGlobalDecl::Function(handle) => ctx.module.functions.get_span(handle), + LoweredGlobalDecl::Var(handle) => ctx.module.global_variables.get_span(handle), + LoweredGlobalDecl::Const(handle) => ctx.module.constants.get_span(handle), + LoweredGlobalDecl::Override(handle) => ctx.module.overrides.get_span(handle), + LoweredGlobalDecl::Type(handle) => ctx.module.types.get_span(handle), + // We don't have good spans for entry points + LoweredGlobalDecl::EntryPoint => Default::default(), + }; + return Err(Error::Redefinition { + previous: span, + current: ident.span, + }); + } + } + for decl_handle in self.index.visit_ordered() { let span = tu.decls.get_span(decl_handle); let decl = &tu.decls[decl_handle]; diff --git a/naga/src/front/wgsl/mod.rs b/naga/src/front/wgsl/mod.rs index aec1e657fc..919d199140 100644 --- a/naga/src/front/wgsl/mod.rs +++ b/naga/src/front/wgsl/mod.rs @@ -20,6 +20,8 @@ pub use crate::front::wgsl::error::ParseError; use crate::front::wgsl::lower::Lowerer; use crate::Scalar; +use self::parse::ast::TranslationUnit; + pub struct Frontend { parser: Parser, } @@ -32,18 +34,46 @@ impl Frontend { } pub fn parse(&mut self, source: &str) -> Result { - self.inner(source).map_err(|x| x.as_parse_error(source)) + self.parse_to_ast(source)?.to_module(None) } - fn inner<'a>(&mut self, source: &'a str) -> Result> { - let tu = self.parser.parse(source)?; - let index = index::Index::generate(&tu)?; - let module = Lowerer::new(&index).lower(&tu)?; + /// Two-step module conversion, can be used to compile with a "base module". + pub fn parse_to_ast<'a>(&mut self, source: &'a str) -> Result, ParseError> { + self.inner_to_ast(source) + .map_err(|x| x.as_parse_error(source)) + } - Ok(module) + fn inner_to_ast<'a>(&mut self, source: &'a str) -> Result, Error<'a>> { + let translation_unit = self.parser.parse(source)?; + let index = index::Index::generate(&translation_unit)?; + Ok(ParsedWgsl { + source, + translation_unit, + index, + }) } } +pub struct ParsedWgsl<'a> { + source: &'a str, + translation_unit: TranslationUnit<'a>, + index: index::Index<'a>, +} +impl<'a> ParsedWgsl<'a> { + pub fn to_module( + &self, + base_module: Option<&crate::Module>, + ) -> Result { + self.inner_to_module(base_module) + .map_err(|x| x.as_parse_error(self.source)) + } + fn inner_to_module( + &self, + base_module: Option<&'a crate::Module>, + ) -> Result> { + Lowerer::new(&self.index).lower(&self.translation_unit, base_module) + } +} ///
// NOTE: Keep this in sync with `wgpu::Device::create_shader_module`! // NOTE: Keep this in sync with `wgpu_core::Global::device_create_shader_module`! diff --git a/naga/src/front/wgsl/tests.rs b/naga/src/front/wgsl/tests.rs index cc3d858317..31697d40fd 100644 --- a/naga/src/front/wgsl/tests.rs +++ b/naga/src/front/wgsl/tests.rs @@ -614,7 +614,9 @@ fn parse_repeated_attributes() { let span_end = span_start + name_length; let expected_span = Span::new(span_start, span_end); - let result = Frontend::new().inner(&shader); + let result = Frontend::new() + .inner_to_ast(&shader) + .and_then(|v| v.inner_to_module(None)); assert!(matches!( result.unwrap_err(), Error::RepeatedAttribute(span) if span == expected_span @@ -630,9 +632,248 @@ fn parse_missing_workgroup_size() { }; let shader = "@compute fn vs() -> vec4 { return vec4(0.0); }"; - let result = Frontend::new().inner(shader); + let result = Frontend::new() + .inner_to_ast(shader) + .and_then(|v| v.inner_to_module(None)); assert!(matches!( result.unwrap_err(), Error::MissingWorkgroupSize(span) if span == Span::new(1, 8) )); } + +#[test] +fn parse_base_module_constant() { + use crate::front::wgsl::Frontend; + + let base_module = parse_str("const some_constant_value: f32 = 1.0;").unwrap(); + let shader = "fn foo() -> vec4 { return vec4(some_constant_value); }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +/* Add this test once abstract numerics are supported. +Although, depending on the implementation, this may not work anyways. +#[test] +fn parse_base_module_abstract_numerics() { + use crate::front::wgsl::Frontend; + + let base_module = + parse_str("const two = 2;") + .unwrap(); + let shader = "const four: u32 = 2 * two; const signed_two: i32 = two;"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} +*/ +#[test] +fn parse_structs_with_base_module_structs() { + use crate::front::wgsl::Frontend; + + let base_module = parse_str("struct Bar { a: vec4f, b: Foo }; struct Foo { c: f32 }").unwrap(); + let shader = "fn foo(foo: Foo) -> Bar { return Bar(vec4(1.0), foo); }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +#[test] +fn parse_fn_with_base_module() { + use crate::front::wgsl::Frontend; + + let base_module = + parse_str("fn cat() -> Foo { return Foo(1.0); }; struct Foo { c: f32 }").unwrap(); + let shader = "fn foo() -> f32 { return cat().c; }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +#[test] +fn parse_fn_conflict_with_base_module() { + use crate::front::wgsl::{error::Error, Frontend}; + + let base_module = parse_str("fn cat() -> f32 { return 1.0; }").unwrap(); + let shader = "fn cat() -> f32 { return 2.0; }"; + let result = Frontend::new() + .inner_to_ast(shader) + .and_then(|v| v.inner_to_module(Some(&base_module))); + assert!(matches!(result, Err(Error::Redefinition { .. }))); +} + +#[test] +fn parse_base_module_alias() { + use crate::front::wgsl::Frontend; + + let base_module = + parse_str("alias number = u32; struct Bar { a: number, b: Foo }; alias Foo = i32;") + .unwrap(); + let shader = "fn foo(a: u32) -> Bar { return Bar(a, -1i); }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +#[test] +fn parse_base_module_alias_usage() { + use crate::front::wgsl::Frontend; + + let base_module = parse_str("alias number = u32; struct Bar { a: u32 };").unwrap(); + let shader = "fn foo(a: number) -> Bar { return Bar(a); }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +/* Add this once https://github.com/gfx-rs/wgpu/issues/5786 is fixed +#[test] +fn parse_base_module_alias_predeclared() { + use crate::front::wgsl::Frontend; + + let base_module = + parse_str("alias vec4f = u32; struct Bar { a: vec4f, b: Foo }; alias Foo = i32;").unwrap(); + let shader = "fn foo(a: u32) -> Bar { return Bar(a, -1i); }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + */ + +#[test] +fn parse_base_module_function_predefined() { + use crate::front::wgsl::Frontend; + + // Changing a predefined function should affect the shader + let base_module = parse_str( + "fn bar(a: f32) -> f32 { return cos(vec3f(a)).z; } fn cos(a: vec3f) -> vec3f { return a.xyz; }", + ) + .unwrap(); + let shader = "fn foo(a: f32) -> f32 { return cos(1.0); }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +#[test] +fn parse_base_module_function_predefined_no_leak() { + use crate::front::wgsl::Frontend; + + // But changing a predefined function in the shader should not affect the base module + let base_module = parse_str("fn bar(a: f32) -> f32 { return cos(a); }").unwrap(); + let shader = + "fn foo(a: f32) -> f32 { return cos(vec3f(a)).z + bar(a); } fn cos(a: vec3f) -> vec3f { return a.xyz; }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +#[test] +fn parse_base_module_twice() { + use crate::front::wgsl::Frontend; + + let base_module_a = parse_str("fn bar(a: f32) -> f32 { return cos(a); }").unwrap(); + let shader_a = "fn foo(a: f32) -> f32 { return bar(a); }"; + let base_module_b = Frontend::new() + .parse_to_ast(shader_a) + .unwrap() + .to_module(Some(&base_module_a)) + .unwrap(); + let shader_b = "fn foobar(a: f32) -> f32 { return bar(a) + foo(a); }"; + Frontend::new() + .parse_to_ast(shader_b) + .unwrap() + .to_module(Some(&base_module_b)) + .unwrap(); +} + +#[test] +fn parse_base_module_const_conflict() { + use crate::front::wgsl::{error::Error, Frontend}; + + let base_module = parse_str("const foo: f32 = 1.0;").unwrap(); + let shader = "fn foo() -> f32 { return 2.0; }"; + let result = Frontend::new() + .inner_to_ast(shader) + .and_then(|v| v.inner_to_module(Some(&base_module))); + assert!(matches!(result, Err(Error::Redefinition { .. }))); +} + +#[test] +fn parse_base_module_const_local() { + use crate::front::wgsl::Frontend; + // Const in base module with same name as a local variable in the actual module shouldn't cause conflicts + + let base_module = parse_str("const foo: vec3f = vec3f(1.0);").unwrap(); + let shader = "fn bar() -> f32 { let foo: f32 = 2.0; return foo; }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +#[test] +fn parse_base_module_entry_points() { + use crate::front::wgsl::Frontend; + + let base_module = + parse_str("@vertex fn vs() -> @builtin(position) vec4f { return vec4(0.0); }") + .unwrap(); + let shader = "@fragment fn fs() -> @location(0) vec4f { return vec4(1.0); }"; + let result = Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); + assert_eq!(result.entry_points.len(), 2); +} + +#[test] +fn parse_base_module_pipeline_overridable_constants() { + use crate::front::wgsl::Frontend; + + let base_module = parse_str("override diffuse_param: f32 = 2.3;").unwrap(); + let shader = "override specular_param: f32 = 0.1; fn foo() -> f32 { return diffuse_param + specular_param; }"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +} + +#[test] +fn parse_base_module_storage_buffers() { + use crate::front::wgsl::Frontend; + + let base_module = parse_str( + "@group(0) @binding(0) + var foo: array;", + ) + .unwrap(); + let shader = "@group(0) @binding(1) + var bar: array;"; + Frontend::new() + .parse_to_ast(shader) + .unwrap() + .to_module(Some(&base_module)) + .unwrap(); +}