Skip to content

Commit 90104c5

Browse files
refactor(naga): extract Validator::validate_func_call_with_overloads
1 parent ebf3bf0 commit 90104c5

1 file changed

Lines changed: 69 additions & 39 deletions

File tree

naga/src/valid/expression.rs

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
use alloc::{format, string::String};
1+
use alloc::{
2+
format,
3+
string::{String, ToString},
4+
};
25

36
use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
47
use crate::arena::UniqueArena;
58
use crate::{
69
arena::Handle,
710
proc,
8-
proc::OverloadSet as _,
11+
proc::OverloadSet,
912
proc::{IndexableLengthError, ResolveError},
1013
};
1114

@@ -237,6 +240,61 @@ impl super::Validator {
237240
Ok(())
238241
}
239242

243+
pub(super) fn validate_func_call_with_overloads<'a, F, O, A>(
244+
&self,
245+
module: &crate::Module,
246+
fun: F,
247+
overloads: O,
248+
actuals: A,
249+
) -> Result<(), ExpressionError>
250+
where
251+
F: core::fmt::Display + Copy,
252+
O: OverloadSet,
253+
A: Iterator<Item = (Handle<crate::Expression>, &'a crate::TypeInner)> + ExactSizeIterator,
254+
{
255+
// Start with the set of all overloads available for `fun`.
256+
let mut overloads = overloads;
257+
log::debug!(
258+
"initial overloads for {}: {:#?}",
259+
fun,
260+
overloads.for_debug(&module.types)
261+
);
262+
263+
// If any argument is not a constant expression, then no
264+
// overloads that accept abstract values should be considered.
265+
// `OverloadSet::concrete_only` is supposed to help impose this
266+
// restriction. However, no `MathFunction` accepts a mix of
267+
// abstract and concrete arguments, so we don't need to worry
268+
// about that here.
269+
270+
let actuals_len = actuals.len();
271+
272+
for (i, (expr, ty)) in actuals.into_iter().enumerate() {
273+
// Remove overloads that cannot accept an `i`'th
274+
// argument arguments of type `ty`.
275+
overloads = overloads.arg(i, ty, &module.types);
276+
log::debug!(
277+
"overloads after arg {i}: {:#?}",
278+
overloads.for_debug(&module.types)
279+
);
280+
281+
if overloads.is_empty() {
282+
log::debug!("all overloads eliminated");
283+
return Err(ExpressionError::InvalidArgumentType(
284+
fun.to_string(),
285+
i as u32,
286+
expr,
287+
));
288+
}
289+
}
290+
291+
if actuals_len < overloads.min_arguments() {
292+
return Err(ExpressionError::WrongArgumentCount(fun.to_string()));
293+
}
294+
295+
Ok(())
296+
}
297+
240298
#[allow(clippy::too_many_arguments)]
241299
pub(super) fn validate_expression(
242300
&self,
@@ -1027,43 +1085,15 @@ impl super::Validator {
10271085
_ => unreachable!(),
10281086
};
10291087

1030-
// Start with the set of all overloads available for `fun`.
1031-
let mut overloads = fun.overloads();
1032-
log::debug!(
1033-
"initial overloads for {:?}: {:#?}",
1034-
fun,
1035-
overloads.for_debug(&module.types)
1036-
);
1037-
1038-
// If any argument is not a constant expression, then no
1039-
// overloads that accept abstract values should be considered.
1040-
// `OverloadSet::concrete_only` is supposed to help impose this
1041-
// restriction. However, no `MathFunction` accepts a mix of
1042-
// abstract and concrete arguments, so we don't need to worry
1043-
// about that here.
1044-
1045-
for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1046-
// Remove overloads that cannot accept an `i`'th
1047-
// argument arguments of type `ty`.
1048-
overloads = overloads.arg(i, ty, &module.types);
1049-
log::debug!(
1050-
"overloads after arg {i}: {:#?}",
1051-
overloads.for_debug(&module.types)
1052-
);
1053-
1054-
if overloads.is_empty() {
1055-
log::debug!("all overloads eliminated");
1056-
return Err(ExpressionError::InvalidArgumentType(
1057-
format!("{fun:?}"),
1058-
i as u32,
1059-
expr,
1060-
));
1061-
}
1062-
}
1063-
1064-
if actuals.len() < overloads.min_arguments() {
1065-
return Err(ExpressionError::WrongArgumentCount(format!("{fun:?}")));
1066-
}
1088+
self.validate_func_call_with_overloads(
1089+
module,
1090+
format_args!("{fun:?}"),
1091+
fun.overloads(),
1092+
actuals
1093+
.iter()
1094+
.zip(actual_types.iter())
1095+
.map(|(&val, &ty)| (val, ty)),
1096+
)?;
10671097

10681098
ShaderStages::all()
10691099
}

0 commit comments

Comments
 (0)