Skip to content

Commit eab71ed

Browse files
committed
feat(compiler): implement operator overloading with trait dispatch
Add Type::Extern to TypedAST for opaque/FFI types that converts to HirType::Ptr(Opaque) at HIR level. Implement trait dispatch in SSA lowering to transform binary ops (a + b) into method calls ($Type$add). - Add @types directive to ZynPEG for declaring opaque types and returns - Update ZynML grammar with comprehensive type declarations - SSA lowering detects Type::Extern and dispatches to trait methods - Add exhaustive pattern matching for Type::Extern in constraint solver
1 parent 2717644 commit eab71ed

8 files changed

Lines changed: 257 additions & 14 deletions

File tree

crates/compiler/src/ssa.rs

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2956,6 +2956,11 @@ impl SsaBuilder {
29562956
packed: false,
29572957
})
29582958
},
2959+
Type::Extern { name, .. } => {
2960+
// Extern/Opaque types are pointers to opaque structs at the HIR level
2961+
// The name is used for trait dispatch (e.g., $Tensor -> $Tensor$add)
2962+
HirType::Ptr(Box::new(HirType::Opaque(*name)))
2963+
},
29592964
_ => HirType::I64, // Default for complex types
29602965
}
29612966
}
@@ -3044,7 +3049,8 @@ impl SsaBuilder {
30443049
let type_name = type_name.unwrap();
30453050

30463051
// Construct the method symbol name: $TypeName$method
3047-
let method_symbol = format!("${}${}", type_name, method_name);
3052+
// Note: type_name already includes the $ prefix (e.g., "$Tensor")
3053+
let method_symbol = format!("{}${}", type_name, method_name);
30483054
log::debug!("[SSA] Operator trait dispatch: {} for type {}", method_symbol, type_name);
30493055

30503056
// Translate arguments
@@ -3078,9 +3084,33 @@ impl SsaBuilder {
30783084
/// Check if a type should use trait dispatch for operators
30793085
fn is_trait_dispatchable_type(&self, ty: &Type) -> bool {
30803086
match ty {
3081-
// Named types - check if the HIR conversion results in an opaque type
3087+
// Extern types - ZRTL-backed opaque types, always use trait dispatch
3088+
Type::Extern { name, .. } => {
3089+
let name_str = name.resolve_global()
3090+
.unwrap_or_else(|| {
3091+
let arena = self.arena.lock().unwrap();
3092+
arena.resolve_string(*name)
3093+
.map(|s| s.to_string())
3094+
.unwrap_or_default()
3095+
});
3096+
log::debug!("[trait_dispatch] Type::Extern name='{}' -> dispatchable", name_str);
3097+
true
3098+
}
3099+
// Named types - check if it's a ZRTL-backed type (name starts with $)
30823100
Type::Named { id, .. } => {
3083-
// Convert to HIR type and check if it's opaque
3101+
// Check the type name from registry
3102+
if let Some(type_def) = self.type_registry.get_type_by_id(*id) {
3103+
let name = type_def.name.resolve_global()
3104+
.unwrap_or_else(|| {
3105+
let arena = self.arena.lock().unwrap();
3106+
arena.resolve_string(type_def.name)
3107+
.map(|s| s.to_string())
3108+
.unwrap_or_default()
3109+
});
3110+
// ZRTL opaque types start with $
3111+
return name.starts_with('$');
3112+
}
3113+
// Also check HIR conversion
30843114
let hir_ty = self.convert_type(ty);
30853115
matches!(hir_ty, HirType::Opaque(_))
30863116
}
@@ -3091,9 +3121,21 @@ impl SsaBuilder {
30913121
}
30923122
}
30933123

3094-
/// Get the symbol prefix for a type (e.g., "Tensor" for $Tensor$add)
3124+
/// Get the symbol prefix for a type (e.g., "$Tensor" for $Tensor$add)
30953125
fn get_type_symbol_prefix(&self, ty: &Type) -> Option<String> {
30963126
match ty {
3127+
// Extern types have the name directly
3128+
Type::Extern { name, .. } => {
3129+
let name_str = name.resolve_global()
3130+
.unwrap_or_else(|| {
3131+
let arena = self.arena.lock().unwrap();
3132+
arena.resolve_string(*name)
3133+
.map(|s| s.to_string())
3134+
.unwrap_or_default()
3135+
});
3136+
log::debug!("[trait_dispatch] Type::Extern prefix: '{}'", name_str);
3137+
Some(name_str)
3138+
}
30973139
Type::Named { id, .. } => {
30983140
if let Some(type_def) = self.type_registry.get_type_by_id(*id) {
30993141
// Use the type name
@@ -3116,8 +3158,7 @@ impl SsaBuilder {
31163158
.map(|s| s.to_string())
31173159
.unwrap_or_default()
31183160
});
3119-
// Remove $ prefix if present
3120-
Some(name_str.trim_start_matches('$').to_string())
3161+
Some(name_str)
31213162
} else {
31223163
None
31233164
}

crates/typed_ast/src/constraint_solver.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ impl Substitution {
247247
| Type::Never
248248
| Type::Any
249249
| Type::Error
250-
| Type::HigherKinded { .. } => ty.clone(),
250+
| Type::HigherKinded { .. }
251+
| Type::Extern { .. } => ty.clone(),
251252
Type::Nullable(inner_ty) => {
252253
// For nullable types, substitute the inner type and maintain nullability
253254
let substituted_inner = self.apply(inner_ty);
@@ -2639,6 +2640,10 @@ impl ConstraintSolver {
26392640
associated_types,
26402641
super_traits,
26412642
} => todo!(),
2643+
Type::Extern { name, .. } => {
2644+
// Format extern/opaque type by name
2645+
name.resolve_global().unwrap_or_else(|| format!("extern_{}", name.symbol().to_usize()))
2646+
}
26422647
}
26432648
}
26442649

@@ -3394,7 +3399,8 @@ impl ConstraintSolver {
33943399
| Type::Any
33953400
| Type::Error
33963401
| Type::SelfType
3397-
| Type::HigherKinded { .. } => ty.clone(),
3402+
| Type::HigherKinded { .. }
3403+
| Type::Extern { .. } => ty.clone(),
33983404
Type::Nullable(inner_ty) => {
33993405
// For nullable types, substitute the inner type and maintain nullability
34003406
let substituted_inner = self.resolve_associated_types(inner_ty, receiver_type);
@@ -3452,6 +3458,7 @@ impl ConstraintSolver {
34523458
associated_types,
34533459
super_traits,
34543460
} => todo!(),
3461+
Type::Extern { .. } => ty.clone(),
34553462
}
34563463
}
34573464

@@ -3563,7 +3570,8 @@ impl ConstraintSolver {
35633570
| Type::Any
35643571
| Type::Error
35653572
| Type::Associated { .. }
3566-
| Type::HigherKinded { .. } => ty.clone(),
3573+
| Type::HigherKinded { .. }
3574+
| Type::Extern { .. } => ty.clone(),
35673575
Type::Nullable(inner_ty) => {
35683576
// For nullable types, substitute the inner type and maintain nullability
35693577
let substituted_inner = self.substitute_self_type(inner_ty, receiver_type);

crates/typed_ast/src/type_registry.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,20 @@ pub enum Type {
429429
associated_types: Vec<(InternedString, Type)>,
430430
super_traits: Vec<Type>,
431431
},
432+
433+
/// External/Opaque type (for FFI, ZRTL plugins)
434+
/// At the HIR level, this becomes a pointer to an opaque type
435+
Extern {
436+
name: InternedString,
437+
layout: Option<ExternLayout>,
438+
},
439+
}
440+
441+
/// Layout information for extern types (optional)
442+
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
443+
pub struct ExternLayout {
444+
pub size: usize,
445+
pub align: usize,
432446
}
433447

434448
/// Type variable for inference

crates/zyn_peg/src/ast.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! AST types for parsed .zyn grammar files
22
33
use pest::iterators::Pair;
4-
use crate::{Rule, ZynGrammar, LanguageInfo, Imports, ContextVar, TypeHelpers, BuiltinMappings, RuleDef, RuleModifier, ActionBlock, ActionField};
4+
use crate::{Rule, ZynGrammar, LanguageInfo, Imports, ContextVar, TypeHelpers, BuiltinMappings, TypeDeclarations, RuleDef, RuleModifier, ActionBlock, ActionField};
55

66
/// Build a ZynGrammar from parsed pest pairs
77
pub fn build_grammar(pairs: pest::iterators::Pairs<Rule>) -> Result<ZynGrammar, String> {
@@ -39,6 +39,7 @@ fn process_top_level(grammar: &mut ZynGrammar, pair: Pair<Rule>) -> Result<(), S
3939
Rule::context_directive => grammar.context = build_context(pair)?,
4040
Rule::type_helpers_directive => grammar.type_helpers = build_type_helpers(pair)?,
4141
Rule::builtin_directive => grammar.builtins = build_builtins(pair)?,
42+
Rule::types_directive => grammar.types = build_types(pair)?,
4243
Rule::EOI => {}
4344
_ => {}
4445
}
@@ -53,6 +54,7 @@ fn process_directive(grammar: &mut ZynGrammar, pair: Pair<Rule>) -> Result<(), S
5354
Rule::context_directive => grammar.context = build_context(inner)?,
5455
Rule::type_helpers_directive => grammar.type_helpers = build_type_helpers(inner)?,
5556
Rule::builtin_directive => grammar.builtins = build_builtins(inner)?,
57+
Rule::types_directive => grammar.types = build_types(inner)?,
5658
Rule::error_messages_directive => {
5759
// TODO: Parse error messages
5860
}
@@ -205,6 +207,45 @@ fn build_builtins(pair: Pair<Rule>) -> Result<BuiltinMappings, String> {
205207
Ok(builtins)
206208
}
207209

210+
/// Parse @types { opaque: [$Tensor, $Audio], returns: { tensor: $Tensor } } directive
211+
///
212+
/// Declares:
213+
/// - opaque: List of opaque type names that are pointer types backed by ZRTL plugins
214+
/// - returns: Map of function name -> return type for proper type tracking
215+
fn build_types(pair: Pair<Rule>) -> Result<TypeDeclarations, String> {
216+
let mut types = TypeDeclarations::default();
217+
218+
for inner in pair.into_inner() {
219+
if inner.as_rule() == Rule::types_def {
220+
let def_text = inner.as_str().trim();
221+
222+
if def_text.starts_with("opaque") {
223+
// Parse opaque type list: opaque: [$Tensor, $Audio]
224+
for type_inner in inner.into_inner() {
225+
if type_inner.as_rule() == Rule::type_name {
226+
let type_name = type_inner.as_str().to_string();
227+
types.opaque_types.push(type_name);
228+
}
229+
}
230+
} else if def_text.starts_with("returns") {
231+
// Parse return type mapping: returns: { tensor: $Tensor, audio_load: $Audio }
232+
for return_inner in inner.into_inner() {
233+
if return_inner.as_rule() == Rule::return_def {
234+
let mut parts = return_inner.into_inner();
235+
if let (Some(fn_name), Some(type_name)) = (parts.next(), parts.next()) {
236+
let fn_name_str = fn_name.as_str().to_string();
237+
let type_name_str = type_name.as_str().to_string();
238+
types.function_returns.insert(fn_name_str, type_name_str);
239+
}
240+
}
241+
}
242+
}
243+
}
244+
}
245+
246+
Ok(types)
247+
}
248+
208249
fn build_rule_def(pair: Pair<Rule>) -> Result<RuleDef, String> {
209250
let mut name = String::new();
210251
let mut modifier = None;

crates/zyn_peg/src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ pub struct BuiltinMappings {
9696
pub operators: std::collections::HashMap<String, Vec<String>>,
9797
}
9898

99+
/// Type declarations from @types directive
100+
/// Declares opaque types and function return types for proper type tracking
101+
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
102+
pub struct TypeDeclarations {
103+
/// List of opaque type names (ZRTL-backed types)
104+
/// e.g., ["$Tensor", "$Audio"] - these are pointer types backed by plugins
105+
#[serde(default)]
106+
pub opaque_types: Vec<String>,
107+
/// Map of function name -> return type
108+
/// e.g., "tensor" -> "$Tensor", "audio_load" -> "$Audio"
109+
/// Used during lowering to assign correct types to call expressions
110+
#[serde(default)]
111+
pub function_returns: std::collections::HashMap<String, String>,
112+
}
113+
99114
/// A parsed .zyn grammar file
100115
#[derive(Debug, Clone, Default)]
101116
pub struct ZynGrammar {
@@ -105,6 +120,8 @@ pub struct ZynGrammar {
105120
pub type_helpers: TypeHelpers,
106121
/// Built-in function mappings from @builtin directive
107122
pub builtins: BuiltinMappings,
123+
/// Type declarations from @types directive
124+
pub types: TypeDeclarations,
108125
pub rules: Vec<RuleDef>,
109126
}
110127

crates/zyn_peg/src/runtime.rs

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use std::collections::HashMap;
4242
use std::path::Path;
4343

4444
use crate::error::{Result, ZynPegError};
45-
use crate::{ZynGrammar, BuiltinMappings};
45+
use crate::{ZynGrammar, BuiltinMappings, TypeDeclarations};
4646

4747
// Re-export types from typed_ast for host function implementations
4848
pub use zyntax_typed_ast::{
@@ -88,6 +88,10 @@ pub struct ZpegMetadata {
8888
/// - Operators: "$*" -> "vec_dot" (x * y -> vec_dot(x, y))
8989
#[serde(default)]
9090
pub builtins: BuiltinMappings,
91+
/// Type declarations for opaque types and function return types
92+
/// Used for proper type tracking during lowering for operator trait dispatch
93+
#[serde(default)]
94+
pub types: TypeDeclarations,
9195
}
9296

9397
/// Commands for a single grammar rule
@@ -302,23 +306,42 @@ pub trait AstHostFunctions {
302306
/// Create a function call expression
303307
fn create_call(&mut self, callee: NodeHandle, args: Vec<NodeHandle>) -> NodeHandle;
304308

305-
/// Create a function call expression with builtin resolution
309+
/// Create a function call expression with explicit return type
310+
/// Used when we know the return type from @types directive
311+
fn create_call_with_return_type(&mut self, callee: NodeHandle, args: Vec<NodeHandle>, return_type: Option<&str>) -> NodeHandle {
312+
// Default implementation ignores return type
313+
self.create_call(callee, args)
314+
}
315+
316+
/// Create a function call expression with builtin and type resolution
306317
/// If the callee is a simple identifier matching a builtin function, use the runtime symbol instead
318+
/// Also looks up return types from @types directive for proper opaque type tracking
307319
fn create_call_with_builtin_resolution(
308320
&mut self,
309321
callee: NodeHandle,
310322
args: Vec<NodeHandle>,
311323
builtins: &BuiltinMappings,
324+
types: &TypeDeclarations,
312325
) -> NodeHandle {
313326
// Check if callee is an identifier that matches a builtin function
314327
if let Some(name) = self.get_identifier_name(callee) {
315328
log::trace!("[builtin resolution] callee name='{}', checking {} builtins", name, builtins.functions.len());
329+
330+
// Look up return type from @types directive
331+
let return_type = types.function_returns.get(&name).map(|s| s.as_str());
332+
if return_type.is_some() {
333+
log::trace!("[builtin resolution] found return type for '{}': {:?}", name, return_type);
334+
}
335+
316336
if let Some(symbol) = builtins.functions.get(&name) {
317337
log::trace!("[builtin resolution] found builtin '{}' -> '{}'", name, symbol);
318338
// Create a new identifier with the runtime symbol name
319339
let resolved_callee = self.create_identifier(symbol);
320-
return self.create_call(resolved_callee, args);
340+
return self.create_call_with_return_type(resolved_callee, args, return_type);
321341
}
342+
343+
// Not a builtin, but might have return type info
344+
return self.create_call_with_return_type(callee, args, return_type);
322345
} else {
323346
log::trace!("[builtin resolution] callee is not an identifier");
324347
}
@@ -543,6 +566,7 @@ impl ZpegCompiler {
543566
entry_point: grammar.language.entry_point.clone(),
544567
zpeg_version: env!("CARGO_PKG_VERSION").to_string(),
545568
builtins: grammar.builtins.clone(),
569+
types: grammar.types.clone(),
546570
};
547571

548572
// Generate pest grammar
@@ -1483,6 +1507,10 @@ impl AstHostFunctions for TypedAstBuilder {
14831507
}
14841508

14851509
fn create_call(&mut self, callee: NodeHandle, args: Vec<NodeHandle>) -> NodeHandle {
1510+
self.create_call_with_return_type(callee, args, None)
1511+
}
1512+
1513+
fn create_call_with_return_type(&mut self, callee: NodeHandle, args: Vec<NodeHandle>, return_type: Option<&str>) -> NodeHandle {
14861514
let span = self.default_span();
14871515

14881516
let callee_expr = self.get_expr(callee)
@@ -1492,7 +1520,35 @@ impl AstHostFunctions for TypedAstBuilder {
14921520
.filter_map(|h| self.get_expr(*h))
14931521
.collect();
14941522

1495-
let expr = self.inner.call_positional(callee_expr, arg_exprs, Type::Primitive(PrimitiveType::I32), span);
1523+
// Convert return type string to Type
1524+
// Opaque types (starting with $) become Extern types which are pointers at the HIR level
1525+
let ty = match return_type {
1526+
Some(type_name) if type_name.starts_with('$') => {
1527+
// This is an opaque type - create an Extern type
1528+
// The type name without $ is used as the extern type name
1529+
log::trace!("[create_call] opaque return type: {}", type_name);
1530+
Type::Extern {
1531+
name: InternedString::new_global(type_name),
1532+
layout: None,
1533+
}
1534+
}
1535+
Some(type_name) => {
1536+
// Regular named type - try to parse it
1537+
log::trace!("[create_call] named return type: {}", type_name);
1538+
match type_name {
1539+
"i32" | "I32" => Type::Primitive(PrimitiveType::I32),
1540+
"i64" | "I64" => Type::Primitive(PrimitiveType::I64),
1541+
"f32" | "F32" => Type::Primitive(PrimitiveType::F32),
1542+
"f64" | "F64" => Type::Primitive(PrimitiveType::F64),
1543+
"bool" | "Bool" => Type::Primitive(PrimitiveType::Bool),
1544+
"void" | "Void" | "()" => Type::Primitive(PrimitiveType::Unit),
1545+
_ => Type::Primitive(PrimitiveType::I32), // Default for unknown types
1546+
}
1547+
}
1548+
None => Type::Primitive(PrimitiveType::I32), // Default when no return type specified
1549+
};
1550+
1551+
let expr = self.inner.call_positional(callee_expr, arg_exprs, ty, span);
14961552
self.store_expr(expr)
14971553
}
14981554

@@ -3066,6 +3122,7 @@ impl<'a, H: AstHostFunctions> CommandInterpreter<'a, H> {
30663122
base_h.clone(),
30673123
call_args,
30683124
&self.module.metadata.builtins,
3125+
&self.module.metadata.types,
30693126
);
30703127
result = RuntimeValue::Node(new_node);
30713128
} else if op_info.starts_with("field:") {
@@ -3983,6 +4040,7 @@ impl<'a, H: AstHostFunctions> CommandInterpreter<'a, H> {
39834040
callee,
39844041
call_args,
39854042
&self.module.metadata.builtins,
4043+
&self.module.metadata.types,
39864044
);
39874045
Ok(RuntimeValue::Node(handle))
39884046
}

0 commit comments

Comments
 (0)