@@ -4,6 +4,7 @@ use alloc::{
44 string:: { String , ToString } ,
55 vec:: Vec ,
66} ;
7+ use arrayvec:: ArrayVec ;
78use core:: num:: NonZeroU32 ;
89
910use crate :: common:: wgsl:: { TryToWgsl , TypeContext } ;
@@ -2465,7 +2466,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24652466
24662467 ir:: Expression :: Derivative { axis, ctrl, expr }
24672468 } else if let Some ( fun) = conv:: map_standard_fun ( function. name ) {
2468- self . math_function_helper ( span, fun, arguments, ctx) ?
2469+ let lowered_arguments = self . function_helper (
2470+ span,
2471+ fun,
2472+ |_args, _ctx| fun. overloads ( ) ,
2473+ ArrayVec :: < _ , 4 > :: from_iter ( arguments. iter ( ) . copied ( ) ) ,
2474+ ctx,
2475+ ) ?;
2476+
2477+ ir:: Expression :: Math {
2478+ fun,
2479+ arg : lowered_arguments[ 0 ] ,
2480+ arg1 : lowered_arguments. get ( 1 ) . cloned ( ) ,
2481+ arg2 : lowered_arguments. get ( 2 ) . cloned ( ) ,
2482+ arg3 : lowered_arguments. get ( 3 ) . cloned ( ) ,
2483+ }
24692484 } else if let Some ( fun) = Texture :: map ( function. name ) {
24702485 self . texture_sample_helper ( fun, arguments, span, ctx) ?
24712486 } else if let Some ( ( op, cop) ) = conv:: map_subgroup_operation ( function. name ) {
@@ -2944,31 +2959,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29442959 }
29452960 }
29462961
2947- /// Generate a Naga IR [`Math`] expression.
2948- ///
2949- /// Generate Naga IR for a call to the [`MathFunction`] `fun`, whose
2962+ /// Generate a Naga IR expression for a call to `fun`, whose
29502963 /// unlowered arguments are `ast_arguments`.
29512964 ///
29522965 /// The `span` argument should give the span of the function name in the
29532966 /// call expression.
2954- ///
2955- /// [`Math`]: ir::Expression::Math
2956- /// [`MathFunction`]: ir::MathFunction
2957- fn math_function_helper (
2967+ fn function_helper < const NUM_ARGS : usize , F , O , R > (
29582968 & mut self ,
29592969 span : Span ,
2960- fun : ir:: MathFunction ,
2961- ast_arguments : & [ Handle < ast:: Expression < ' source > > ] ,
2970+ fun : F ,
2971+ resolve_overloads : R ,
2972+ ast_arguments : ArrayVec < Handle < ast:: Expression < ' source > > , { NUM_ARGS } > ,
29622973 ctx : & mut ExpressionContext < ' source , ' _ , ' _ > ,
2963- ) -> Result < ' source , ir:: Expression > {
2964- let mut lowered_arguments = Vec :: with_capacity ( ast_arguments. len ( ) ) ;
2965- for & arg in ast_arguments {
2974+ ) -> Result < ' source , ArrayVec < Handle < ir:: Expression > , { NUM_ARGS } > >
2975+ where
2976+ F : TryToWgsl + core:: fmt:: Debug + Copy ,
2977+ O : proc:: OverloadSet ,
2978+ R : FnOnce ( & [ Handle < ir:: Expression > ] , & mut ExpressionContext < ' source , ' _ , ' _ > ) -> O ,
2979+ {
2980+ let mut lowered_arguments = ArrayVec :: < _ , { NUM_ARGS } > :: new ( ) ;
2981+
2982+ for & arg in ast_arguments. iter ( ) {
29662983 let lowered = self . expression_for_abstract ( arg, ctx) ?;
29672984 ctx. grow_types ( lowered) ?;
29682985 lowered_arguments. push ( lowered) ;
29692986 }
29702987
2971- let fun_overloads = fun . overloads ( ) ;
2988+ let fun_overloads = resolve_overloads ( & lowered_arguments , ctx ) ;
29722989 let rule = self . resolve_overloads ( span, fun, fun_overloads, & lowered_arguments, ctx) ?;
29732990 self . apply_automatic_conversions_for_call ( & rule, & mut lowered_arguments, ctx) ?;
29742991
@@ -2979,13 +2996,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29792996 ctx. module . generate_predeclared_type ( predeclared) ;
29802997 }
29812998
2982- Ok ( ir:: Expression :: Math {
2983- fun,
2984- arg : lowered_arguments[ 0 ] ,
2985- arg1 : lowered_arguments. get ( 1 ) . cloned ( ) ,
2986- arg2 : lowered_arguments. get ( 2 ) . cloned ( ) ,
2987- arg3 : lowered_arguments. get ( 3 ) . cloned ( ) ,
2988- } )
2999+ Ok ( lowered_arguments)
29893000 }
29903001
29913002 /// Choose the right overload for a function call.
0 commit comments