Skip to content

Commit ffb6fc9

Browse files
refactor(naga): make math_function_helper a more generic function_helper
1 parent 804ce8d commit ffb6fc9

1 file changed

Lines changed: 32 additions & 21 deletions

File tree

  • naga/src/front/wgsl/lower

naga/src/front/wgsl/lower/mod.rs

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use alloc::{
44
string::{String, ToString},
55
vec::Vec,
66
};
7+
use arrayvec::ArrayVec;
78
use core::num::NonZeroU32;
89

910
use crate::common::wgsl::{TryToWgsl, TypeContext};
@@ -2465,7 +2466,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24652466

24662467
ir::Expression::Derivative { axis, ctrl, expr }
24672468
} else if let Some(fun) = conv::map_standard_fun(function.name) {
2468-
self.math_function_helper(span, fun, arguments, ctx)?
2469+
let lowered_arguments = self.function_helper(
2470+
span,
2471+
fun,
2472+
|_args, _ctx| fun.overloads(),
2473+
ArrayVec::<_, 4>::from_iter(arguments.iter().copied()),
2474+
ctx,
2475+
)?;
2476+
2477+
ir::Expression::Math {
2478+
fun,
2479+
arg: lowered_arguments[0],
2480+
arg1: lowered_arguments.get(1).cloned(),
2481+
arg2: lowered_arguments.get(2).cloned(),
2482+
arg3: lowered_arguments.get(3).cloned(),
2483+
}
24692484
} else if let Some(fun) = Texture::map(function.name) {
24702485
self.texture_sample_helper(fun, arguments, span, ctx)?
24712486
} else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
@@ -2944,31 +2959,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29442959
}
29452960
}
29462961

2947-
/// Generate a Naga IR [`Math`] expression.
2948-
///
2949-
/// Generate Naga IR for a call to the [`MathFunction`] `fun`, whose
2962+
/// Generate a Naga IR expression for a call to `fun`, whose
29502963
/// unlowered arguments are `ast_arguments`.
29512964
///
29522965
/// The `span` argument should give the span of the function name in the
29532966
/// call expression.
2954-
///
2955-
/// [`Math`]: ir::Expression::Math
2956-
/// [`MathFunction`]: ir::MathFunction
2957-
fn math_function_helper(
2967+
fn function_helper<const NUM_ARGS: usize, F, O, R>(
29582968
&mut self,
29592969
span: Span,
2960-
fun: ir::MathFunction,
2961-
ast_arguments: &[Handle<ast::Expression<'source>>],
2970+
fun: F,
2971+
resolve_overloads: R,
2972+
ast_arguments: ArrayVec<Handle<ast::Expression<'source>>, { NUM_ARGS }>,
29622973
ctx: &mut ExpressionContext<'source, '_, '_>,
2963-
) -> Result<'source, ir::Expression> {
2964-
let mut lowered_arguments = Vec::with_capacity(ast_arguments.len());
2965-
for &arg in ast_arguments {
2974+
) -> Result<'source, ArrayVec<Handle<ir::Expression>, { NUM_ARGS }>>
2975+
where
2976+
F: TryToWgsl + core::fmt::Debug + Copy,
2977+
O: proc::OverloadSet,
2978+
R: FnOnce(&[Handle<ir::Expression>], &mut ExpressionContext<'source, '_, '_>) -> O,
2979+
{
2980+
let mut lowered_arguments = ArrayVec::<_, { NUM_ARGS }>::new();
2981+
2982+
for &arg in ast_arguments.iter() {
29662983
let lowered = self.expression_for_abstract(arg, ctx)?;
29672984
ctx.grow_types(lowered)?;
29682985
lowered_arguments.push(lowered);
29692986
}
29702987

2971-
let fun_overloads = fun.overloads();
2988+
let fun_overloads = resolve_overloads(&lowered_arguments, ctx);
29722989
let rule = self.resolve_overloads(span, fun, fun_overloads, &lowered_arguments, ctx)?;
29732990
self.apply_automatic_conversions_for_call(&rule, &mut lowered_arguments, ctx)?;
29742991

@@ -2979,13 +2996,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29792996
ctx.module.generate_predeclared_type(predeclared);
29802997
}
29812998

2982-
Ok(ir::Expression::Math {
2983-
fun,
2984-
arg: lowered_arguments[0],
2985-
arg1: lowered_arguments.get(1).cloned(),
2986-
arg2: lowered_arguments.get(2).cloned(),
2987-
arg3: lowered_arguments.get(3).cloned(),
2988-
})
2999+
Ok(lowered_arguments)
29893000
}
29903001

29913002
/// Choose the right overload for a function call.

0 commit comments

Comments
 (0)