Skip to content

Commit 3e875d7

Browse files
committed
[naga wgsl] Add to_wgsl functions for some Naga IR types.
Add `to_wgsl` functions that return the WGSL representations of the following types as `&'static str` values: - `MathFunction` - `BuiltIn` - `Interpolation` - `Sampling` Use these functions in the WGSL backend. Also add `to_wgsl_debug` functions that return some legible string for all possible values, not just those that are defined by the WGSL standard.
1 parent b7d1f4c commit 3e875d7

File tree

3 files changed

+203
-144
lines changed

3 files changed

+203
-144
lines changed

naga/src/back/wgsl/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub enum Error {
1919
Custom(String),
2020
#[error("{0}")]
2121
Unimplemented(String), // TODO: Error used only during development
22-
#[error("Unsupported math function: {0:?}")]
22+
#[error("Unsupported math function: {}", .0.to_wgsl_debug())]
2323
UnsupportedMathFunction(crate::MathFunction),
2424
#[error("Unsupported relational function: {0:?}")]
2525
UnsupportedRelationalFunction(crate::RelationalFunction),

naga/src/back/wgsl/writer.rs

Lines changed: 26 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,12 @@ impl<W: Write> Writer<W> {
342342
Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
343343
Attribute::SecondBlendSource => write!(self.out, "@second_blend_source ")?,
344344
Attribute::BuiltIn(builtin_attrib) => {
345-
let builtin = builtin_str(builtin_attrib)?;
345+
let builtin = builtin_attrib.to_wgsl().ok_or_else(|| {
346+
Error::Custom(format!(
347+
"Unsupported builtin `{}`",
348+
builtin_attrib.to_wgsl_debug()
349+
))
350+
})?;
346351
write!(self.out, "@builtin({builtin}) ")?;
347352
}
348353
Attribute::Stage(shader_stage) => {
@@ -368,20 +373,20 @@ impl<W: Write> Writer<W> {
368373
write!(
369374
self.out,
370375
"@interpolate({}, {}) ",
371-
interpolation_str(
372-
interpolation.unwrap_or(crate::Interpolation::Perspective)
373-
),
374-
sampling_str(sampling.unwrap_or(crate::Sampling::Center))
376+
interpolation
377+
.unwrap_or(crate::Interpolation::Perspective)
378+
.to_wgsl(),
379+
sampling.unwrap_or(crate::Sampling::Center).to_wgsl(),
375380
)?;
376381
} else if interpolation.is_some()
377382
&& interpolation != Some(crate::Interpolation::Perspective)
378383
{
379384
write!(
380385
self.out,
381386
"@interpolate({}) ",
382-
interpolation_str(
383-
interpolation.unwrap_or(crate::Interpolation::Perspective)
384-
)
387+
interpolation
388+
.unwrap_or(crate::Interpolation::Perspective)
389+
.to_wgsl(),
385390
)?;
386391
}
387392
}
@@ -1697,98 +1702,20 @@ impl<W: Write> Writer<W> {
16971702
InversePolyfill(InversePolyfill),
16981703
}
16991704

1700-
let function = match fun {
1701-
Mf::Abs => Function::Regular("abs"),
1702-
Mf::Min => Function::Regular("min"),
1703-
Mf::Max => Function::Regular("max"),
1704-
Mf::Clamp => Function::Regular("clamp"),
1705-
Mf::Saturate => Function::Regular("saturate"),
1706-
// trigonometry
1707-
Mf::Cos => Function::Regular("cos"),
1708-
Mf::Cosh => Function::Regular("cosh"),
1709-
Mf::Sin => Function::Regular("sin"),
1710-
Mf::Sinh => Function::Regular("sinh"),
1711-
Mf::Tan => Function::Regular("tan"),
1712-
Mf::Tanh => Function::Regular("tanh"),
1713-
Mf::Acos => Function::Regular("acos"),
1714-
Mf::Asin => Function::Regular("asin"),
1715-
Mf::Atan => Function::Regular("atan"),
1716-
Mf::Atan2 => Function::Regular("atan2"),
1717-
Mf::Asinh => Function::Regular("asinh"),
1718-
Mf::Acosh => Function::Regular("acosh"),
1719-
Mf::Atanh => Function::Regular("atanh"),
1720-
Mf::Radians => Function::Regular("radians"),
1721-
Mf::Degrees => Function::Regular("degrees"),
1722-
// decomposition
1723-
Mf::Ceil => Function::Regular("ceil"),
1724-
Mf::Floor => Function::Regular("floor"),
1725-
Mf::Round => Function::Regular("round"),
1726-
Mf::Fract => Function::Regular("fract"),
1727-
Mf::Trunc => Function::Regular("trunc"),
1728-
Mf::Modf => Function::Regular("modf"),
1729-
Mf::Frexp => Function::Regular("frexp"),
1730-
Mf::Ldexp => Function::Regular("ldexp"),
1731-
// exponent
1732-
Mf::Exp => Function::Regular("exp"),
1733-
Mf::Exp2 => Function::Regular("exp2"),
1734-
Mf::Log => Function::Regular("log"),
1735-
Mf::Log2 => Function::Regular("log2"),
1736-
Mf::Pow => Function::Regular("pow"),
1737-
// geometry
1738-
Mf::Dot => Function::Regular("dot"),
1739-
Mf::Cross => Function::Regular("cross"),
1740-
Mf::Distance => Function::Regular("distance"),
1741-
Mf::Length => Function::Regular("length"),
1742-
Mf::Normalize => Function::Regular("normalize"),
1743-
Mf::FaceForward => Function::Regular("faceForward"),
1744-
Mf::Reflect => Function::Regular("reflect"),
1745-
Mf::Refract => Function::Regular("refract"),
1746-
// computational
1747-
Mf::Sign => Function::Regular("sign"),
1748-
Mf::Fma => Function::Regular("fma"),
1749-
Mf::Mix => Function::Regular("mix"),
1750-
Mf::Step => Function::Regular("step"),
1751-
Mf::SmoothStep => Function::Regular("smoothstep"),
1752-
Mf::Sqrt => Function::Regular("sqrt"),
1753-
Mf::InverseSqrt => Function::Regular("inverseSqrt"),
1754-
Mf::Transpose => Function::Regular("transpose"),
1755-
Mf::Determinant => Function::Regular("determinant"),
1756-
Mf::QuantizeToF16 => Function::Regular("quantizeToF16"),
1757-
// bits
1758-
Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
1759-
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
1760-
Mf::CountOneBits => Function::Regular("countOneBits"),
1761-
Mf::ReverseBits => Function::Regular("reverseBits"),
1762-
Mf::ExtractBits => Function::Regular("extractBits"),
1763-
Mf::InsertBits => Function::Regular("insertBits"),
1764-
Mf::FirstTrailingBit => Function::Regular("firstTrailingBit"),
1765-
Mf::FirstLeadingBit => Function::Regular("firstLeadingBit"),
1766-
// data packing
1767-
Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
1768-
Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
1769-
Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"),
1770-
Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"),
1771-
Mf::Pack2x16float => Function::Regular("pack2x16float"),
1772-
Mf::Pack4xI8 => Function::Regular("pack4xI8"),
1773-
Mf::Pack4xU8 => Function::Regular("pack4xU8"),
1774-
// data unpacking
1775-
Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"),
1776-
Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"),
1777-
Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"),
1778-
Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"),
1779-
Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
1780-
Mf::Unpack4xI8 => Function::Regular("unpack4xI8"),
1781-
Mf::Unpack4xU8 => Function::Regular("unpack4xU8"),
1782-
Mf::Inverse => {
1783-
let typ = func_ctx.resolve_type(arg, &module.types);
1784-
1785-
let Some(overload) = InversePolyfill::find_overload(typ) else {
1786-
return Err(Error::UnsupportedMathFunction(fun));
1787-
};
1705+
let function = match fun.to_wgsl() {
1706+
Some(name) => Function::Regular(name),
1707+
None => match fun {
1708+
Mf::Inverse => {
1709+
let typ = func_ctx.resolve_type(arg, &module.types);
17881710

1789-
Function::InversePolyfill(overload)
1790-
}
1791-
Mf::Outer => return Err(Error::UnsupportedMathFunction(fun)),
1711+
let Some(overload) = InversePolyfill::find_overload(typ) else {
1712+
return Err(Error::UnsupportedMathFunction(fun));
1713+
};
1714+
1715+
Function::InversePolyfill(overload)
1716+
}
1717+
other => return Err(Error::UnsupportedMathFunction(other)),
1718+
},
17921719
};
17931720

17941721
match function {
@@ -1969,39 +1896,6 @@ impl<W: Write> Writer<W> {
19691896
}
19701897
}
19711898

1972-
fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
1973-
use crate::BuiltIn as Bi;
1974-
1975-
Ok(match built_in {
1976-
Bi::VertexIndex => "vertex_index",
1977-
Bi::InstanceIndex => "instance_index",
1978-
Bi::Position { .. } => "position",
1979-
Bi::FrontFacing => "front_facing",
1980-
Bi::FragDepth => "frag_depth",
1981-
Bi::LocalInvocationId => "local_invocation_id",
1982-
Bi::LocalInvocationIndex => "local_invocation_index",
1983-
Bi::GlobalInvocationId => "global_invocation_id",
1984-
Bi::WorkGroupId => "workgroup_id",
1985-
Bi::NumWorkGroups => "num_workgroups",
1986-
Bi::SampleIndex => "sample_index",
1987-
Bi::SampleMask => "sample_mask",
1988-
Bi::PrimitiveIndex => "primitive_index",
1989-
Bi::ViewIndex => "view_index",
1990-
Bi::NumSubgroups => "num_subgroups",
1991-
Bi::SubgroupId => "subgroup_id",
1992-
Bi::SubgroupSize => "subgroup_size",
1993-
Bi::SubgroupInvocationId => "subgroup_invocation_id",
1994-
Bi::BaseInstance
1995-
| Bi::BaseVertex
1996-
| Bi::ClipDistance
1997-
| Bi::CullDistance
1998-
| Bi::PointSize
1999-
| Bi::PointCoord
2000-
| Bi::WorkGroupSize
2001-
| Bi::DrawID => return Err(Error::Custom(format!("Unsupported builtin {built_in:?}"))),
2002-
})
2003-
}
2004-
20051899
const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str {
20061900
use crate::ImageDimension as IDim;
20071901

@@ -2098,17 +1992,6 @@ const fn storage_format_str(format: crate::StorageFormat) -> &'static str {
20981992
}
20991993
}
21001994

2101-
/// Helper function that returns the string corresponding to the WGSL interpolation qualifier
2102-
const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str {
2103-
use crate::Interpolation as I;
2104-
2105-
match interpolation {
2106-
I::Perspective => "perspective",
2107-
I::Linear => "linear",
2108-
I::Flat => "flat",
2109-
}
2110-
}
2111-
21121995
/// Return the WGSL auxiliary qualifier for the given sampling value.
21131996
const fn sampling_str(sampling: crate::Sampling) -> &'static str {
21141997
use crate::Sampling as S;

0 commit comments

Comments
 (0)