|
1 | | -use alloc::{format, string::String}; |
| 1 | +use alloc::{ |
| 2 | + format, |
| 3 | + string::{String, ToString}, |
| 4 | +}; |
2 | 5 |
|
3 | 6 | use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags}; |
4 | 7 | use crate::arena::UniqueArena; |
5 | 8 | use crate::{ |
6 | 9 | arena::Handle, |
7 | 10 | proc, |
8 | | - proc::OverloadSet as _, |
| 11 | + proc::OverloadSet, |
9 | 12 | proc::{IndexableLengthError, ResolveError}, |
10 | 13 | }; |
11 | 14 |
|
@@ -237,6 +240,61 @@ impl super::Validator { |
237 | 240 | Ok(()) |
238 | 241 | } |
239 | 242 |
|
| 243 | + pub(super) fn validate_func_call_with_overloads<'a, F, O, A>( |
| 244 | + &self, |
| 245 | + module: &crate::Module, |
| 246 | + fun: F, |
| 247 | + overloads: O, |
| 248 | + actuals: A, |
| 249 | + ) -> Result<(), ExpressionError> |
| 250 | + where |
| 251 | + F: core::fmt::Display + Copy, |
| 252 | + O: OverloadSet, |
| 253 | + A: Iterator<Item = (Handle<crate::Expression>, &'a crate::TypeInner)> + ExactSizeIterator, |
| 254 | + { |
| 255 | + // Start with the set of all overloads available for `fun`. |
| 256 | + let mut overloads = overloads; |
| 257 | + log::debug!( |
| 258 | + "initial overloads for {}: {:#?}", |
| 259 | + fun, |
| 260 | + overloads.for_debug(&module.types) |
| 261 | + ); |
| 262 | + |
| 263 | + // If any argument is not a constant expression, then no |
| 264 | + // overloads that accept abstract values should be considered. |
| 265 | + // `OverloadSet::concrete_only` is supposed to help impose this |
| 266 | + // restriction. However, no `MathFunction` accepts a mix of |
| 267 | + // abstract and concrete arguments, so we don't need to worry |
| 268 | + // about that here. |
| 269 | + |
| 270 | + let actuals_len = actuals.len(); |
| 271 | + |
| 272 | + for (i, (expr, ty)) in actuals.into_iter().enumerate() { |
| 273 | + // Remove overloads that cannot accept an `i`'th |
| 274 | + // argument arguments of type `ty`. |
| 275 | + overloads = overloads.arg(i, ty, &module.types); |
| 276 | + log::debug!( |
| 277 | + "overloads after arg {i}: {:#?}", |
| 278 | + overloads.for_debug(&module.types) |
| 279 | + ); |
| 280 | + |
| 281 | + if overloads.is_empty() { |
| 282 | + log::debug!("all overloads eliminated"); |
| 283 | + return Err(ExpressionError::InvalidArgumentType( |
| 284 | + fun.to_string(), |
| 285 | + i as u32, |
| 286 | + expr, |
| 287 | + )); |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + if actuals_len < overloads.min_arguments() { |
| 292 | + return Err(ExpressionError::WrongArgumentCount(fun.to_string())); |
| 293 | + } |
| 294 | + |
| 295 | + Ok(()) |
| 296 | + } |
| 297 | + |
240 | 298 | #[allow(clippy::too_many_arguments)] |
241 | 299 | pub(super) fn validate_expression( |
242 | 300 | &self, |
@@ -1027,43 +1085,15 @@ impl super::Validator { |
1027 | 1085 | _ => unreachable!(), |
1028 | 1086 | }; |
1029 | 1087 |
|
1030 | | - // Start with the set of all overloads available for `fun`. |
1031 | | - let mut overloads = fun.overloads(); |
1032 | | - log::debug!( |
1033 | | - "initial overloads for {:?}: {:#?}", |
1034 | | - fun, |
1035 | | - overloads.for_debug(&module.types) |
1036 | | - ); |
1037 | | - |
1038 | | - // If any argument is not a constant expression, then no |
1039 | | - // overloads that accept abstract values should be considered. |
1040 | | - // `OverloadSet::concrete_only` is supposed to help impose this |
1041 | | - // restriction. However, no `MathFunction` accepts a mix of |
1042 | | - // abstract and concrete arguments, so we don't need to worry |
1043 | | - // about that here. |
1044 | | - |
1045 | | - for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() { |
1046 | | - // Remove overloads that cannot accept an `i`'th |
1047 | | - // argument arguments of type `ty`. |
1048 | | - overloads = overloads.arg(i, ty, &module.types); |
1049 | | - log::debug!( |
1050 | | - "overloads after arg {i}: {:#?}", |
1051 | | - overloads.for_debug(&module.types) |
1052 | | - ); |
1053 | | - |
1054 | | - if overloads.is_empty() { |
1055 | | - log::debug!("all overloads eliminated"); |
1056 | | - return Err(ExpressionError::InvalidArgumentType( |
1057 | | - format!("{fun:?}"), |
1058 | | - i as u32, |
1059 | | - expr, |
1060 | | - )); |
1061 | | - } |
1062 | | - } |
1063 | | - |
1064 | | - if actuals.len() < overloads.min_arguments() { |
1065 | | - return Err(ExpressionError::WrongArgumentCount(format!("{fun:?}"))); |
1066 | | - } |
| 1088 | + self.validate_func_call_with_overloads( |
| 1089 | + module, |
| 1090 | + format_args!("{fun:?}"), |
| 1091 | + fun.overloads(), |
| 1092 | + actuals |
| 1093 | + .iter() |
| 1094 | + .zip(actual_types.iter()) |
| 1095 | + .map(|(&val, &ty)| (val, ty)), |
| 1096 | + )?; |
1067 | 1097 |
|
1068 | 1098 | ShaderStages::all() |
1069 | 1099 | } |
|
0 commit comments