Skip to content

Commit 6853e1f

Browse files
committed
Lower logical && and || with CFG short-circuit semantics
1 parent cd79201 commit 6853e1f

2 files changed

Lines changed: 333 additions & 18 deletions

File tree

crates/compiler/src/ssa.rs

Lines changed: 108 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,33 +1607,123 @@ impl SsaBuilder {
16071607
return self.translate_assignment(block_id, left, right);
16081608
}
16091609

1610-
// Logical AND/OR operators - use bitwise operations for now
1611-
// TODO: Implement proper short-circuit evaluation
1612-
// For now, both sides are evaluated (no short-circuiting)
1610+
// Logical AND/OR operators require short-circuit control flow.
16131611
if matches!(op, FrontendOp::And | FrontendOp::Or) {
16141612
let left_val = self.translate_expression(block_id, left)?;
1615-
let right_val = self.translate_expression(block_id, right)?;
1616-
let result_type = self.convert_type(&expr.ty);
1613+
let left_block_id = self.continuation_block.take().unwrap_or(block_id);
1614+
1615+
let rhs_block_id = HirId::new();
1616+
let short_block_id = HirId::new();
1617+
let merge_block_id = HirId::new();
16171618

1618-
let hir_op = match op {
1619-
FrontendOp::And => crate::hir::BinaryOp::And, // Bitwise AND for now
1620-
FrontendOp::Or => crate::hir::BinaryOp::Or, // Bitwise OR for now
1619+
self.function
1620+
.blocks
1621+
.insert(rhs_block_id, HirBlock::new(rhs_block_id));
1622+
self.function
1623+
.blocks
1624+
.insert(short_block_id, HirBlock::new(short_block_id));
1625+
self.function
1626+
.blocks
1627+
.insert(merge_block_id, HirBlock::new(merge_block_id));
1628+
1629+
self.definitions.insert(rhs_block_id, IndexMap::new());
1630+
self.definitions.insert(short_block_id, IndexMap::new());
1631+
self.definitions.insert(merge_block_id, IndexMap::new());
1632+
1633+
let (true_target, false_target, short_value) = match op {
1634+
FrontendOp::And => (rhs_block_id, short_block_id, false),
1635+
FrontendOp::Or => (short_block_id, rhs_block_id, true),
16211636
_ => unreachable!(),
16221637
};
16231638

1624-
let result = self.create_value(result_type.clone(), HirValueKind::Instruction);
1625-
let inst = HirInstruction::Binary {
1626-
op: hir_op,
1627-
result,
1628-
ty: result_type,
1629-
left: left_val,
1630-
right: right_val,
1639+
self.function
1640+
.blocks
1641+
.get_mut(&left_block_id)
1642+
.unwrap()
1643+
.terminator = HirTerminator::CondBranch {
1644+
condition: left_val,
1645+
true_target,
1646+
false_target,
16311647
};
1648+
self.function
1649+
.blocks
1650+
.get_mut(&left_block_id)
1651+
.unwrap()
1652+
.successors = vec![true_target, false_target];
1653+
self.function
1654+
.blocks
1655+
.get_mut(&rhs_block_id)
1656+
.unwrap()
1657+
.predecessors
1658+
.push(left_block_id);
1659+
self.function
1660+
.blocks
1661+
.get_mut(&short_block_id)
1662+
.unwrap()
1663+
.predecessors
1664+
.push(left_block_id);
16321665

1633-
self.add_instruction(block_id, inst);
1634-
self.add_use(left_val, result);
1635-
self.add_use(right_val, result);
1666+
let short_const = self.create_value(
1667+
HirType::Bool,
1668+
HirValueKind::Constant(crate::hir::HirConstant::Bool(short_value)),
1669+
);
1670+
self.function
1671+
.blocks
1672+
.get_mut(&short_block_id)
1673+
.unwrap()
1674+
.terminator = HirTerminator::Branch {
1675+
target: merge_block_id,
1676+
};
1677+
self.function
1678+
.blocks
1679+
.get_mut(&short_block_id)
1680+
.unwrap()
1681+
.successors = vec![merge_block_id];
1682+
self.function
1683+
.blocks
1684+
.get_mut(&merge_block_id)
1685+
.unwrap()
1686+
.predecessors
1687+
.push(short_block_id);
1688+
1689+
let rhs_value = self.translate_expression(rhs_block_id, right)?;
1690+
let rhs_exit_block_id = self.continuation_block.take().unwrap_or(rhs_block_id);
1691+
self.function
1692+
.blocks
1693+
.get_mut(&rhs_exit_block_id)
1694+
.unwrap()
1695+
.terminator = HirTerminator::Branch {
1696+
target: merge_block_id,
1697+
};
1698+
self.function
1699+
.blocks
1700+
.get_mut(&rhs_exit_block_id)
1701+
.unwrap()
1702+
.successors = vec![merge_block_id];
1703+
self.function
1704+
.blocks
1705+
.get_mut(&merge_block_id)
1706+
.unwrap()
1707+
.predecessors
1708+
.push(rhs_exit_block_id);
1709+
1710+
let result_type = self.convert_type(&expr.ty);
1711+
let result = self.create_value(result_type.clone(), HirValueKind::Instruction);
1712+
self.function
1713+
.blocks
1714+
.get_mut(&merge_block_id)
1715+
.unwrap()
1716+
.phis
1717+
.push(HirPhi {
1718+
result,
1719+
ty: result_type,
1720+
incoming: vec![
1721+
(short_const, short_block_id),
1722+
(rhs_value, rhs_exit_block_id),
1723+
],
1724+
});
16361725

1726+
self.continuation_block = Some(merge_block_id);
16371727
return Ok(result);
16381728
}
16391729

crates/compiler/tests/expression_lowering_tests.rs

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,41 @@ fn create_test_program(arena: &mut AstArena, func_name: &str, body: TypedBlock)
7575
}
7676
}
7777

78+
/// Helper to create a typed program with one bool-returning function
79+
fn create_bool_test_program(
80+
arena: &mut AstArena,
81+
func_name: &str,
82+
body: TypedBlock,
83+
) -> TypedProgram {
84+
let name = arena.intern_string(func_name);
85+
let function = TypedFunction {
86+
name,
87+
params: vec![],
88+
type_params: vec![],
89+
return_type: Type::Primitive(PrimitiveType::Bool),
90+
body: Some(body),
91+
visibility: Visibility::Public,
92+
is_async: false,
93+
is_external: false,
94+
calling_convention: CallingConvention::Default,
95+
link_name: None,
96+
annotations: vec![],
97+
effects: vec![],
98+
is_pure: false,
99+
};
100+
101+
TypedProgram {
102+
declarations: vec![typed_node(
103+
TypedDeclaration::Function(function),
104+
Type::Primitive(PrimitiveType::Unit),
105+
test_span(),
106+
)],
107+
span: test_span(),
108+
source_files: vec![],
109+
type_registry: TypeRegistry::new(),
110+
}
111+
}
112+
78113
#[test]
79114
fn test_literal_lowering() {
80115
let mut arena = test_arena();
@@ -192,6 +227,196 @@ fn test_binary_operation_lowering() {
192227
)
193228
}
194229

230+
#[test]
231+
fn test_logical_and_short_circuit_lowering() {
232+
let mut arena = test_arena();
233+
234+
let left = typed_node(
235+
TypedExpression::Literal(TypedLiteral::Bool(false)),
236+
Type::Primitive(PrimitiveType::Bool),
237+
test_span(),
238+
);
239+
let right = typed_node(
240+
TypedExpression::Literal(TypedLiteral::Bool(true)),
241+
Type::Primitive(PrimitiveType::Bool),
242+
test_span(),
243+
);
244+
245+
let expr = typed_node(
246+
TypedExpression::Binary(TypedBinary {
247+
op: BinaryOp::And,
248+
left: Box::new(left),
249+
right: Box::new(right),
250+
}),
251+
Type::Primitive(PrimitiveType::Bool),
252+
test_span(),
253+
);
254+
255+
let return_stmt = typed_node(
256+
TypedStatement::Return(Some(Box::new(expr))),
257+
Type::Primitive(PrimitiveType::Unit),
258+
test_span(),
259+
);
260+
let body = TypedBlock {
261+
statements: vec![return_stmt],
262+
span: test_span(),
263+
};
264+
265+
let mut program = create_bool_test_program(&mut arena, "test_and_short_circuit", body);
266+
267+
let type_registry = Arc::new(TypeRegistry::new());
268+
let config = LoweringConfig::default();
269+
let module_name = arena.intern_string("test_module");
270+
let arena = Arc::new(Mutex::new(arena));
271+
let mut ctx = LoweringContext::new(module_name, type_registry, arena, config);
272+
let result = ctx.lower_program(&mut program);
273+
assert!(
274+
result.is_ok(),
275+
"Failed to lower logical and expression: {:?}",
276+
result.err()
277+
);
278+
279+
let module = result.unwrap();
280+
let func = module.functions.values().next().unwrap();
281+
assert!(
282+
func.blocks.len() >= 4,
283+
"Expected short-circuit CFG blocks for &&, got {}",
284+
func.blocks.len()
285+
);
286+
287+
// Ensure logical && is not lowered as a plain Binary And instruction.
288+
let has_plain_and = func.blocks.values().any(|block| {
289+
block.instructions.iter().any(|inst| {
290+
matches!(
291+
inst,
292+
HirInstruction::Binary {
293+
op: zyntax_compiler::hir::BinaryOp::And,
294+
..
295+
}
296+
)
297+
})
298+
});
299+
assert!(
300+
!has_plain_and,
301+
"&& should use short-circuit CFG, not binary And"
302+
);
303+
304+
let phi = func
305+
.blocks
306+
.values()
307+
.flat_map(|b| b.phis.iter())
308+
.find(|p| p.incoming.len() == 2)
309+
.expect("Expected merge phi for && short-circuit");
310+
311+
let has_false_short_path = phi.incoming.iter().any(|(val, _)| {
312+
matches!(
313+
func.values.get(val).map(|v| &v.kind),
314+
Some(zyntax_compiler::hir::HirValueKind::Constant(
315+
zyntax_compiler::hir::HirConstant::Bool(false)
316+
))
317+
)
318+
});
319+
assert!(
320+
has_false_short_path,
321+
"&& phi should contain constant false short-circuit path"
322+
);
323+
}
324+
325+
#[test]
326+
fn test_logical_or_short_circuit_lowering() {
327+
let mut arena = test_arena();
328+
329+
let left = typed_node(
330+
TypedExpression::Literal(TypedLiteral::Bool(true)),
331+
Type::Primitive(PrimitiveType::Bool),
332+
test_span(),
333+
);
334+
let right = typed_node(
335+
TypedExpression::Literal(TypedLiteral::Bool(false)),
336+
Type::Primitive(PrimitiveType::Bool),
337+
test_span(),
338+
);
339+
340+
let expr = typed_node(
341+
TypedExpression::Binary(TypedBinary {
342+
op: BinaryOp::Or,
343+
left: Box::new(left),
344+
right: Box::new(right),
345+
}),
346+
Type::Primitive(PrimitiveType::Bool),
347+
test_span(),
348+
);
349+
350+
let return_stmt = typed_node(
351+
TypedStatement::Return(Some(Box::new(expr))),
352+
Type::Primitive(PrimitiveType::Unit),
353+
test_span(),
354+
);
355+
let body = TypedBlock {
356+
statements: vec![return_stmt],
357+
span: test_span(),
358+
};
359+
360+
let mut program = create_bool_test_program(&mut arena, "test_or_short_circuit", body);
361+
362+
let type_registry = Arc::new(TypeRegistry::new());
363+
let config = LoweringConfig::default();
364+
let module_name = arena.intern_string("test_module");
365+
let arena = Arc::new(Mutex::new(arena));
366+
let mut ctx = LoweringContext::new(module_name, type_registry, arena, config);
367+
let result = ctx.lower_program(&mut program);
368+
assert!(
369+
result.is_ok(),
370+
"Failed to lower logical or expression: {:?}",
371+
result.err()
372+
);
373+
374+
let module = result.unwrap();
375+
let func = module.functions.values().next().unwrap();
376+
assert!(
377+
func.blocks.len() >= 4,
378+
"Expected short-circuit CFG blocks for ||, got {}",
379+
func.blocks.len()
380+
);
381+
382+
// Ensure logical || is not lowered as a plain Binary Or instruction.
383+
let has_plain_or = func.blocks.values().any(|block| {
384+
block.instructions.iter().any(|inst| {
385+
matches!(
386+
inst,
387+
HirInstruction::Binary {
388+
op: zyntax_compiler::hir::BinaryOp::Or,
389+
..
390+
}
391+
)
392+
})
393+
});
394+
assert!(
395+
!has_plain_or,
396+
"|| should use short-circuit CFG, not binary Or"
397+
);
398+
399+
let phi = func
400+
.blocks
401+
.values()
402+
.flat_map(|b| b.phis.iter())
403+
.find(|p| p.incoming.len() == 2)
404+
.expect("Expected merge phi for || short-circuit");
405+
406+
let has_true_short_path = phi.incoming.iter().any(|(val, _)| {
407+
matches!(
408+
func.values.get(val).map(|v| &v.kind),
409+
Some(zyntax_compiler::hir::HirValueKind::Constant(
410+
zyntax_compiler::hir::HirConstant::Bool(true)
411+
))
412+
)
413+
});
414+
assert!(
415+
has_true_short_path,
416+
"|| phi should contain constant true short-circuit path"
417+
);
418+
}
419+
195420
#[test]
196421
fn test_unary_operation_lowering() {
197422
let mut arena = test_arena();

0 commit comments

Comments
 (0)