Skip to content

Commit 2717644

Browse files
committed
feat(typed_ast): add built-in operator traits with feature flag
Type System (typed_ast): - Add BuiltinTraitIds struct for well-known trait IDs - Implement register_builtin_traits() to register all operator traits: - Arithmetic: Add, Sub, Mul, Div, Mod, MatMul, Neg - Comparison: Eq, Ord (lt, gt, le, ge) - Bitwise: BitAnd, BitOr, BitXor, Not - Indexing: Index, IndexMut - Formatting: Display, Debug - Lifecycle: Clone, Drop - Iteration: Iterator, IntoIterator - Add helper methods for operator->trait and operator->method mapping - Gate all trait code behind "operator_traits" feature flag Compiler (SSA lowering): - Enable operator_traits feature for zyntax_typed_ast - Add try_operator_trait_dispatch() for binary operators - Check if left operand type is opaque (ZRTL-backed) - Generate method call to $TypeName$method instead of binary instruction - Add is_trait_dispatchable_type() and get_type_symbol_prefix() helpers ZRTL Tensor Plugin: - Add arithmetic operators: tensor_add, tensor_sub, tensor_mul, tensor_div - Add tensor_mod for modulo operation - Add tensor_neg for unary negation - Add tensor_dot for @ (matrix multiplication/dot product) - Register all operators as $Tensor$<method> symbols
1 parent 3df5b24 commit 2717644

6 files changed

Lines changed: 1126 additions & 4 deletions

File tree

crates/compiler/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ description = "Layered compilation system for Zyntax language"
77

88
[dependencies]
99
# Parser dependency
10-
zyntax_typed_ast = { path = "../typed_ast" }
10+
zyntax_typed_ast = { path = "../typed_ast", features = ["operator_traits"] }
1111

1212
# Core data structures
1313
dashmap = { workspace = true }

crates/compiler/src/ssa.rs

Lines changed: 138 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,13 @@ impl SsaBuilder {
13781378
return Ok(result);
13791379
}
13801380

1381-
// Regular binary operations
1381+
// Check if this is a non-primitive type that might use operator overloading
1382+
// For opaque/named types, try to dispatch to trait method
1383+
if let Some(trait_call) = self.try_operator_trait_dispatch(block_id, op, left, right, &expr.ty)? {
1384+
return Ok(trait_call);
1385+
}
1386+
1387+
// Regular binary operations for primitive types
13821388
let left_val = self.translate_expression(block_id, left)?;
13831389
let right_val = self.translate_expression(block_id, right)?;
13841390
let result_type = self.convert_type(&expr.ty);
@@ -2984,14 +2990,143 @@ impl SsaBuilder {
29842990
fn convert_unary_op(&self, op: &zyntax_typed_ast::typed_ast::UnaryOp) -> crate::hir::UnaryOp {
29852991
use zyntax_typed_ast::typed_ast::UnaryOp as FrontendOp;
29862992
use crate::hir::UnaryOp as HirOp;
2987-
2993+
29882994
match op {
29892995
FrontendOp::Minus => HirOp::Neg,
29902996
FrontendOp::Not => HirOp::Not,
29912997
_ => HirOp::Neg, // Default
29922998
}
29932999
}
2994-
3000+
3001+
/// Try to dispatch a binary operator to a trait method call.
3002+
/// Returns Some(result_value) if the operator should be dispatched via trait,
3003+
/// or None if the regular binary instruction should be used.
3004+
fn try_operator_trait_dispatch(
3005+
&mut self,
3006+
block_id: HirId,
3007+
op: &zyntax_typed_ast::typed_ast::BinaryOp,
3008+
left: &zyntax_typed_ast::TypedNode<zyntax_typed_ast::typed_ast::TypedExpression>,
3009+
right: &zyntax_typed_ast::TypedNode<zyntax_typed_ast::typed_ast::TypedExpression>,
3010+
result_type: &Type,
3011+
) -> CompilerResult<Option<HirId>> {
3012+
use zyntax_typed_ast::typed_ast::BinaryOp as FrontendOp;
3013+
3014+
// Only consider trait dispatch for non-primitive types
3015+
let left_type = &left.ty;
3016+
if !self.is_trait_dispatchable_type(left_type) {
3017+
return Ok(None);
3018+
}
3019+
3020+
// Get the method name for this operator
3021+
let method_name = match op {
3022+
FrontendOp::Add => "add",
3023+
FrontendOp::Sub => "sub",
3024+
FrontendOp::Mul => "mul",
3025+
FrontendOp::Div => "div",
3026+
FrontendOp::Rem => "mod",
3027+
FrontendOp::Eq => "eq",
3028+
FrontendOp::Ne => "ne",
3029+
FrontendOp::Lt => "lt",
3030+
FrontendOp::Le => "le",
3031+
FrontendOp::Gt => "gt",
3032+
FrontendOp::Ge => "ge",
3033+
FrontendOp::BitAnd => "bitand",
3034+
FrontendOp::BitOr => "bitor",
3035+
FrontendOp::BitXor => "bitxor",
3036+
_ => return Ok(None), // No trait method for this operator
3037+
};
3038+
3039+
// Get the type name for constructing the method symbol
3040+
let type_name = self.get_type_symbol_prefix(left_type);
3041+
if type_name.is_none() {
3042+
return Ok(None);
3043+
}
3044+
let type_name = type_name.unwrap();
3045+
3046+
// Construct the method symbol name: $TypeName$method
3047+
let method_symbol = format!("${}${}", type_name, method_name);
3048+
log::debug!("[SSA] Operator trait dispatch: {} for type {}", method_symbol, type_name);
3049+
3050+
// Translate arguments
3051+
let left_val = self.translate_expression(block_id, left)?;
3052+
let right_val = self.translate_expression(block_id, right)?;
3053+
let hir_result_type = self.convert_type(result_type);
3054+
3055+
// Create call instruction to the trait method
3056+
let result = if hir_result_type != HirType::Void {
3057+
Some(self.create_value(hir_result_type.clone(), HirValueKind::Instruction))
3058+
} else {
3059+
None
3060+
};
3061+
3062+
let inst = HirInstruction::Call {
3063+
result,
3064+
callee: crate::hir::HirCallable::Symbol(method_symbol),
3065+
args: vec![left_val, right_val],
3066+
type_args: vec![],
3067+
const_args: vec![],
3068+
is_tail: false,
3069+
};
3070+
3071+
self.add_instruction(block_id, inst);
3072+
self.add_use(left_val, result.unwrap_or(left_val));
3073+
self.add_use(right_val, result.unwrap_or(right_val));
3074+
3075+
Ok(Some(result.unwrap_or_else(|| self.create_undef(HirType::Void))))
3076+
}
3077+
3078+
/// Check if a type should use trait dispatch for operators
3079+
fn is_trait_dispatchable_type(&self, ty: &Type) -> bool {
3080+
match ty {
3081+
// Named types - check if the HIR conversion results in an opaque type
3082+
Type::Named { id, .. } => {
3083+
// Convert to HIR type and check if it's opaque
3084+
let hir_ty = self.convert_type(ty);
3085+
matches!(hir_ty, HirType::Opaque(_))
3086+
}
3087+
// Primitive types - use built-in operations
3088+
Type::Primitive(_) => false,
3089+
// Other types - might need trait dispatch
3090+
_ => false,
3091+
}
3092+
}
3093+
3094+
/// Get the symbol prefix for a type (e.g., "Tensor" for $Tensor$add)
3095+
fn get_type_symbol_prefix(&self, ty: &Type) -> Option<String> {
3096+
match ty {
3097+
Type::Named { id, .. } => {
3098+
if let Some(type_def) = self.type_registry.get_type_by_id(*id) {
3099+
// Use the type name
3100+
let name = type_def.name.resolve_global()
3101+
.unwrap_or_else(|| {
3102+
let arena = self.arena.lock().unwrap();
3103+
arena.resolve_string(type_def.name)
3104+
.map(|s| s.to_string())
3105+
.unwrap_or_default()
3106+
});
3107+
Some(name)
3108+
} else {
3109+
// If not in registry, try to get from HIR type
3110+
let hir_ty = self.convert_type(ty);
3111+
if let HirType::Opaque(name) = hir_ty {
3112+
let name_str = name.resolve_global()
3113+
.unwrap_or_else(|| {
3114+
let arena = self.arena.lock().unwrap();
3115+
arena.resolve_string(name)
3116+
.map(|s| s.to_string())
3117+
.unwrap_or_default()
3118+
});
3119+
// Remove $ prefix if present
3120+
Some(name_str.trim_start_matches('$').to_string())
3121+
} else {
3122+
None
3123+
}
3124+
}
3125+
}
3126+
_ => None,
3127+
}
3128+
}
3129+
29953130
/// Translate literal to constant
29963131
fn translate_literal(&self, lit: &zyntax_typed_ast::typed_ast::TypedLiteral) -> crate::hir::HirConstant {
29973132
use zyntax_typed_ast::typed_ast::TypedLiteral;

crates/typed_ast/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ derive_builder = { version = "0.20.2", optional = true }
2727
[features]
2828
default = []
2929
builders = ["derive_builder"] # Enable builder pattern helpers
30+
operator_traits = [] # Enable built-in operator traits (Add, Sub, Mul, etc.)
3031

3132

3233
[[example]]

crates/typed_ast/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ pub use type_registry::{
110110
ConstBinaryOp, ConstUnaryOp, ConstVarId, Kind,
111111
};
112112

113+
// Operator traits feature exports
114+
#[cfg(feature = "operator_traits")]
115+
pub use type_registry::BuiltinTraitIds;
116+
113117
pub use typed_ast::{
114118
TypedNode, TypedProgram, TypedDeclaration, TypedFunction, TypedVariable,
115119
TypedStatement, TypedExpression, TypedLiteral, BinaryOp, UnaryOp,

0 commit comments

Comments
 (0)