Skip to content

Commit 7d27b33

Browse files
committed
fix: List<T> field registration and f-string interpolation for opaque types
Fix three issues preventing f-strings from working with tensor types: 1. List<T> field registration: Generic types pre-registered during parsing had 0 fields, and register_struct_declarations skipped them. Now updates existing types when they have empty fields but new declarations have real fields, preserving the existing TypeId. 2. F-string closure approach: Grammar now produces __fstring__(parts...) call nodes instead of nested concat() chains. SSA intercepts println(__fstring__(...)) and emits individual print_dynamic/ println_dynamic calls, preserving Display trait dispatch for each part. 3. Block expression support: Added TypedExpression::Block handler in SSA for general block expression evaluation.
1 parent 1b67471 commit 7d27b33

5 files changed

Lines changed: 284 additions & 53 deletions

File tree

crates/compiler/src/ssa.rs

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2651,6 +2651,99 @@ impl SsaBuilder {
26512651
let callee = &call.callee;
26522652
let args = &call.positional_args;
26532653

2654+
// F-string closure inlining: println(f"text {expr}") desugars to
2655+
// println(__fstring__("text", expr)). Intercept this pattern and
2656+
// emit individual print_dynamic() calls for each part, which properly
2657+
// handle DynamicBox wrapping with Display trait dispatch for opaque types.
2658+
if let TypedExpression::Variable(func_name) = &callee.node {
2659+
let outer_name = func_name.resolve_global().unwrap_or_default();
2660+
if (outer_name == "println" || outer_name == "print"
2661+
|| outer_name == "eprintln" || outer_name == "eprint")
2662+
&& args.len() == 1
2663+
{
2664+
if let TypedExpression::Call(inner_call) = &args[0].node {
2665+
if let TypedExpression::Variable(inner_name) = &inner_call.callee.node {
2666+
let inner_str = inner_name.resolve_global().unwrap_or_default();
2667+
if inner_str == "__fstring__" {
2668+
// Flatten: emit print(part) for each f-string part
2669+
let is_err = outer_name.starts_with('e');
2670+
let print_symbol = if is_err {
2671+
"$IO$eprint_dynamic".to_string()
2672+
} else {
2673+
"$IO$print_dynamic".to_string()
2674+
};
2675+
let println_symbol = if is_err {
2676+
"$IO$eprintln_dynamic".to_string()
2677+
} else {
2678+
"$IO$println_dynamic".to_string()
2679+
};
2680+
2681+
let parts = &inner_call.positional_args;
2682+
let last_idx = parts.len().saturating_sub(1);
2683+
2684+
for (i, part) in parts.iter().enumerate() {
2685+
let val = self.translate_expression(block_id, part)?;
2686+
// Last part of println: use println_dynamic (adds newline)
2687+
// All other parts: use print_dynamic (no newline)
2688+
let symbol = if i == last_idx
2689+
&& (outer_name == "println" || outer_name == "eprintln")
2690+
{
2691+
&println_symbol
2692+
} else {
2693+
&print_symbol
2694+
};
2695+
2696+
let result = self.create_value(
2697+
HirType::Void,
2698+
HirValueKind::Instruction,
2699+
);
2700+
let call_inst = HirInstruction::Call {
2701+
result: Some(result),
2702+
callee: crate::hir::HirCallable::Symbol(
2703+
symbol.clone(),
2704+
),
2705+
args: vec![val],
2706+
type_args: vec![],
2707+
const_args: vec![],
2708+
is_tail: false,
2709+
};
2710+
self.add_instruction(block_id, call_inst);
2711+
}
2712+
2713+
// If println with no parts (empty f-string), or println
2714+
// where the last part was handled by print (not println),
2715+
// emit a final newline
2716+
if parts.is_empty()
2717+
&& (outer_name == "println" || outer_name == "eprintln")
2718+
{
2719+
let newline_interned =
2720+
zyntax_typed_ast::InternedString::new_global("");
2721+
let newline_val =
2722+
self.create_string_global(newline_interned);
2723+
let result = self.create_value(
2724+
HirType::Void,
2725+
HirValueKind::Instruction,
2726+
);
2727+
let call_inst = HirInstruction::Call {
2728+
result: Some(result),
2729+
callee: crate::hir::HirCallable::Symbol(
2730+
println_symbol,
2731+
),
2732+
args: vec![newline_val],
2733+
type_args: vec![],
2734+
const_args: vec![],
2735+
is_tail: false,
2736+
};
2737+
self.add_instruction(block_id, call_inst);
2738+
}
2739+
2740+
return Ok(self.create_undef(HirType::Void));
2741+
}
2742+
}
2743+
}
2744+
}
2745+
}
2746+
26542747
// Check if callee is a function name (direct call) vs expression (indirect call)
26552748
// Path expressions should be resolved to Variable during lowering/type resolution
26562749
let (hir_callable, indirect_callee_val) = if let TypedExpression::Variable(
@@ -3858,6 +3951,36 @@ impl SsaBuilder {
38583951

38593952
TypedExpression::Lambda(lambda) => self.translate_closure(block_id, lambda, &expr.ty),
38603953

3954+
TypedExpression::Block(block) => {
3955+
// Block expression: evaluate all statements, return value of last expression.
3956+
// Used by f-string desugaring (closure approach): the block contains
3957+
// print() calls for each f-string part and ends with an empty string.
3958+
let mut last_val = self.create_undef(HirType::Void);
3959+
for stmt in &block.statements {
3960+
match &stmt.node {
3961+
zyntax_typed_ast::typed_ast::TypedStatement::Expression(e) => {
3962+
last_val = self.translate_expression(block_id, e)?;
3963+
}
3964+
zyntax_typed_ast::typed_ast::TypedStatement::Let(let_stmt) => {
3965+
if let Some(init) = &let_stmt.initializer {
3966+
let val = self.translate_expression(block_id, init)?;
3967+
self.write_variable(let_stmt.name, block_id, val);
3968+
self.var_types
3969+
.insert(let_stmt.name, self.convert_type(&let_stmt.ty));
3970+
}
3971+
}
3972+
zyntax_typed_ast::typed_ast::TypedStatement::Return(ret_expr) => {
3973+
if let Some(ret) = ret_expr {
3974+
let val = self.translate_expression(block_id, ret)?;
3975+
last_val = val;
3976+
}
3977+
}
3978+
_ => {}
3979+
}
3980+
}
3981+
Ok(last_val)
3982+
}
3983+
38613984
_ => {
38623985
// Fallback for any remaining unhandled expressions
38633986
Ok(self.create_undef(self.convert_type(&expr.ty)))
@@ -4977,6 +5100,13 @@ impl SsaBuilder {
49775100
use crate::hir::HirStructType;
49785101
use zyntax_typed_ast::type_registry::TypeKind;
49795102

5103+
// Extern/opaque types (ZRTL-backed like Tensor) → Ptr(Opaque)
5104+
if let Some(zyntax_typed_ast::type_registry::Type::Extern { name: extern_name, .. }) =
5105+
self.type_registry.resolve_alias(type_def.name)
5106+
{
5107+
return HirType::Ptr(Box::new(HirType::Opaque(*extern_name)));
5108+
}
5109+
49805110
// Abstract types are zero-cost wrappers with struct layout
49815111
// They must be treated as structs for field access to work
49825112
// The backend will optimize away the wrapper at codegen time
@@ -5033,6 +5163,14 @@ impl SsaBuilder {
50335163
type_def.kind
50345164
);
50355165

5166+
// Extern/opaque types (ZRTL-backed like Tensor) → Ptr(Opaque)
5167+
// Extern types are registered with TypeKind::Atomic and an alias to Type::Extern
5168+
if let Some(zyntax_typed_ast::type_registry::Type::Extern { name: extern_name, .. }) =
5169+
self.type_registry.resolve_alias(type_def.name)
5170+
{
5171+
return HirType::Ptr(Box::new(HirType::Opaque(*extern_name)));
5172+
}
5173+
50365174
// Abstract types are zero-cost wrappers with struct layout
50375175
if let TypeKind::Abstract { .. } = &type_def.kind {
50385176
let hir_fields: Vec<HirType> = type_def
@@ -5507,6 +5645,21 @@ impl SsaBuilder {
55075645
}
55085646
Type::Named { id, .. } => {
55095647
if let Some(type_def) = self.type_registry.get_type_by_id(*id) {
5648+
// Check if this is an extern type with a runtime prefix alias
5649+
if let Some(Type::Extern { name: extern_name, .. }) =
5650+
self.type_registry.resolve_alias(type_def.name)
5651+
{
5652+
let extern_str = extern_name.resolve_global().unwrap_or_else(|| {
5653+
let arena = self.arena.lock().unwrap();
5654+
arena
5655+
.resolve_string(*extern_name)
5656+
.map(|s| s.to_string())
5657+
.unwrap_or_default()
5658+
});
5659+
log::debug!("[trait_dispatch] Type::Named (extern alias) prefix: '{}'", extern_str);
5660+
return Some(extern_str);
5661+
}
5662+
55105663
// Use the type name
55115664
let name = type_def.name.resolve_global().unwrap_or_else(|| {
55125665
let arena = self.arena.lock().unwrap();
@@ -5694,7 +5847,6 @@ impl SsaBuilder {
56945847
}
56955848
}
56965849

5697-
// Field not found
56985850
Err(crate::CompilerError::Analysis(format!(
56995851
"Field {:?} not found in type {:?}",
57005852
field_name, type_def.name

crates/zyn_peg/src/runtime2/interpreter.rs

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,9 +2257,17 @@ impl<'g> GrammarInterpreter<'g> {
22572257
Ok(acc)
22582258
}
22592259
"fold_concat" => {
2260-
// fold_concat(parts) - fold string parts into nested concat calls
2261-
// e.g., fold_concat(["a", "b", "c"]) -> concat(concat("a", "b"), "c")
2262-
// Used by f-string desugaring
2260+
// fold_concat(parts) - f-string desugaring using "closure" approach
2261+
//
2262+
// Produces: __fstring__(part1, part2, ...) — a special call node that
2263+
// the SSA intercepts. When println(f"...") is encountered, SSA emits
2264+
// individual print_dynamic() calls for each part (with DynamicBox
2265+
// wrapping that respects Display traits for opaque types like Tensor).
2266+
//
2267+
// f"text {expr}" → __fstring__("text", expr)
2268+
//
2269+
// The __fstring_format__ wrappers on interpolated expressions are
2270+
// stripped since print_dynamic already handles Display trait dispatch.
22632271
if args.len() != 1 {
22642272
return Err(
22652273
"fold_concat() requires exactly 1 argument (parts list)".to_string()
@@ -2297,32 +2305,48 @@ impl<'g> GrammarInterpreter<'g> {
22972305
.map(|e| ParsedValue::Expression(Box::new(e)));
22982306
}
22992307

2300-
// Fold left: concat(concat(a, b), c)
2301-
let concat_name = state.intern("concat");
2302-
let mut iter = parts_list.into_iter();
2303-
let first = iter.next().unwrap();
2304-
let mut acc = self.parsed_value_to_expr(first, state)?;
2308+
// Build __fstring__(part1, part2, ...) call.
2309+
// Strip __fstring_format__ wrappers — the SSA handles Display dispatch.
2310+
let fstring_name = state.intern("__fstring__");
2311+
let mut fstring_args = Vec::new();
23052312

2306-
for part in iter {
2313+
for part in parts_list {
23072314
let part_expr = self.parsed_value_to_expr(part, state)?;
2308-
// Create concat(acc, part)
2309-
acc = typed_node(
2310-
TypedExpression::Call(TypedCall {
2311-
callee: Box::new(typed_node(
2312-
TypedExpression::Variable(concat_name),
2313-
Type::Unknown,
2314-
span,
2315-
)),
2316-
positional_args: vec![acc, part_expr],
2317-
named_args: vec![],
2318-
type_args: vec![],
2319-
}),
2320-
Type::Primitive(zyntax_typed_ast::PrimitiveType::String),
2321-
span,
2322-
);
2315+
2316+
// Strip __fstring_format__ wrapper: __fstring_format__(expr) → expr
2317+
let unwrapped_expr = if let TypedExpression::Call(ref call) = part_expr.node {
2318+
if let TypedExpression::Variable(callee_name) = &call.callee.node {
2319+
let callee_str = callee_name.resolve_global().unwrap_or_default();
2320+
if callee_str == "__fstring_format__" && call.positional_args.len() == 1
2321+
{
2322+
call.positional_args[0].clone()
2323+
} else {
2324+
part_expr
2325+
}
2326+
} else {
2327+
part_expr
2328+
}
2329+
} else {
2330+
part_expr
2331+
};
2332+
2333+
fstring_args.push(unwrapped_expr);
23232334
}
23242335

2325-
Ok(ParsedValue::Expression(Box::new(acc)))
2336+
Ok(ParsedValue::Expression(Box::new(typed_node(
2337+
TypedExpression::Call(TypedCall {
2338+
callee: Box::new(typed_node(
2339+
TypedExpression::Variable(fstring_name),
2340+
Type::Unknown,
2341+
span,
2342+
)),
2343+
positional_args: fstring_args,
2344+
named_args: vec![],
2345+
type_args: vec![],
2346+
}),
2347+
Type::Primitive(zyntax_typed_ast::PrimitiveType::String),
2348+
span,
2349+
))))
23262350
}
23272351
_ => Err(format!("unknown helper function: {}", function)),
23282352
}
Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Simple hello.zynml without f-strings
1+
// ZynML hello world with tensor operations and f-strings
22
import tensor
33

44
def main() {
@@ -8,42 +8,32 @@ def main() {
88
let a = Tensor::arange(1.0, 5.0, 1.0) // [1.0, 2.0, 3.0, 4.0]
99
let b = Tensor::arange(5.0, 9.0, 1.0) // [5.0, 6.0, 7.0, 8.0]
1010

11-
println("Tensor a:")
12-
println(a)
13-
14-
println("Tensor b:")
15-
println(b)
11+
println(f"Tensor a: {a}")
12+
println(f"Tensor b: {b}")
1613

1714
// Arithmetic operations using operator overloading
1815
let sum = a + b
19-
println("a + b:")
20-
println(sum)
16+
println(f"a + b: {sum}")
2117

2218
let diff = b - a
23-
println("b - a:")
24-
println(diff)
19+
println(f"b - a: {diff}")
2520

2621
let prod = a * b
27-
println("a * b (element-wise):")
28-
println(prod)
22+
println(f"a * b (element-wise): {prod}")
2923

3024
// Reduction operations
3125
let total = a.sum()
32-
println("Sum of a:")
33-
println(total)
26+
println(f"Sum of a: {total}")
3427

3528
let avg = a.mean()
36-
println("Mean of a:")
37-
println(avg)
29+
println(f"Mean of a: {avg}")
3830

3931
// Create matrices
4032
let matrix = Tensor::zeros([2, 3])
41-
println("2x3 zeros matrix:")
42-
println(matrix)
33+
println(f"2x3 zeros matrix: {matrix}")
4334

4435
let ones = Tensor::ones([3, 2])
45-
println("3x2 ones matrix:")
46-
println(ones)
36+
println(f"3x2 ones matrix: {ones}")
4737

4838
println("Done!")
4939
}

crates/zynml/stdlib/tensor.zynml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,24 @@ impl Tensor {
8383

8484
// Create tensor filled with zeros
8585
// Usage: Tensor::zeros([3, 4])
86-
// ZRTL expects (ptr, ndim) - wrapper unpacks List struct fields
8786
def zeros(shape: List<i64>): Tensor {
8887
extern tensor_zeros(shape.data, shape.len)
8988
}
9089

9190
// Create tensor filled with ones
9291
// Usage: Tensor::ones([3, 4])
93-
// ZRTL expects (ptr, ndim) - wrapper unpacks List struct fields
9492
def ones(shape: List<i64>): Tensor {
9593
extern tensor_ones(shape.data, shape.len)
9694
}
9795

96+
// Dimension-specific convenience constructors
97+
extern def zeros_1d(n: i64): Tensor
98+
extern def zeros_2d(rows: i64, cols: i64): Tensor
99+
extern def zeros_3d(d0: i64, d1: i64, d2: i64): Tensor
100+
extern def ones_1d(n: i64): Tensor
101+
extern def ones_2d(rows: i64, cols: i64): Tensor
102+
extern def ones_3d(d0: i64, d1: i64, d2: i64): Tensor
103+
98104
// Create tensor with values in range [start, end) with step
99105
// Usage: Tensor::arange(0.0, 10.0, 1.0)
100106
// ZRTL signature: (f64, f64, f64) -> opaque

0 commit comments

Comments
 (0)