Skip to content

Commit 9a0ed3f

Browse files
refactor(naga): make math_function_helper a more generic function_helper
1 parent 1e382d2 commit 9a0ed3f

1 file changed

Lines changed: 32 additions & 21 deletions

File tree

  • naga/src/front/wgsl/lower

naga/src/front/wgsl/lower/mod.rs

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use alloc::{
44
string::{String, ToString},
55
vec::Vec,
66
};
7+
use arrayvec::ArrayVec;
78
use core::num::NonZeroU32;
89

910
use crate::common::wgsl::{TryToWgsl, TypeContext};
@@ -2465,7 +2466,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24652466

24662467
ir::Expression::Derivative { axis, ctrl, expr }
24672468
} else if let Some(fun) = conv::map_standard_fun(function.name) {
2468-
self.math_function_helper(span, fun, arguments, ctx)?
2469+
let lowered_arguments = self.function_helper(
2470+
span,
2471+
fun,
2472+
|_args, _ctx| fun.overloads(),
2473+
ArrayVec::<_, 4>::from_iter(arguments.iter().copied()),
2474+
ctx,
2475+
)?;
2476+
2477+
ir::Expression::Math {
2478+
fun,
2479+
arg: lowered_arguments[0],
2480+
arg1: lowered_arguments.get(1).cloned(),
2481+
arg2: lowered_arguments.get(2).cloned(),
2482+
arg3: lowered_arguments.get(3).cloned(),
2483+
}
24692484
} else if let Some(fun) = Texture::map(function.name) {
24702485
self.texture_sample_helper(fun, arguments, span, ctx)?
24712486
} else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
@@ -2952,31 +2967,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29522967
}
29532968
}
29542969

2955-
/// Generate a Naga IR [`Math`] expression.
2956-
///
2957-
/// Generate Naga IR for a call to the [`MathFunction`] `fun`, whose
2970+
/// Generate a Naga IR expression for a call to `fun`, whose
29582971
/// unlowered arguments are `ast_arguments`.
29592972
///
29602973
/// The `span` argument should give the span of the function name in the
29612974
/// call expression.
2962-
///
2963-
/// [`Math`]: ir::Expression::Math
2964-
/// [`MathFunction`]: ir::MathFunction
2965-
fn math_function_helper(
2975+
fn function_helper<const NUM_ARGS: usize, F, O, R>(
29662976
&mut self,
29672977
span: Span,
2968-
fun: ir::MathFunction,
2969-
ast_arguments: &[Handle<ast::Expression<'source>>],
2978+
fun: F,
2979+
resolve_overloads: R,
2980+
ast_arguments: ArrayVec<Handle<ast::Expression<'source>>, { NUM_ARGS }>,
29702981
ctx: &mut ExpressionContext<'source, '_, '_>,
2971-
) -> Result<'source, ir::Expression> {
2972-
let mut lowered_arguments = Vec::with_capacity(ast_arguments.len());
2973-
for &arg in ast_arguments {
2982+
) -> Result<'source, ArrayVec<Handle<ir::Expression>, { NUM_ARGS }>>
2983+
where
2984+
F: TryToWgsl + core::fmt::Debug + Copy,
2985+
O: proc::OverloadSet,
2986+
R: FnOnce(&[Handle<ir::Expression>], &mut ExpressionContext<'source, '_, '_>) -> O,
2987+
{
2988+
let mut lowered_arguments = ArrayVec::<_, { NUM_ARGS }>::new();
2989+
2990+
for &arg in ast_arguments.iter() {
29742991
let lowered = self.expression_for_abstract(arg, ctx)?;
29752992
ctx.grow_types(lowered)?;
29762993
lowered_arguments.push(lowered);
29772994
}
29782995

2979-
let fun_overloads = fun.overloads();
2996+
let fun_overloads = resolve_overloads(&lowered_arguments, ctx);
29802997
let rule = self.resolve_overloads(span, fun, fun_overloads, &lowered_arguments, ctx)?;
29812998
self.apply_automatic_conversions_for_call(&rule, &mut lowered_arguments, ctx)?;
29822999

@@ -2987,13 +3004,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29873004
ctx.module.generate_predeclared_type(predeclared);
29883005
}
29893006

2990-
Ok(ir::Expression::Math {
2991-
fun,
2992-
arg: lowered_arguments[0],
2993-
arg1: lowered_arguments.get(1).cloned(),
2994-
arg2: lowered_arguments.get(2).cloned(),
2995-
arg3: lowered_arguments.get(3).cloned(),
2996-
})
3007+
Ok(lowered_arguments)
29973008
}
29983009

29993010
/// Choose the right overload for a function call.

0 commit comments

Comments
 (0)