Skip to content

Commit 3dd5933

Browse files
committed
runtime: validate compute yield context and keep int literals inferred
1 parent 8b0a6e2 commit 3dd5933

3 files changed

Lines changed: 135 additions & 24 deletions

File tree

crates/compiler/src/ssa.rs

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ pub struct SsaBuilder {
6969
/// External function link names (alias -> ZRTL symbol)
7070
/// e.g., "tensor_add" -> "$Tensor$add"
7171
extern_link_names: IndexMap<InternedString, String>,
72+
/// Captured yield values for active compute-expression lowering contexts.
73+
/// Empty outside compute expression translation.
74+
compute_yield_stack: Vec<Vec<HirId>>,
7275
}
7376

7477
/// Context for pattern matching
@@ -314,6 +317,7 @@ impl SsaBuilder {
314317
address_taken_vars: HashSet::new(),
315318
stack_slots: IndexMap::new(),
316319
extern_link_names: IndexMap::new(),
320+
compute_yield_stack: Vec::new(),
317321
}
318322
}
319323

@@ -1368,8 +1372,16 @@ impl SsaBuilder {
13681372
self.translate_expression(block_id, expr)?;
13691373
}
13701374
TypedStatement::Yield(expr) => {
1371-
// Yield behaves like an expression statement at the SSA level.
1372-
self.translate_expression(block_id, expr)?;
1375+
if self.compute_yield_stack.is_empty() {
1376+
return Err(crate::CompilerError::Analysis(
1377+
"`yield` is only valid inside compute expression bodies".to_string(),
1378+
));
1379+
}
1380+
1381+
let yielded = self.translate_expression(block_id, expr)?;
1382+
if let Some(active_yields) = self.compute_yield_stack.last_mut() {
1383+
active_yields.push(yielded);
1384+
}
13731385
}
13741386

13751387
TypedStatement::Match(match_stmt) => {
@@ -1494,6 +1506,31 @@ impl SsaBuilder {
14941506
Ok(())
14951507
}
14961508

1509+
fn translate_compute_dispatch_call(
1510+
&mut self,
1511+
block_id: HirId,
1512+
expr: &TypedNode<TypedExpression>,
1513+
compute: &zyntax_typed_ast::typed_ast::TypedComputeExpr,
1514+
) -> CompilerResult<HirId> {
1515+
let compute_name = InternedString::new_global(INTERNAL_COMPUTE_ALIAS);
1516+
let lowered = zyntax_typed_ast::typed_ast::TypedCall {
1517+
callee: Box::new(zyntax_typed_ast::typed_node(
1518+
TypedExpression::Variable(compute_name),
1519+
Type::Any,
1520+
expr.span,
1521+
)),
1522+
positional_args: compute.args.clone(),
1523+
named_args: vec![],
1524+
type_args: vec![],
1525+
};
1526+
let lowered_expr = zyntax_typed_ast::typed_node(
1527+
TypedExpression::Call(lowered),
1528+
expr.ty.clone(),
1529+
expr.span,
1530+
);
1531+
self.translate_expression(block_id, &lowered_expr)
1532+
}
1533+
14971534
/// Translate expression to SSA value
14981535
fn translate_expression(
14991536
&mut self,
@@ -1941,25 +1978,49 @@ impl SsaBuilder {
19411978
}
19421979

19431980
TypedExpression::Compute(compute) => {
1944-
// Lower compute expressions through the existing call pipeline for now.
1945-
// This preserves call-based execution while keeping typed compute structure.
1946-
let compute_name = InternedString::new_global(INTERNAL_COMPUTE_ALIAS);
1947-
let lowered = zyntax_typed_ast::typed_ast::TypedCall {
1948-
callee: Box::new(zyntax_typed_ast::typed_node(
1949-
TypedExpression::Variable(compute_name),
1950-
Type::Any,
1951-
expr.span,
1952-
)),
1953-
positional_args: compute.args.clone(),
1954-
named_args: vec![],
1955-
type_args: vec![],
1956-
};
1957-
let lowered_expr = zyntax_typed_ast::typed_node(
1958-
TypedExpression::Call(lowered),
1959-
expr.ty.clone(),
1960-
expr.span,
1961-
);
1962-
self.translate_expression(block_id, &lowered_expr)
1981+
let has_direct_yield = compute.body.statements.iter().any(|stmt| {
1982+
matches!(
1983+
stmt.node,
1984+
zyntax_typed_ast::typed_ast::TypedStatement::Yield(_)
1985+
)
1986+
});
1987+
1988+
if !has_direct_yield {
1989+
// Preserve legacy runtime behavior for non-reduction compute blocks.
1990+
return self.translate_compute_dispatch_call(block_id, expr, compute);
1991+
}
1992+
1993+
// CPU fallback path for reduction-style compute blocks.
1994+
// Evaluate explicit compute args for side effects, then execute compute body.
1995+
for arg in &compute.args {
1996+
self.translate_expression(block_id, arg)?;
1997+
}
1998+
1999+
self.compute_yield_stack.push(Vec::new());
2000+
let mut current_block = block_id;
2001+
for stmt in &compute.body.statements {
2002+
match &stmt.node {
2003+
zyntax_typed_ast::typed_ast::TypedStatement::Let(_)
2004+
| zyntax_typed_ast::typed_ast::TypedStatement::Expression(_)
2005+
| zyntax_typed_ast::typed_ast::TypedStatement::Yield(_) => {
2006+
current_block = self.process_statement(current_block, stmt)?;
2007+
}
2008+
_ => {
2009+
self.compute_yield_stack.pop();
2010+
return Err(crate::CompilerError::Analysis(
2011+
"compute body fallback currently supports only let/expression/yield statements"
2012+
.to_string(),
2013+
));
2014+
}
2015+
}
2016+
}
2017+
2018+
let yields = self.compute_yield_stack.pop().unwrap_or_default();
2019+
if let Some(last_yield) = yields.last().copied() {
2020+
Ok(last_yield)
2021+
} else {
2022+
self.translate_compute_dispatch_call(block_id, expr, compute)
2023+
}
19632024
}
19642025

19652026
TypedExpression::Field(field_access) => {

crates/zyn_peg/src/runtime2/interpreter.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -809,11 +809,11 @@ impl<'g> GrammarInterpreter<'g> {
809809
_ => return Err(format!("unknown TypedExpression variant: {}", variant)),
810810
};
811811

812-
// Determine the type based on the expression variant
813-
// Default to I64 for integers (64-bit system) and F64 for floats
812+
// Determine the type based on the expression variant.
813+
// Integer literals default to I32; assignment/context-driven inference can refine later.
814814
let ty = match &expr {
815815
TypedExpression::Literal(TypedLiteral::Integer(_)) => {
816-
Type::Primitive(PrimitiveType::I64)
816+
Type::Primitive(PrimitiveType::I32)
817817
}
818818
TypedExpression::Literal(TypedLiteral::Float(_)) => Type::Primitive(PrimitiveType::F64),
819819
TypedExpression::Literal(TypedLiteral::String(_)) => {
@@ -824,6 +824,7 @@ impl<'g> GrammarInterpreter<'g> {
824824
// Use Type::Any to signal that lowering should infer the type
825825
TypedExpression::Call(_) => Type::Any,
826826
TypedExpression::Variable(_) => Type::Any,
827+
TypedExpression::Compute(_) => Type::Any,
827828
// Struct literal gets its type from the struct name - use Unresolved for compiler to resolve
828829
TypedExpression::Struct(lit) => Type::Unresolved(lit.name),
829830
_ => Type::Primitive(PrimitiveType::Unit),

crates/zynml/tests/e2e_tests.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3590,6 +3590,55 @@ mod execution {
35903590
}
35913591
}
35923592

3593+
#[test]
3594+
fn test_execute_compute_yield_fallback() {
3595+
let Some(mut zynml) = create_runtime_with_plugins() else {
3596+
println!("Skipping: plugins not available");
3597+
return;
3598+
};
3599+
3600+
let source = r#"
3601+
fn compute_fallback() {
3602+
let x = 41
3603+
let _y = compute(x) @device("cpu") {
3604+
yield x
3605+
}
3606+
}
3607+
"#;
3608+
3609+
let functions = zynml
3610+
.load_source(source)
3611+
.expect("compute fallback program should compile");
3612+
assert!(
3613+
functions.iter().any(|f| f == "compute_fallback"),
3614+
"compute_fallback should be exported"
3615+
);
3616+
}
3617+
3618+
#[test]
3619+
fn test_execute_yield_outside_compute_rejected() {
3620+
let Some(mut zynml) = create_runtime_with_plugins() else {
3621+
println!("Skipping: plugins not available");
3622+
return;
3623+
};
3624+
3625+
let source = r#"
3626+
fn invalid_yield() {
3627+
yield 1
3628+
}
3629+
"#;
3630+
3631+
let functions = zynml
3632+
.load_source(source)
3633+
.expect("compiler should continue while skipping invalid function");
3634+
let message = format!("{:?}", functions);
3635+
assert!(
3636+
functions.is_empty(),
3637+
"yield outside compute should prevent function export, got: {}",
3638+
message
3639+
);
3640+
}
3641+
35933642
// Regression coverage for the full hello example (prelude + tensor + dynamic println).
35943643
// Keep panic isolation via catch_unwind so runtime regressions produce actionable test output.
35953644
#[test]

0 commit comments

Comments
 (0)