Skip to content

Commit 1ef7881

Browse files
committed
feat: implement DynamicBox auto-boxing infrastructure for ZRTL plugins
Added automatic DynamicBox wrapping for opaque type parameters when calling ZRTL plugin functions that expect DynamicBox arguments. Changes: - Added symbol_signatures HashMap to CraneliftBackend to track parameter types - Implemented register_symbol_signatures() to populate signatures from ZRTL plugins - Added param_needs_boxing() helper to check if parameter expects DynamicBox - Pre-compute boxing requirements before FunctionBuilder to avoid borrow issues - Generate DynamicBox wrapping code for parameters marked as dynamic - Updated ZPack to expose runtime_symbols_with_signatures() - Updated ZrtlRegistry with collect_symbols_with_signatures() - Modified zyntax_embed Runtime to register signatures when loading plugins - Updated cranelift_jit compile_jit() to pass signatures to backend DynamicBox structure (24 bytes): - tag: u32 (TypeCategory::Opaque = 0x12) - size: u32 (8 bytes for pointer) - data: i64 (pointer to opaque value) - dropper: i64 (null for now) Current status: - Boxing detection works correctly (println_dynamic param 0 shows needs_boxing=true) - DynamicBox structure is created on stack with proper layout - Remaining issue: ABI calling convention for passing struct by value vs pointer Need to investigate proper struct passing in Cranelift
1 parent bd5dc64 commit 1ef7881

9 files changed

Lines changed: 168 additions & 23 deletions

File tree

crates/compiler/src/cranelift_backend.rs

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ pub struct CraneliftBackend {
5151
exported_symbols: HashMap<String, *const u8>,
5252
/// Runtime symbols registered for external linking
5353
runtime_symbols: Vec<(String, *const u8)>,
54+
/// Symbol signatures for auto-boxing support (symbol_name → signature)
55+
symbol_signatures: HashMap<String, crate::zrtl::ZrtlSymbolSig>,
5456
}
5557

5658
/// Hot-reload state management
@@ -146,9 +148,32 @@ impl CraneliftBackend {
146148
},
147149
exported_symbols: HashMap::new(),
148150
runtime_symbols,
151+
symbol_signatures: HashMap::new(),
149152
})
150153
}
151154

155+
/// Register symbol signatures for auto-boxing support
156+
pub fn register_symbol_signatures(&mut self, symbols: &[crate::zrtl::RuntimeSymbolInfo]) {
157+
log::info!("[DynamicBox] Registering {} symbol signatures", symbols.len());
158+
for sym in symbols {
159+
if let Some(sig) = &sym.sig {
160+
log::debug!("[DynamicBox] Registering signature for {}: params={}, dynamic_params={:?}",
161+
sym.name, sig.param_count,
162+
(0..sig.param_count).filter(|&i| sig.param_is_dynamic(i as usize)).collect::<Vec<_>>());
163+
self.symbol_signatures.insert(sym.name.to_string(), sig.clone());
164+
}
165+
}
166+
log::info!("[DynamicBox] Registered {} signatures total", self.symbol_signatures.len());
167+
}
168+
169+
/// Check if a symbol parameter expects DynamicBox
170+
fn param_needs_boxing(&self, symbol_name: &str, param_index: usize) -> bool {
171+
self.symbol_signatures
172+
.get(symbol_name)
173+
.map(|sig| sig.param_is_dynamic(param_index))
174+
.unwrap_or(false)
175+
}
176+
152177
/// Compile a HIR module to native code
153178
pub fn compile_module(&mut self, module: &HirModule) -> CompilerResult<()> {
154179
// Process globals first (including vtables)
@@ -498,6 +523,21 @@ impl CraneliftBackend {
498523
}
499524
}
500525

526+
// Pre-compute symbol parameter boxing requirements to avoid borrow checker issues
527+
let mut symbol_boxing: HashMap<(String, usize), bool> = HashMap::new();
528+
for block in function.blocks.values() {
529+
for inst in &block.instructions {
530+
if let HirInstruction::Call { callee: HirCallable::Symbol(symbol_name), args, .. } = inst {
531+
for param_index in 0..args.len() {
532+
let key = (symbol_name.clone(), param_index);
533+
if !symbol_boxing.contains_key(&key) {
534+
symbol_boxing.insert(key, self.param_needs_boxing(symbol_name, param_index));
535+
}
536+
}
537+
}
538+
}
539+
}
540+
501541
// Phase 2: Create builder and all Cranelift blocks
502542
let mut builder = FunctionBuilder::new(
503543
&mut self.codegen_context.func,
@@ -930,9 +970,55 @@ impl CraneliftBackend {
930970
}
931971
HirCallable::Symbol(symbol_name) => {
932972
// Call external runtime symbol by name (e.g., "$Image$load")
933-
// Create signature based on argument types
973+
// Check if any parameters need DynamicBox wrapping
974+
let mut boxed_args = Vec::new();
975+
for (param_index, &arg_val) in arg_values.iter().enumerate() {
976+
// Look up boxing requirement from pre-computed map
977+
let needs_boxing = symbol_boxing
978+
.get(&(symbol_name.clone(), param_index))
979+
.copied()
980+
.unwrap_or(false);
981+
log::debug!("[DynamicBox] Symbol: {}, param {}: needs_boxing = {}", symbol_name, param_index, needs_boxing);
982+
if needs_boxing {
983+
// This parameter expects DynamicBox - wrap it
984+
// For opaque types (i64 pointer), we need to create a DynamicBox struct
985+
// DynamicBox layout: { tag: u32, size: u32, data: i64, dropper: i64 }
986+
987+
// Allocate stack space for DynamicBox (24 bytes on 64-bit)
988+
let slot = builder.create_sized_stack_slot(cranelift_codegen::ir::StackSlotData::new(
989+
cranelift_codegen::ir::StackSlotKind::ExplicitSlot,
990+
24,
991+
));
992+
let box_addr = builder.ins().stack_addr(types::I64, slot, 0);
993+
994+
// Set tag (TypeCategory::Opaque = 0x12, type_id = 0)
995+
let tag_val = builder.ins().iconst(types::I32, 0x12); // Opaque category
996+
builder.ins().store(cranelift_codegen::ir::MemFlags::new(), tag_val, box_addr, 0);
997+
998+
// Set size (8 bytes for pointer)
999+
let size_val = builder.ins().iconst(types::I32, 8);
1000+
builder.ins().store(cranelift_codegen::ir::MemFlags::new(), size_val, box_addr, 4);
1001+
1002+
// Set data pointer (the opaque value itself)
1003+
builder.ins().store(cranelift_codegen::ir::MemFlags::new(), arg_val, box_addr, 8);
1004+
1005+
// Set dropper to null (0)
1006+
let null_dropper = builder.ins().iconst(types::I64, 0);
1007+
builder.ins().store(cranelift_codegen::ir::MemFlags::new(), null_dropper, box_addr, 16);
1008+
1009+
// Pass the box by value (load struct fields and pass as args)
1010+
// Actually, DynamicBox is passed by value, so we need to load the struct
1011+
// For now, pass the pointer to the struct (this might need ABI adjustment)
1012+
boxed_args.push(box_addr);
1013+
} else {
1014+
// No boxing needed - use value as-is
1015+
boxed_args.push(arg_val);
1016+
}
1017+
}
1018+
1019+
// Create signature based on (possibly boxed) argument types
9341020
let mut sig = self.module.make_signature();
935-
for arg_val in &arg_values {
1021+
for arg_val in &boxed_args {
9361022
let arg_ty = builder.func.dfg.value_type(*arg_val);
9371023
sig.params.push(cranelift_codegen::ir::AbiParam::new(arg_ty));
9381024
}
@@ -971,7 +1057,7 @@ impl CraneliftBackend {
9711057
let local_func = self.module
9721058
.declare_func_in_func(func, builder.func);
9731059

974-
let call = builder.ins().call(local_func, &arg_values);
1060+
let call = builder.ins().call(local_func, &boxed_args);
9751061

9761062
if let Some(result_id) = result {
9771063
if let Some(&ret_val) = builder.inst_results(call).first() {

crates/compiler/src/ssa.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -698,18 +698,19 @@ impl SsaBuilder {
698698
}
699699

700700
TypedTerminator::Unreachable => {
701-
// For void functions, convert Unreachable to Return(None)
702-
// This handles functions that don't have an explicit return statement
703-
let is_void_return = match &self.original_return_type {
704-
Some(Type::Primitive(zyntax_typed_ast::PrimitiveType::Unit)) => true,
705-
None => false, // No return type info - keep as unreachable
706-
_ => false,
707-
};
708-
709-
if is_void_return {
710-
HirTerminator::Return { values: vec![] }
711-
} else {
712-
HirTerminator::Unreachable
701+
// Handle implicit returns for functions without explicit return statements
702+
match &self.original_return_type {
703+
Some(Type::Primitive(zyntax_typed_ast::PrimitiveType::Unit)) | None => {
704+
// Void/Unit return or no return type specified - implicit void return
705+
HirTerminator::Return { values: vec![] }
706+
}
707+
Some(_return_ty) => {
708+
// Non-void function: check if last statement in the block is an expression to implicitly return
709+
// We need to look at the TypedBasicBlock to see if there's an expression statement at the end
710+
// that we should implicitly return
711+
// For now, keep as Unreachable - the type checker should catch missing returns
712+
HirTerminator::Unreachable
713+
}
713714
}
714715
}
715716
};

crates/compiler/src/typed_cfg.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,12 @@ impl TypedCfgBuilder {
175175
for stmt in &block.statements {
176176
match &stmt.node {
177177
TypedStatement::Return(expr) => {
178-
// Return terminates the block
178+
// Explicit return terminates the block
179179
terminator = TypedTerminator::Return(expr.clone());
180180
break; // No more statements after return
181181
}
182182

183-
// For now, treat all other statements as non-terminating
184-
// TODO: Handle If, While, Match, etc. that create control flow
183+
// For all other statements, treat as non-terminating
185184
_ => {
186185
statements.push(stmt.clone());
187186
}
@@ -1040,12 +1039,29 @@ impl TypedCfgBuilder {
10401039
current_block_id, current_statements.len(), exit_id);
10411040
if !all_blocks.iter().any(|b| b.id == current_block_id) {
10421041
log::debug!("[CFG] End: creating final block with {} statements", current_statements.len());
1042+
1043+
// Special case: If the function body has exactly one statement and it's an expression,
1044+
// treat it as an implicit return. This handles cases like:
1045+
// fn add(self, rhs: Tensor) -> Tensor { extern tensor_add(self, rhs) }
1046+
// where the single expression should be returned.
1047+
let (final_statements, terminator) = if current_statements.len() == 1 {
1048+
if let TypedStatement::Expression(expr) = &current_statements[0].node {
1049+
// Single expression - implicitly return it
1050+
(vec![], TypedTerminator::Return(Some(Box::new((**expr).clone()))))
1051+
} else {
1052+
(current_statements, TypedTerminator::Unreachable)
1053+
}
1054+
} else {
1055+
// Multiple statements or no statements - keep as unreachable
1056+
(current_statements, TypedTerminator::Unreachable)
1057+
};
1058+
10431059
all_blocks.push(TypedBasicBlock {
10441060
id: current_block_id,
10451061
label: None,
1046-
statements: current_statements,
1047-
terminator: TypedTerminator::Unreachable,
1048-
pattern_check: None,
1062+
statements: final_statements,
1063+
terminator,
1064+
pattern_check: None,
10491065
});
10501066
exit_id = current_block_id;
10511067
} else {

crates/compiler/src/zpack.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,14 @@ impl ZPack {
392392
.unwrap_or_default()
393393
}
394394

395+
/// Get runtime symbols with signature information (for auto-boxing)
396+
pub fn runtime_symbols_with_signatures(&self) -> &[crate::zrtl::RuntimeSymbolInfo] {
397+
self.runtime
398+
.as_ref()
399+
.map(|r| r.symbols_with_signatures())
400+
.unwrap_or(&[])
401+
}
402+
395403
/// Get a module's HIR by path
396404
pub fn get_module(&self, path: &str) -> Option<&HirModule> {
397405
self.modules.get(path)

crates/compiler/src/zrtl.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,15 @@ impl ZrtlRegistry {
12111211
symbols
12121212
}
12131213

1214+
/// Get all symbols with signature information from all loaded plugins
1215+
pub fn collect_symbols_with_signatures(&self) -> Vec<RuntimeSymbolInfo> {
1216+
let mut symbols = Vec::new();
1217+
for plugin in &self.plugins {
1218+
symbols.extend_from_slice(plugin.symbols_with_signatures());
1219+
}
1220+
symbols
1221+
}
1222+
12141223
/// List all loaded plugin names
12151224
pub fn list_plugins(&self) -> Vec<&str> {
12161225
self.plugins.iter().map(|p| p.name()).collect()

crates/zynml/examples/basic_tensor.zynml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,16 @@ fn test_chained_ops(x: Tensor, y: Tensor, z: Tensor) {
5151
}
5252

5353
fn main() {
54+
// Create test tensors using arange
55+
let a = arange(1.0, 4.0, 1.0) // [1.0, 2.0, 3.0]
56+
let b = arange(2.0, 5.0, 1.0) // [2.0, 3.0, 4.0]
57+
let c = arange(0.5, 1.6, 0.5) // [0.5, 1.0, 1.5]
58+
59+
// Test arithmetic operations
60+
test_arithmetic_ops(a, b)
61+
62+
// Test chained operations
63+
test_chained_ops(a, b, c)
64+
5465
42
5566
}

crates/zyntax_cli/src/backends/cranelift_jit.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ use zyntax_compiler::hir::HirModule;
1313
/// - `entry_candidates`: Pre-resolved entry point candidates from EntryPointResolver
1414
/// These are in order of preference (most likely first).
1515
/// - `pack_symbols`: Runtime symbols from loaded ZPack archives
16+
/// - `pack_symbols_with_sigs`: Runtime symbols with signature information (for auto-boxing)
1617
pub fn compile_jit(
1718
module: HirModule,
1819
_opt_level: u8,
1920
run: bool,
2021
entry_candidates: &[String],
2122
pack_symbols: &[(&'static str, *const u8)],
23+
pack_symbols_with_sigs: &[zyntax_compiler::zrtl::RuntimeSymbolInfo],
2224
verbose: bool,
2325
) -> Result<(), Box<dyn std::error::Error>> {
2426
// Runtime symbols come exclusively from ZPack archives
@@ -40,6 +42,9 @@ pub fn compile_jit(
4042
let mut backend = CraneliftBackend::with_runtime_symbols(&runtime_symbols)
4143
.map_err(|e| format!("Failed to initialize backend: {}", e))?;
4244

45+
// Register symbol signatures for auto-boxing support
46+
backend.register_symbol_signatures(pack_symbols_with_sigs);
47+
4348
if verbose {
4449
println!("{} Compiling functions...", "info:".blue());
4550
}

crates/zyntax_cli/src/backends/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ pub fn compile(
5959

6060
// Collect runtime symbols from all packs (for JIT mode)
6161
let mut pack_symbols: Vec<(&'static str, *const u8)> = Vec::new();
62+
let mut pack_symbols_with_sigs: Vec<zyntax_compiler::zrtl::RuntimeSymbolInfo> = Vec::new();
6263
for pack in packs.iter() {
6364
pack_symbols.extend(pack.runtime_symbols());
65+
pack_symbols_with_sigs.extend_from_slice(pack.runtime_symbols_with_signatures());
6466
}
6567

6668
if verbose && !pack_symbols.is_empty() {
@@ -69,13 +71,13 @@ pub fn compile(
6971

7072
match (backend, jit) {
7173
// JIT execution - uses dynamic runtime symbols from ZPack
72-
(Backend::Cranelift, true) => compile_jit(module, opt_level, true, &candidates, &pack_symbols, verbose),
74+
(Backend::Cranelift, true) => compile_jit(module, opt_level, true, &candidates, &pack_symbols, &pack_symbols_with_sigs, verbose),
7375
(Backend::Llvm, true) => {
7476
compile_and_run_llvm(module, opt_level, Some(entry), &pack_symbols, verbose)?;
7577
Ok(())
7678
}
7779
// AOT compilation - users link static libraries directly
78-
(Backend::Cranelift, false) => compile_jit(module, opt_level, false, &candidates, &pack_symbols, verbose),
80+
(Backend::Cranelift, false) => compile_jit(module, opt_level, false, &candidates, &pack_symbols, &pack_symbols_with_sigs, verbose),
7981
(Backend::Llvm, false) => {
8082
// Static libraries are provided directly by the user via --lib flag
8183
// ZPack is for JIT only - AOT users link their runtime libraries directly

crates/zyntax_embed/src/runtime.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,9 @@ impl ZyntaxRuntime {
13981398
self.register_function(name, *ptr, 0); // Arity unknown without type info
13991399
}
14001400

1401+
// Register symbol signatures for auto-boxing support
1402+
self.backend.register_symbol_signatures(plugin.symbols_with_signatures());
1403+
14011404
// Rebuild the JIT module to include the new symbols
14021405
self.backend.rebuild_with_accumulated_symbols()?;
14031406

@@ -1423,6 +1426,10 @@ impl ZyntaxRuntime {
14231426
self.register_function(name, ptr, 0);
14241427
}
14251428

1429+
// Register symbol signatures for auto-boxing support
1430+
let symbols_with_sigs = registry.collect_symbols_with_signatures();
1431+
self.backend.register_symbol_signatures(&symbols_with_sigs);
1432+
14261433
// Rebuild the JIT module to include all the new symbols
14271434
self.backend.rebuild_with_accumulated_symbols()?;
14281435

0 commit comments

Comments
 (0)