@@ -1460,8 +1460,19 @@ impl SsaBuilder {
14601460
14611461 // Check if this is a non-primitive type that might use operator overloading
14621462 // For opaque/named types, try to dispatch to trait method
1463- eprintln ! ( "[DEBUG SSA] Binary op {:?}, left.ty={:?}, right.ty={:?}" , op, left. ty, right. ty) ;
1464- if let Some ( trait_call) = self . try_operator_trait_dispatch ( block_id, op, left, right, & expr. ty ) ? {
1463+ // First, resolve the actual type from the variable if the expression is a Variable
1464+ let left_actual_ty = self . resolve_actual_type ( & left. node , & left. ty ) ;
1465+ let right_actual_ty = self . resolve_actual_type ( & right. node , & right. ty ) ;
1466+ eprintln ! ( "[DEBUG SSA] Binary op {:?}, left.ty={:?} (actual: {:?}), right.ty={:?} (actual: {:?})" ,
1467+ op, left. ty, left_actual_ty, right. ty, right_actual_ty) ;
1468+
1469+ // Create a modified left/right with resolved types for trait dispatch
1470+ let mut left_with_type = left. clone ( ) ;
1471+ let mut right_with_type = right. clone ( ) ;
1472+ left_with_type. ty = left_actual_ty;
1473+ right_with_type. ty = right_actual_ty;
1474+
1475+ if let Some ( trait_call) = self . try_operator_trait_dispatch ( block_id, op, & left_with_type, & right_with_type, & expr. ty ) ? {
14651476 eprintln ! ( "[DEBUG SSA] Using trait dispatch for binary op" ) ;
14661477 return Ok ( trait_call) ;
14671478 }
@@ -1541,8 +1552,17 @@ impl SsaBuilder {
15411552 ( crate :: hir:: HirCallable :: Function ( func_id) , None )
15421553 } else if let Some ( link_name) = self . extern_link_names . get ( func_name) {
15431554 // External function with link_name (e.g., tensor_add -> $Tensor$add)
1555+ log:: debug!( "[SSA] Resolved extern call '{}' -> '{}'" , name_str, link_name) ;
15441556 ( crate :: hir:: HirCallable :: Symbol ( link_name. clone ( ) ) , None )
15451557 } else {
1558+ // Debug: Log what we're looking for and what's available
1559+ if !self . extern_link_names . is_empty ( ) {
1560+ let available: Vec < _ > = self . extern_link_names . keys ( )
1561+ . filter_map ( |k| k. resolve_global ( ) )
1562+ . collect ( ) ;
1563+ log:: debug!( "[SSA] Function '{}' not in extern_link_names ({} entries). Sample: {:?}" ,
1564+ name_str, self . extern_link_names. len( ) , & available[ ..available. len( ) . min( 10 ) ] ) ;
1565+ }
15461566 // Variable lookup (function pointer)
15471567 let callee_val = self . translate_expression ( block_id, callee) ?;
15481568 ( crate :: hir:: HirCallable :: Indirect ( callee_val) , Some ( callee_val) )
@@ -2056,44 +2076,57 @@ impl SsaBuilder {
20562076 ) ) ;
20572077 } ;
20582078
2059- // Determine result type: if the method call has type Any, look up the trait method's return type
2060- let result_type = if matches ! ( expr. ty, Type :: Any ) {
2061- // Look up the trait method's return type from the type registry
2062- let receiver_type_id = match & receiver_type {
2063- Type :: Named { id, .. } => * id,
2064- _ => return Err ( crate :: CompilerError :: Analysis (
2065- format ! ( "Method receiver is not a named type: {:?}" , receiver_type)
2066- ) ) ,
2067- } ;
2068-
2069- // Find the trait implementation for this type
2070- let mut method_return_type = None ;
2071- for ( _trait_id, impls) in self . type_registry . iter_implementations ( ) {
2072- for impl_def in impls {
2073- if let Type :: Named { id : impl_type_id, .. } = & impl_def. for_type {
2074- if * impl_type_id == receiver_type_id {
2075- // Find the method in this impl
2076- for method in & impl_def. methods {
2077- if method. signature . name == method_call. method {
2078- method_return_type = Some ( method. signature . return_type . clone ( ) ) ;
2079- break ;
2079+ // Determine result type based on the mangled function name
2080+ // For extern methods like $Tensor$sum_f32, parse the return type from the suffix
2081+ let result_type = if matches ! ( expr. ty, Type :: Any | Type :: Unknown ) {
2082+ let mangled_str = mangled_name. resolve_global ( ) . unwrap_or_default ( ) ;
2083+
2084+ // Check common return type suffixes for Tensor methods
2085+ // Methods like sum_f32, mean_f32 return f32
2086+ // Methods like zeros, ones, arange return Tensor (opaque ptr)
2087+ let hir_type = if mangled_str. ends_with ( "_f32" ) || mangled_str. contains ( "$sum" ) || mangled_str. contains ( "$mean" )
2088+ || mangled_str. contains ( "$max" ) || mangled_str. contains ( "$min" ) || mangled_str. contains ( "$std" ) || mangled_str. contains ( "$var" ) {
2089+ // Reduction methods that return f32
2090+ log:: debug!( "[METHOD_CALL] Inferred F32 return type for '{}'" , mangled_str) ;
2091+ HirType :: F32
2092+ } else if mangled_str. contains ( "$ndim" ) || mangled_str. contains ( "$numel" ) {
2093+ // Methods that return i64
2094+ log:: debug!( "[METHOD_CALL] Inferred I64 return type for '{}'" , mangled_str) ;
2095+ HirType :: I64
2096+ } else if let Type :: Named { id, .. } = & receiver_type {
2097+ // Fall back to looking up in trait implementations for named types
2098+ let receiver_type_id = * id;
2099+ let mut method_return_type = None ;
2100+ for ( _trait_id, impls) in self . type_registry . iter_implementations ( ) {
2101+ for impl_def in impls {
2102+ if let Type :: Named { id : impl_type_id, .. } = & impl_def. for_type {
2103+ if * impl_type_id == receiver_type_id {
2104+ for method in & impl_def. methods {
2105+ if method. signature . name == method_call. method {
2106+ method_return_type = Some ( method. signature . return_type . clone ( ) ) ;
2107+ break ;
2108+ }
20802109 }
20812110 }
20822111 }
2112+ if method_return_type. is_some ( ) {
2113+ break ;
2114+ }
20832115 }
20842116 if method_return_type. is_some ( ) {
20852117 break ;
20862118 }
20872119 }
2088- if method_return_type. is_some ( ) {
2089- break ;
2090- }
2091- }
2092-
2093- let typed_return_type = method_return_type. unwrap_or ( Type :: Primitive ( zyntax_typed_ast:: PrimitiveType :: I32 ) ) ;
2094- self . convert_type ( & typed_return_type)
2120+ let typed_return_type = method_return_type. unwrap_or ( Type :: Primitive ( zyntax_typed_ast:: PrimitiveType :: I64 ) ) ;
2121+ self . convert_type ( & typed_return_type)
2122+ } else {
2123+ // For extern types, assume opaque return (returns same type as receiver)
2124+ log:: debug!( "[METHOD_CALL] Assuming opaque return type for extern method '{}'" , mangled_str) ;
2125+ self . convert_type ( & receiver_type)
2126+ } ;
2127+ hir_type
20952128 } else {
2096- // Otherwise use the annotated type
2129+ // Use the annotated type from the expression
20972130 self . convert_type ( & expr. ty )
20982131 } ;
20992132
@@ -3555,7 +3588,15 @@ impl SsaBuilder {
35553588 let type_name = type_name. unwrap ( ) ;
35563589
35573590 // Construct the method function name: TypeName$method
3558- let method_symbol = format ! ( "{}${}" , type_name, method_name) ;
3591+ // For extern types, the type_name already starts with '$' (e.g., "$Tensor")
3592+ // so we just append $method to get "$Tensor$add"
3593+ let method_symbol = if type_name. starts_with ( '$' ) {
3594+ // Already has $ prefix (extern type)
3595+ format ! ( "{}${}" , type_name, method_name)
3596+ } else {
3597+ // Regular type, add $ prefix for ZRTL compatibility
3598+ format ! ( "${}${}" , type_name, method_name)
3599+ } ;
35593600 let method_name_interned = InternedString :: new_global ( & method_symbol) ;
35603601 log:: debug!( "[SSA] Operator trait dispatch: {} for type {}" , method_symbol, type_name) ;
35613602
@@ -3573,7 +3614,11 @@ impl SsaBuilder {
35733614 // Translate arguments
35743615 let left_val = self . translate_expression ( block_id, left) ?;
35753616 let right_val = self . translate_expression ( block_id, right) ?;
3576- let hir_result_type = self . convert_type ( result_type) ;
3617+
3618+ // For binary operations, the result type should be the same as the operand type
3619+ // (e.g., Tensor + Tensor = Tensor). Use left operand's type instead of expression type.
3620+ let hir_result_type = self . convert_type ( left_type) ;
3621+ log:: debug!( "[SSA] Operator trait dispatch result type: {:?} (from left type: {:?})" , hir_result_type, left_type) ;
35773622
35783623 // Create call instruction to the trait method
35793624 let result = if hir_result_type != HirType :: Void {
@@ -3598,6 +3643,29 @@ impl SsaBuilder {
35983643 Ok ( Some ( result. unwrap_or_else ( || self . create_undef ( HirType :: Void ) ) ) )
35993644 }
36003645
3646+ /// Resolve the actual type for an expression, looking up variable types if needed.
3647+ /// This is needed because expression nodes may have type `Any` even when the
3648+ /// variable has a more specific type recorded in var_typed_ast_types.
3649+ fn resolve_actual_type ( & self , expr : & zyntax_typed_ast:: typed_ast:: TypedExpression , fallback : & Type ) -> Type {
3650+ use zyntax_typed_ast:: typed_ast:: TypedExpression ;
3651+
3652+ match expr {
3653+ TypedExpression :: Variable ( name) => {
3654+ // Look up the variable's actual type from our tracking
3655+ if let Some ( var_ty) = self . var_typed_ast_types . get ( name) {
3656+ if !matches ! ( var_ty, Type :: Any | Type :: Unknown ) {
3657+ log:: debug!( "[resolve_actual_type] Variable '{}' has type {:?}" ,
3658+ name. resolve_global( ) . unwrap_or_default( ) , var_ty) ;
3659+ return var_ty. clone ( ) ;
3660+ }
3661+ }
3662+ // Fall back to the expression's type
3663+ fallback. clone ( )
3664+ }
3665+ _ => fallback. clone ( ) ,
3666+ }
3667+ }
3668+
36013669 /// Check if a type should use trait dispatch for operators
36023670 fn is_trait_dispatchable_type ( & self , ty : & Type ) -> bool {
36033671 match ty {
0 commit comments