@@ -4,6 +4,7 @@ use alloc::{
44 string:: { String , ToString } ,
55 vec:: Vec ,
66} ;
7+ use arrayvec:: ArrayVec ;
78use core:: num:: NonZeroU32 ;
89
910use crate :: common:: wgsl:: { TryToWgsl , TypeContext } ;
@@ -2465,7 +2466,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24652466
24662467 ir:: Expression :: Derivative { axis, ctrl, expr }
24672468 } else if let Some ( fun) = conv:: map_standard_fun ( function. name ) {
2468- self . math_function_helper ( span, fun, arguments, ctx) ?
2469+ let lowered_arguments = self . function_helper (
2470+ span,
2471+ fun,
2472+ |_args, _ctx| fun. overloads ( ) ,
2473+ ArrayVec :: < _ , 4 > :: from_iter ( arguments. iter ( ) . copied ( ) ) ,
2474+ ctx,
2475+ ) ?;
2476+
2477+ ir:: Expression :: Math {
2478+ fun,
2479+ arg : lowered_arguments[ 0 ] ,
2480+ arg1 : lowered_arguments. get ( 1 ) . cloned ( ) ,
2481+ arg2 : lowered_arguments. get ( 2 ) . cloned ( ) ,
2482+ arg3 : lowered_arguments. get ( 3 ) . cloned ( ) ,
2483+ }
24692484 } else if let Some ( fun) = Texture :: map ( function. name ) {
24702485 self . texture_sample_helper ( fun, arguments, span, ctx) ?
24712486 } else if let Some ( ( op, cop) ) = conv:: map_subgroup_operation ( function. name ) {
@@ -2952,31 +2967,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29522967 }
29532968 }
29542969
2955- /// Generate a Naga IR [`Math`] expression.
2956- ///
2957- /// Generate Naga IR for a call to the [`MathFunction`] `fun`, whose
2970+ /// Generate a Naga IR expression for a call to `fun`, whose
29582971 /// unlowered arguments are `ast_arguments`.
29592972 ///
29602973 /// The `span` argument should give the span of the function name in the
29612974 /// call expression.
2962- ///
2963- /// [`Math`]: ir::Expression::Math
2964- /// [`MathFunction`]: ir::MathFunction
2965- fn math_function_helper (
2975+ fn function_helper < const NUM_ARGS : usize , F , O , R > (
29662976 & mut self ,
29672977 span : Span ,
2968- fun : ir:: MathFunction ,
2969- ast_arguments : & [ Handle < ast:: Expression < ' source > > ] ,
2978+ fun : F ,
2979+ resolve_overloads : R ,
2980+ ast_arguments : ArrayVec < Handle < ast:: Expression < ' source > > , { NUM_ARGS } > ,
29702981 ctx : & mut ExpressionContext < ' source , ' _ , ' _ > ,
2971- ) -> Result < ' source , ir:: Expression > {
2972- let mut lowered_arguments = Vec :: with_capacity ( ast_arguments. len ( ) ) ;
2973- for & arg in ast_arguments {
2982+ ) -> Result < ' source , ArrayVec < Handle < ir:: Expression > , { NUM_ARGS } > >
2983+ where
2984+ F : TryToWgsl + core:: fmt:: Debug + Copy ,
2985+ O : proc:: OverloadSet ,
2986+ R : FnOnce ( & [ Handle < ir:: Expression > ] , & mut ExpressionContext < ' source , ' _ , ' _ > ) -> O ,
2987+ {
2988+ let mut lowered_arguments = ArrayVec :: < _ , { NUM_ARGS } > :: new ( ) ;
2989+
2990+ for & arg in ast_arguments. iter ( ) {
29742991 let lowered = self . expression_for_abstract ( arg, ctx) ?;
29752992 ctx. grow_types ( lowered) ?;
29762993 lowered_arguments. push ( lowered) ;
29772994 }
29782995
2979- let fun_overloads = fun . overloads ( ) ;
2996+ let fun_overloads = resolve_overloads ( & lowered_arguments , ctx ) ;
29802997 let rule = self . resolve_overloads ( span, fun, fun_overloads, & lowered_arguments, ctx) ?;
29812998 self . apply_automatic_conversions_for_call ( & rule, & mut lowered_arguments, ctx) ?;
29822999
@@ -2987,13 +3004,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
29873004 ctx. module . generate_predeclared_type ( predeclared) ;
29883005 }
29893006
2990- Ok ( ir:: Expression :: Math {
2991- fun,
2992- arg : lowered_arguments[ 0 ] ,
2993- arg1 : lowered_arguments. get ( 1 ) . cloned ( ) ,
2994- arg2 : lowered_arguments. get ( 2 ) . cloned ( ) ,
2995- arg3 : lowered_arguments. get ( 3 ) . cloned ( ) ,
2996- } )
3007+ Ok ( lowered_arguments)
29973008 }
29983009
29993010 /// Choose the right overload for a function call.
0 commit comments