Skip to content

Commit d0c3e40

Browse files
committed
[naga] Support MathFunction overloads correctly.
Define a new trait, `proc::builtins::OverloadSet`, for types that represent a Naga IR builtin function's set of overloads. The `OverloadSet` trait includes operations needed to validate calls, choose automatic type conversions, and generate diagnostics. Add a new function, `ir::MathFunction::overloads`, which returns the given `MathFunction`'s set of overloads as an `impl OverloadSet` value. Use this in the WGSL front end, the validator, and the typifier. To support `MathFunction::overloads`, provide two implementations of `OverloadSet`: - `List` is flexible but verbose. - `Regular` is concise but more restrictive. Some snapshot output is affected because `TypeResolution::Handle` values turn into `TypeResolution::Value`, since the function database constructs the return type directly. To work around gfx-rs#7405, avoid offering abstract-typed overloads of some functions. This addresses gfx-rs#6443 for `MathFunction`, although that issue covers other categories of operations as well.
1 parent 81f7fdf commit d0c3e40

20 files changed

+2575
-895
lines changed

naga/src/common/diagnostic_display.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Displaying Naga IR terms in diagnostic output.
22
3-
use crate::proc::GlobalCtx;
3+
use crate::proc::{GlobalCtx, Rule};
44
use crate::{Handle, Scalar, Type, TypeInner};
55

66
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
@@ -83,6 +83,23 @@ impl fmt::Display for DiagnosticDisplay<(&TypeInner, GlobalCtx<'_>)> {
8383
}
8484
}
8585

86+
impl fmt::Display for DiagnosticDisplay<(&str, &Rule, GlobalCtx<'_>)> {
87+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88+
let (name, rule, ref ctx) = self.0;
89+
90+
#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
91+
ctx.write_type_rule(name, rule, f)?;
92+
93+
#[cfg(not(any(feature = "wgsl-in", feature = "wgsl-out")))]
94+
{
95+
let _ = ctx;
96+
write!(f, "{name}({:?}) -> {:?}", rule.arguments, rule.conclusion)?;
97+
}
98+
99+
Ok(())
100+
}
101+
}
102+
86103
impl fmt::Display for DiagnosticDisplay<Scalar> {
87104
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88105
let scalar = self.0;

naga/src/common/wgsl/types.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,37 @@ pub trait TypeContext {
133133
}
134134
}
135135

136+
fn write_type_conclusion<W: Write>(
137+
&self,
138+
conclusion: &crate::proc::Conclusion,
139+
out: &mut W,
140+
) -> core::fmt::Result {
141+
use crate::proc::Conclusion as Co;
142+
143+
match *conclusion {
144+
Co::Value(ref inner) => self.write_type_inner(inner, out),
145+
Co::Predeclared(ref predeclared) => out.write_str(&predeclared.struct_name()),
146+
}
147+
}
148+
149+
fn write_type_rule<W: Write>(
150+
&self,
151+
name: &str,
152+
rule: &crate::proc::Rule,
153+
out: &mut W,
154+
) -> core::fmt::Result {
155+
write!(out, "fn {name}(")?;
156+
for (i, arg) in rule.arguments.iter().enumerate() {
157+
if i > 0 {
158+
out.write_str(", ")?;
159+
}
160+
self.write_type_resolution(arg, out)?
161+
}
162+
out.write_str(") -> ")?;
163+
self.write_type_conclusion(&rule.conclusion, out)?;
164+
Ok(())
165+
}
166+
136167
fn type_to_string(&self, handle: Handle<crate::Type>) -> String {
137168
let mut buf = String::new();
138169
self.write_type(handle, &mut buf).unwrap();
@@ -150,6 +181,12 @@ pub trait TypeContext {
150181
self.write_type_resolution(resolution, &mut buf).unwrap();
151182
buf
152183
}
184+
185+
fn type_rule_to_string(&self, name: &str, rule: &crate::proc::Rule) -> String {
186+
let mut buf = String::new();
187+
self.write_type_rule(name, rule, &mut buf).unwrap();
188+
buf
189+
}
153190
}
154191

155192
fn try_write_type_inner<C, W>(ctx: &C, inner: &TypeInner, out: &mut W) -> Result<(), WriteTypeError>

naga/src/front/wgsl/error.rs

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,110 @@ pub(crate) enum Error<'a> {
270270
expected: Range<u32>,
271271
found: u32,
272272
},
273+
274+
/// No overload of this function accepts this many arguments.
275+
TooManyArguments {
276+
/// The name of the function being called.
277+
function: String,
278+
279+
/// The function name in the call expression.
280+
call_span: Span,
281+
282+
/// The first argument that is unacceptable.
283+
arg_span: Span,
284+
285+
/// Maximum number of arguments accepted by any overload of
286+
/// this function.
287+
max_arguments: u32,
288+
},
289+
290+
/// Given the types of the prior arguments, no remaining overload
291+
/// accepts this many arguments.
292+
WrongArgumentCountForOverloads {
293+
/// The name of the function being called.
294+
function: String,
295+
296+
/// The function name in the call expression.
297+
call_span: Span,
298+
299+
/// The first argument that is unacceptable.
300+
arg_span: Span,
301+
302+
/// The prior argument whose type made the `arg_span` argument
303+
/// unacceptable.
304+
prior_span: Span,
305+
306+
/// The index of the `prior_arg_span` argument.
307+
prior_index: u32,
308+
309+
/// The type of the `prior_arg_span` argument.
310+
prior_ty: String,
311+
312+
/// Maximum number of arguments accepted by any overload of
313+
/// this function.
314+
max_arguments: u32,
315+
},
316+
317+
/// A value passed to a builtin function has a type that is not
318+
/// accepted by any overload of the function.
319+
WrongArgumentType {
320+
/// The name of the function being called.
321+
function: String,
322+
323+
/// The function name in the call expression.
324+
call_span: Span,
325+
326+
/// The first argument whose type is unacceptable.
327+
arg_span: Span,
328+
329+
/// The index of the first argument whose type is unacceptable.
330+
arg_index: u32,
331+
332+
/// That argument's actual type.
333+
arg_ty: String,
334+
335+
/// The set of argument types that would have been accepted for
336+
/// this argument, given the prior arguments.
337+
allowed: Vec<String>,
338+
},
339+
340+
/// A value passed to a builtin function has a type that is not
341+
/// accepted, given the earlier arguments' types.
342+
InconsistentArgumentType {
343+
/// The name of the function being called.
344+
function: String,
345+
346+
/// The function name in the call expression.
347+
call_span: Span,
348+
349+
/// The first unacceptable argument.
350+
arg_span: Span,
351+
352+
/// The index of the first unacceptable argument.
353+
arg_index: u32,
354+
355+
/// The actual type of the first unacceptable argument.
356+
arg_ty: String,
357+
358+
/// The prior argument whose type made the `arg_span` argument
359+
/// unacceptable.
360+
prior_span: Span,
361+
362+
/// The index of the `prior_arg_span` argument.
363+
prior_index: u32,
364+
365+
/// The type of the `prior_arg_span` argument.
366+
prior_ty: String,
367+
368+
/// The types that would have been accepted instead of the
369+
/// first unacceptable argument.
370+
allowed: Vec<String>,
371+
},
372+
373+
AmbiguousCall {
374+
call_span: Span,
375+
alternatives: Vec<String>,
376+
},
273377
FunctionReturnsVoid(Span),
274378
FunctionMustUseUnused(Span),
275379
FunctionMustUseReturnsVoid(Span, Span),
@@ -401,7 +505,8 @@ impl<'a> Error<'a> {
401505
"workgroup size separator (`,`) or a closing parenthesis".to_string()
402506
}
403507
ExpectedToken::GlobalItem => concat!(
404-
"global item (`struct`, `const`, `var`, `alias`, `fn`, `diagnostic`, `enable`, `requires`, `;`) ",
508+
"global item (`struct`, `const`, `var`, `alias`, ",
509+
"`fn`, `diagnostic`, `enable`, `requires`, `;`) ",
405510
"or the end of the file"
406511
)
407512
.to_string(),
@@ -825,6 +930,107 @@ impl<'a> Error<'a> {
825930
labels: vec![(span, "wrong number of arguments".into())],
826931
notes: vec![],
827932
},
933+
Error::TooManyArguments {
934+
ref function,
935+
call_span,
936+
arg_span,
937+
max_arguments,
938+
} => ParseError {
939+
message: format!("Too many arguments passed to `{function}`"),
940+
labels: vec![
941+
(call_span, "The function is called here".into()),
942+
(arg_span, "This is the first excess argument".into())
943+
],
944+
notes: vec![
945+
format!("The `{function}` function accepts at most {max_arguments} argument(s)")
946+
],
947+
},
948+
Error::WrongArgumentCountForOverloads {
949+
ref function,
950+
call_span,
951+
arg_span,
952+
prior_span,
953+
prior_index,
954+
ref prior_ty,
955+
max_arguments,
956+
} => {
957+
let message = format!("Too many arguments in this call to to `{function}`");
958+
let labels = vec![
959+
(call_span, "The function is called here".into()),
960+
(arg_span, "This is the first excess argument".into()),
961+
(prior_span, format!("This argument has type `{prior_ty}`.").into()),
962+
];
963+
let notes = vec![
964+
format!("Because argument #{} has type `{prior_ty}`,", prior_index + 1),
965+
format!("`{function}` accepts only {max_arguments} argument(s)."),
966+
];
967+
968+
ParseError { message, labels, notes }
969+
}
970+
Error::WrongArgumentType {
971+
ref function,
972+
call_span,
973+
arg_span,
974+
arg_index,
975+
ref arg_ty,
976+
ref allowed,
977+
} => {
978+
let message = format!(
979+
"Wrong type passed as argument #{} to `{function}`",
980+
arg_index + 1,
981+
);
982+
let labels = vec![
983+
(call_span, "The function is called here".into()),
984+
(arg_span, format!("This argument has type `{arg_ty}`").into())
985+
];
986+
987+
let mut notes = vec![];
988+
notes.push(format!("`{function}` accepts the following types for argument #{}:", arg_index + 1));
989+
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));
990+
991+
ParseError { message, labels, notes }
992+
},
993+
Error::InconsistentArgumentType {
994+
ref function,
995+
call_span,
996+
arg_span,
997+
arg_index,
998+
ref arg_ty,
999+
prior_span,
1000+
prior_index,
1001+
ref prior_ty,
1002+
ref allowed
1003+
} => {
1004+
let message = format!(
1005+
"Inconsistent type passed as argument #{} to `{function}`",
1006+
arg_index + 1,
1007+
);
1008+
let labels = vec![
1009+
(call_span, "The function is called here".into()),
1010+
(arg_span, format!("This argument has type {arg_ty}").into()),
1011+
(prior_span, format!(
1012+
"This argument has type {prior_ty}, which constrains subsequent arguments"
1013+
).into()),
1014+
];
1015+
let mut notes = vec![
1016+
format!("Because argument #{} has type {prior_ty}, only the following types", prior_index + 1),
1017+
format!("(or types that automatically convert to them) are accepted for argument #{}:", arg_index + 1),
1018+
];
1019+
notes.extend(allowed.iter().map(|ty| format!("allowed type: {ty}")));
1020+
1021+
ParseError { message, labels, notes }
1022+
}
1023+
Error::AmbiguousCall { call_span, ref alternatives } => {
1024+
let message = "Function call is ambiguous: more than one overload could apply".into();
1025+
let labels = vec![
1026+
(call_span, "More than one overload of this function could apply to these arguments".into()),
1027+
];
1028+
let mut notes = vec![
1029+
"All of the following overloads could apply, but no one overload is clearly preferable:".into()
1030+
];
1031+
notes.extend(alternatives.iter().map(|alt| format!("possible overload: {alt}")));
1032+
ParseError { message, labels, notes }
1033+
},
8281034
Error::FunctionReturnsVoid(span) => ParseError {
8291035
message: "function does not return any value".to_string(),
8301036
labels: vec![(span, "".into())],

0 commit comments

Comments
 (0)