@@ -69,6 +69,9 @@ pub struct SsaBuilder {
6969 /// External function link names (alias -> ZRTL symbol)
7070 /// e.g., "tensor_add" -> "$Tensor$add"
7171 extern_link_names : IndexMap < InternedString , String > ,
72+ /// Captured yield values for active compute-expression lowering contexts.
73+ /// Empty outside compute expression translation.
74+ compute_yield_stack : Vec < Vec < HirId > > ,
7275}
7376
7477/// Context for pattern matching
@@ -314,6 +317,7 @@ impl SsaBuilder {
314317 address_taken_vars : HashSet :: new ( ) ,
315318 stack_slots : IndexMap :: new ( ) ,
316319 extern_link_names : IndexMap :: new ( ) ,
320+ compute_yield_stack : Vec :: new ( ) ,
317321 }
318322 }
319323
@@ -1368,8 +1372,16 @@ impl SsaBuilder {
13681372 self . translate_expression ( block_id, expr) ?;
13691373 }
13701374 TypedStatement :: Yield ( expr) => {
1371- // Yield behaves like an expression statement at the SSA level.
1372- self . translate_expression ( block_id, expr) ?;
1375+ if self . compute_yield_stack . is_empty ( ) {
1376+ return Err ( crate :: CompilerError :: Analysis (
1377+ "`yield` is only valid inside compute expression bodies" . to_string ( ) ,
1378+ ) ) ;
1379+ }
1380+
1381+ let yielded = self . translate_expression ( block_id, expr) ?;
1382+ if let Some ( active_yields) = self . compute_yield_stack . last_mut ( ) {
1383+ active_yields. push ( yielded) ;
1384+ }
13731385 }
13741386
13751387 TypedStatement :: Match ( match_stmt) => {
@@ -1494,6 +1506,31 @@ impl SsaBuilder {
14941506 Ok ( ( ) )
14951507 }
14961508
1509+ fn translate_compute_dispatch_call (
1510+ & mut self ,
1511+ block_id : HirId ,
1512+ expr : & TypedNode < TypedExpression > ,
1513+ compute : & zyntax_typed_ast:: typed_ast:: TypedComputeExpr ,
1514+ ) -> CompilerResult < HirId > {
1515+ let compute_name = InternedString :: new_global ( INTERNAL_COMPUTE_ALIAS ) ;
1516+ let lowered = zyntax_typed_ast:: typed_ast:: TypedCall {
1517+ callee : Box :: new ( zyntax_typed_ast:: typed_node (
1518+ TypedExpression :: Variable ( compute_name) ,
1519+ Type :: Any ,
1520+ expr. span ,
1521+ ) ) ,
1522+ positional_args : compute. args . clone ( ) ,
1523+ named_args : vec ! [ ] ,
1524+ type_args : vec ! [ ] ,
1525+ } ;
1526+ let lowered_expr = zyntax_typed_ast:: typed_node (
1527+ TypedExpression :: Call ( lowered) ,
1528+ expr. ty . clone ( ) ,
1529+ expr. span ,
1530+ ) ;
1531+ self . translate_expression ( block_id, & lowered_expr)
1532+ }
1533+
14971534 /// Translate expression to SSA value
14981535 fn translate_expression (
14991536 & mut self ,
@@ -1941,25 +1978,49 @@ impl SsaBuilder {
19411978 }
19421979
19431980 TypedExpression :: Compute ( compute) => {
1944- // Lower compute expressions through the existing call pipeline for now.
1945- // This preserves call-based execution while keeping typed compute structure.
1946- let compute_name = InternedString :: new_global ( INTERNAL_COMPUTE_ALIAS ) ;
1947- let lowered = zyntax_typed_ast:: typed_ast:: TypedCall {
1948- callee : Box :: new ( zyntax_typed_ast:: typed_node (
1949- TypedExpression :: Variable ( compute_name) ,
1950- Type :: Any ,
1951- expr. span ,
1952- ) ) ,
1953- positional_args : compute. args . clone ( ) ,
1954- named_args : vec ! [ ] ,
1955- type_args : vec ! [ ] ,
1956- } ;
1957- let lowered_expr = zyntax_typed_ast:: typed_node (
1958- TypedExpression :: Call ( lowered) ,
1959- expr. ty . clone ( ) ,
1960- expr. span ,
1961- ) ;
1962- self . translate_expression ( block_id, & lowered_expr)
1981+ let has_direct_yield = compute. body . statements . iter ( ) . any ( |stmt| {
1982+ matches ! (
1983+ stmt. node,
1984+ zyntax_typed_ast:: typed_ast:: TypedStatement :: Yield ( _)
1985+ )
1986+ } ) ;
1987+
1988+ if !has_direct_yield {
1989+ // Preserve legacy runtime behavior for non-reduction compute blocks.
1990+ return self . translate_compute_dispatch_call ( block_id, expr, compute) ;
1991+ }
1992+
1993+ // CPU fallback path for reduction-style compute blocks.
1994+ // Evaluate explicit compute args for side effects, then execute compute body.
1995+ for arg in & compute. args {
1996+ self . translate_expression ( block_id, arg) ?;
1997+ }
1998+
1999+ self . compute_yield_stack . push ( Vec :: new ( ) ) ;
2000+ let mut current_block = block_id;
2001+ for stmt in & compute. body . statements {
2002+ match & stmt. node {
2003+ zyntax_typed_ast:: typed_ast:: TypedStatement :: Let ( _)
2004+ | zyntax_typed_ast:: typed_ast:: TypedStatement :: Expression ( _)
2005+ | zyntax_typed_ast:: typed_ast:: TypedStatement :: Yield ( _) => {
2006+ current_block = self . process_statement ( current_block, stmt) ?;
2007+ }
2008+ _ => {
2009+ self . compute_yield_stack . pop ( ) ;
2010+ return Err ( crate :: CompilerError :: Analysis (
2011+ "compute body fallback currently supports only let/expression/yield statements"
2012+ . to_string ( ) ,
2013+ ) ) ;
2014+ }
2015+ }
2016+ }
2017+
2018+ let yields = self . compute_yield_stack . pop ( ) . unwrap_or_default ( ) ;
2019+ if let Some ( last_yield) = yields. last ( ) . copied ( ) {
2020+ Ok ( last_yield)
2021+ } else {
2022+ self . translate_compute_dispatch_call ( block_id, expr, compute)
2023+ }
19632024 }
19642025
19652026 TypedExpression :: Field ( field_access) => {
0 commit comments