Skip to content

Commit 25f810a

Browse files
committed
fix: binary operators and method return types for ZynML tensors
- Add fold_left_ops and make_pair helpers to Grammar2 interpreter for proper left-associative binary expression construction - Fix operator trait dispatch to avoid double $ prefix in symbol names - Add resolve_actual_type to look up variable types from var_typed_ast_types - Set binary operation result type from left operand type (Tensor + Tensor = Tensor) - Infer F32 return type for reduction methods (sum, mean, max, min, std, var) - Add builtins field to LoweringConfig and pass grammar builtins through hello_basic.zynml now runs correctly with tensor arithmetic and method calls.
1 parent fef7bdb commit 25f810a

7 files changed

Lines changed: 285 additions & 54 deletions

File tree

crates/compiler/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,7 @@ pub fn compile_to_hir(
12681268
hot_reload: config.hot_reload,
12691269
strict_mode: false, // Default to non-strict
12701270
import_resolver: config.import_resolver.clone(),
1271+
builtins: indexmap::IndexMap::new(), // Empty - callers with grammar should use runtime directly
12711272
};
12721273

12731274
// Create arena for string interning (needed for async transformation)

crates/compiler/src/lowering.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ pub struct LoweringConfig {
300300
pub strict_mode: bool,
301301
/// Optional import resolver for resolving import statements
302302
pub import_resolver: Option<Arc<dyn ImportResolver>>,
303+
/// Builtin function mappings (e.g., "tensor_sum_f32" -> "$Tensor$sum_f32")
304+
/// These are added to extern_link_names for resolving extern calls
305+
pub builtins: indexmap::IndexMap<String, String>,
303306
}
304307

305308
impl std::fmt::Debug for LoweringConfig {
@@ -324,6 +327,7 @@ impl Default for LoweringConfig {
324327
hot_reload: false,
325328
strict_mode: false,
326329
import_resolver: None,
330+
builtins: indexmap::IndexMap::new(),
327331
}
328332
}
329333
}
@@ -373,11 +377,22 @@ impl LoweringContext {
373377
arena: Arc<Mutex<AstArena>>,
374378
config: LoweringConfig,
375379
) -> Self {
380+
// Initialize symbol table with builtins from config
381+
let mut symbols = SymbolTable::default();
382+
if !config.builtins.is_empty() {
383+
log::debug!("[LOWERING] Populating extern_link_names with {} builtins", config.builtins.len());
384+
}
385+
for (alias, target) in &config.builtins {
386+
let alias_interned = InternedString::new_global(alias);
387+
symbols.extern_link_names.insert(alias_interned, target.clone());
388+
log::trace!("[LOWERING] Added builtin: '{}' -> '{}'", alias, target);
389+
}
390+
376391
Self {
377392
module: HirModule::new(module_name),
378393
type_registry,
379394
arena,
380-
symbols: SymbolTable::default(),
395+
symbols,
381396
diagnostics: Vec::new(),
382397
config,
383398
vtable_registry: crate::vtable_registry::VtableRegistry::new(),

crates/compiler/src/ssa.rs

Lines changed: 101 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,19 @@ impl SsaBuilder {
14601460

14611461
// Check if this is a non-primitive type that might use operator overloading
14621462
// For opaque/named types, try to dispatch to trait method
1463-
eprintln!("[DEBUG SSA] Binary op {:?}, left.ty={:?}, right.ty={:?}", op, left.ty, right.ty);
1464-
if let Some(trait_call) = self.try_operator_trait_dispatch(block_id, op, left, right, &expr.ty)? {
1463+
// First, resolve the actual type from the variable if the expression is a Variable
1464+
let left_actual_ty = self.resolve_actual_type(&left.node, &left.ty);
1465+
let right_actual_ty = self.resolve_actual_type(&right.node, &right.ty);
1466+
eprintln!("[DEBUG SSA] Binary op {:?}, left.ty={:?} (actual: {:?}), right.ty={:?} (actual: {:?})",
1467+
op, left.ty, left_actual_ty, right.ty, right_actual_ty);
1468+
1469+
// Create a modified left/right with resolved types for trait dispatch
1470+
let mut left_with_type = left.clone();
1471+
let mut right_with_type = right.clone();
1472+
left_with_type.ty = left_actual_ty;
1473+
right_with_type.ty = right_actual_ty;
1474+
1475+
if let Some(trait_call) = self.try_operator_trait_dispatch(block_id, op, &left_with_type, &right_with_type, &expr.ty)? {
14651476
eprintln!("[DEBUG SSA] Using trait dispatch for binary op");
14661477
return Ok(trait_call);
14671478
}
@@ -1541,8 +1552,17 @@ impl SsaBuilder {
15411552
(crate::hir::HirCallable::Function(func_id), None)
15421553
} else if let Some(link_name) = self.extern_link_names.get(func_name) {
15431554
// External function with link_name (e.g., tensor_add -> $Tensor$add)
1555+
log::debug!("[SSA] Resolved extern call '{}' -> '{}'", name_str, link_name);
15441556
(crate::hir::HirCallable::Symbol(link_name.clone()), None)
15451557
} else {
1558+
// Debug: Log what we're looking for and what's available
1559+
if !self.extern_link_names.is_empty() {
1560+
let available: Vec<_> = self.extern_link_names.keys()
1561+
.filter_map(|k| k.resolve_global())
1562+
.collect();
1563+
log::debug!("[SSA] Function '{}' not in extern_link_names ({} entries). Sample: {:?}",
1564+
name_str, self.extern_link_names.len(), &available[..available.len().min(10)]);
1565+
}
15461566
// Variable lookup (function pointer)
15471567
let callee_val = self.translate_expression(block_id, callee)?;
15481568
(crate::hir::HirCallable::Indirect(callee_val), Some(callee_val))
@@ -2056,44 +2076,57 @@ impl SsaBuilder {
20562076
));
20572077
};
20582078

2059-
// Determine result type: if the method call has type Any, look up the trait method's return type
2060-
let result_type = if matches!(expr.ty, Type::Any) {
2061-
// Look up the trait method's return type from the type registry
2062-
let receiver_type_id = match &receiver_type {
2063-
Type::Named { id, .. } => *id,
2064-
_ => return Err(crate::CompilerError::Analysis(
2065-
format!("Method receiver is not a named type: {:?}", receiver_type)
2066-
)),
2067-
};
2068-
2069-
// Find the trait implementation for this type
2070-
let mut method_return_type = None;
2071-
for (_trait_id, impls) in self.type_registry.iter_implementations() {
2072-
for impl_def in impls {
2073-
if let Type::Named { id: impl_type_id, .. } = &impl_def.for_type {
2074-
if *impl_type_id == receiver_type_id {
2075-
// Find the method in this impl
2076-
for method in &impl_def.methods {
2077-
if method.signature.name == method_call.method {
2078-
method_return_type = Some(method.signature.return_type.clone());
2079-
break;
2079+
// Determine result type based on the mangled function name
2080+
// For extern methods like $Tensor$sum_f32, parse the return type from the suffix
2081+
let result_type = if matches!(expr.ty, Type::Any | Type::Unknown) {
2082+
let mangled_str = mangled_name.resolve_global().unwrap_or_default();
2083+
2084+
// Check common return type suffixes for Tensor methods
2085+
// Methods like sum_f32, mean_f32 return f32
2086+
// Methods like zeros, ones, arange return Tensor (opaque ptr)
2087+
let hir_type = if mangled_str.ends_with("_f32") || mangled_str.contains("$sum") || mangled_str.contains("$mean")
2088+
|| mangled_str.contains("$max") || mangled_str.contains("$min") || mangled_str.contains("$std") || mangled_str.contains("$var") {
2089+
// Reduction methods that return f32
2090+
log::debug!("[METHOD_CALL] Inferred F32 return type for '{}'", mangled_str);
2091+
HirType::F32
2092+
} else if mangled_str.contains("$ndim") || mangled_str.contains("$numel") {
2093+
// Methods that return i64
2094+
log::debug!("[METHOD_CALL] Inferred I64 return type for '{}'", mangled_str);
2095+
HirType::I64
2096+
} else if let Type::Named { id, .. } = &receiver_type {
2097+
// Fall back to looking up in trait implementations for named types
2098+
let receiver_type_id = *id;
2099+
let mut method_return_type = None;
2100+
for (_trait_id, impls) in self.type_registry.iter_implementations() {
2101+
for impl_def in impls {
2102+
if let Type::Named { id: impl_type_id, .. } = &impl_def.for_type {
2103+
if *impl_type_id == receiver_type_id {
2104+
for method in &impl_def.methods {
2105+
if method.signature.name == method_call.method {
2106+
method_return_type = Some(method.signature.return_type.clone());
2107+
break;
2108+
}
20802109
}
20812110
}
20822111
}
2112+
if method_return_type.is_some() {
2113+
break;
2114+
}
20832115
}
20842116
if method_return_type.is_some() {
20852117
break;
20862118
}
20872119
}
2088-
if method_return_type.is_some() {
2089-
break;
2090-
}
2091-
}
2092-
2093-
let typed_return_type = method_return_type.unwrap_or(Type::Primitive(zyntax_typed_ast::PrimitiveType::I32));
2094-
self.convert_type(&typed_return_type)
2120+
let typed_return_type = method_return_type.unwrap_or(Type::Primitive(zyntax_typed_ast::PrimitiveType::I64));
2121+
self.convert_type(&typed_return_type)
2122+
} else {
2123+
// For extern types, assume opaque return (returns same type as receiver)
2124+
log::debug!("[METHOD_CALL] Assuming opaque return type for extern method '{}'", mangled_str);
2125+
self.convert_type(&receiver_type)
2126+
};
2127+
hir_type
20952128
} else {
2096-
// Otherwise use the annotated type
2129+
// Use the annotated type from the expression
20972130
self.convert_type(&expr.ty)
20982131
};
20992132

@@ -3555,7 +3588,15 @@ impl SsaBuilder {
35553588
let type_name = type_name.unwrap();
35563589

35573590
// Construct the method function name: TypeName$method
3558-
let method_symbol = format!("{}${}", type_name, method_name);
3591+
// For extern types, the type_name already starts with '$' (e.g., "$Tensor")
3592+
// so we just append $method to get "$Tensor$add"
3593+
let method_symbol = if type_name.starts_with('$') {
3594+
// Already has $ prefix (extern type)
3595+
format!("{}${}", type_name, method_name)
3596+
} else {
3597+
// Regular type, add $ prefix for ZRTL compatibility
3598+
format!("${}${}", type_name, method_name)
3599+
};
35593600
let method_name_interned = InternedString::new_global(&method_symbol);
35603601
log::debug!("[SSA] Operator trait dispatch: {} for type {}", method_symbol, type_name);
35613602

@@ -3573,7 +3614,11 @@ impl SsaBuilder {
35733614
// Translate arguments
35743615
let left_val = self.translate_expression(block_id, left)?;
35753616
let right_val = self.translate_expression(block_id, right)?;
3576-
let hir_result_type = self.convert_type(result_type);
3617+
3618+
// For binary operations, the result type should be the same as the operand type
3619+
// (e.g., Tensor + Tensor = Tensor). Use left operand's type instead of expression type.
3620+
let hir_result_type = self.convert_type(left_type);
3621+
log::debug!("[SSA] Operator trait dispatch result type: {:?} (from left type: {:?})", hir_result_type, left_type);
35773622

35783623
// Create call instruction to the trait method
35793624
let result = if hir_result_type != HirType::Void {
@@ -3598,6 +3643,29 @@ impl SsaBuilder {
35983643
Ok(Some(result.unwrap_or_else(|| self.create_undef(HirType::Void))))
35993644
}
36003645

3646+
/// Resolve the actual type for an expression, looking up variable types if needed.
3647+
/// This is needed because expression nodes may have type `Any` even when the
3648+
/// variable has a more specific type recorded in var_typed_ast_types.
3649+
fn resolve_actual_type(&self, expr: &zyntax_typed_ast::typed_ast::TypedExpression, fallback: &Type) -> Type {
3650+
use zyntax_typed_ast::typed_ast::TypedExpression;
3651+
3652+
match expr {
3653+
TypedExpression::Variable(name) => {
3654+
// Look up the variable's actual type from our tracking
3655+
if let Some(var_ty) = self.var_typed_ast_types.get(name) {
3656+
if !matches!(var_ty, Type::Any | Type::Unknown) {
3657+
log::debug!("[resolve_actual_type] Variable '{}' has type {:?}",
3658+
name.resolve_global().unwrap_or_default(), var_ty);
3659+
return var_ty.clone();
3660+
}
3661+
}
3662+
// Fall back to the expression's type
3663+
fallback.clone()
3664+
}
3665+
_ => fallback.clone(),
3666+
}
3667+
}
3668+
36013669
/// Check if a type should use trait dispatch for operators
36023670
fn is_trait_dispatchable_type(&self, ty: &Type) -> bool {
36033671
match ty {

crates/zyn_peg/src/runtime2/interpreter.rs

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -928,9 +928,13 @@ impl<'g> GrammarInterpreter<'g> {
928928
let name = self.get_field_as_interned("name", fields, state)?;
929929
let type_params = self.get_field_as_type_param_list("type_params", fields, state)?;
930930

931+
// Runtime prefix is $TypeName to match ZRTL symbol convention
932+
let name_str = name.resolve_global().unwrap_or_default();
933+
let runtime_prefix = zyntax_typed_ast::InternedString::new_global(&format!("${}", name_str));
934+
931935
TypedDeclaration::Extern(TypedExtern::Struct(TypedExternStruct {
932936
name,
933-
runtime_prefix: name, // Use name as default runtime prefix
937+
runtime_prefix,
934938
type_params,
935939
}))
936940
}
@@ -1425,12 +1429,16 @@ impl<'g> GrammarInterpreter<'g> {
14251429
ParameterKind::Regular
14261430
};
14271431

1432+
// Parse default_value if present
1433+
let default_value = self.get_field_optional_expr("default_value", fields, state)?
1434+
.map(|expr| Box::new(expr));
1435+
14281436
Ok(ParsedValue::Parameter(TypedParameter {
14291437
name,
14301438
ty,
14311439
mutability: Mutability::Immutable,
14321440
kind,
1433-
default_value: None,
1441+
default_value,
14341442
attributes: vec![],
14351443
span,
14361444
}))
@@ -1661,6 +1669,98 @@ impl<'g> GrammarInterpreter<'g> {
16611669
}
16621670
Ok(ParsedValue::List(result))
16631671
}
1672+
"make_pair" => {
1673+
// make_pair(a, b) - create a two-element list [a, b]
1674+
// Useful for building operator-operand pairs in binary expression parsing
1675+
if args.len() != 2 {
1676+
return Err("make_pair() requires exactly 2 arguments".to_string());
1677+
}
1678+
let first = self.eval_expr(&args[0], state)?;
1679+
let second = self.eval_expr(&args[1], state)?;
1680+
Ok(ParsedValue::List(vec![first, second]))
1681+
}
1682+
"fold_left_ops" => {
1683+
// fold_left_ops(first, rest) - fold binary operations with left associativity
1684+
// first: the first operand expression
1685+
// rest: list of [op, operand, op, operand, ...] pairs or [[op, operand], [op, operand], ...]
1686+
// Returns nested Binary expressions folded left-to-right
1687+
// e.g., fold_left_ops(a, [["+", b], ["-", c]]) -> Binary(Binary(a, +, b), -, c)
1688+
if args.len() != 2 {
1689+
return Err("fold_left_ops() requires exactly 2 arguments (first, rest)".to_string());
1690+
}
1691+
let first = self.eval_expr(&args[0], state)?;
1692+
let rest = self.eval_expr(&args[1], state)?;
1693+
1694+
// Get the rest as a list
1695+
let rest_list = match rest {
1696+
ParsedValue::List(items) => items,
1697+
ParsedValue::Optional(None) | ParsedValue::None => vec![],
1698+
ParsedValue::Optional(Some(inner)) => {
1699+
match *inner {
1700+
ParsedValue::List(items) => items,
1701+
other => vec![other],
1702+
}
1703+
}
1704+
other => vec![other],
1705+
};
1706+
1707+
// If no rest, return first as-is
1708+
if rest_list.is_empty() {
1709+
return Ok(first);
1710+
}
1711+
1712+
// Flatten nested lists from the repeat pattern
1713+
let mut flat_items: Vec<ParsedValue> = Vec::new();
1714+
for item in rest_list {
1715+
match item {
1716+
ParsedValue::List(inner) => {
1717+
flat_items.extend(inner);
1718+
}
1719+
other => flat_items.push(other),
1720+
}
1721+
}
1722+
1723+
// Now we have [op, operand, op, operand, ...]
1724+
// Fold left: accumulator op operand -> new accumulator
1725+
let mut acc = first;
1726+
let mut i = 0;
1727+
while i + 1 < flat_items.len() {
1728+
let op_val = &flat_items[i];
1729+
let operand_val = flat_items[i + 1].clone();
1730+
1731+
// Get operator string
1732+
let op_str = match op_val {
1733+
ParsedValue::Text(s) => s.clone(),
1734+
ParsedValue::Interned(s) => s.resolve_global().unwrap_or_else(|| "+".to_string()),
1735+
other => {
1736+
log::warn!("[fold_left_ops] Unexpected operator value: {:?}", other);
1737+
"+".to_string()
1738+
}
1739+
};
1740+
1741+
// Convert accumulator to expression
1742+
let left_expr = self.parsed_value_to_expr(acc, state)?;
1743+
1744+
// Convert operand to expression
1745+
let right_expr = self.parsed_value_to_expr(operand_val, state)?;
1746+
1747+
// Create Binary expression
1748+
let binary_op = self.string_to_binary_op(&op_str)?;
1749+
acc = ParsedValue::Expression(Box::new(typed_node(
1750+
TypedExpression::Binary(zyntax_typed_ast::TypedBinary {
1751+
op: binary_op,
1752+
left: Box::new(left_expr),
1753+
right: Box::new(right_expr),
1754+
}),
1755+
Type::Unknown,
1756+
span,
1757+
)));
1758+
1759+
i += 2;
1760+
}
1761+
1762+
Ok(acc)
1763+
}
16641764
"fold_postfix" => {
16651765
// fold_postfix(base, suffixes) - fold postfix operations into nested expressions
16661766
// base: the base expression (TypedExpression)

0 commit comments

Comments
 (0)