Skip to content

Commit 5d880d7

Browse files
committed
Harden MatMul trait dispatch and add lowering regressions
1 parent 6853e1f commit 5d880d7

2 files changed

Lines changed: 310 additions & 42 deletions

File tree

crates/compiler/src/ssa.rs

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4272,23 +4272,23 @@ impl SsaBuilder {
42724272
return Ok(None);
42734273
}
42744274

4275-
// Get the method name for this operator
4276-
let method_name = match op {
4277-
FrontendOp::Add => "add",
4278-
FrontendOp::Sub => "sub",
4279-
FrontendOp::Mul => "mul",
4280-
FrontendOp::MatMul => "matmul",
4281-
FrontendOp::Div => "div",
4282-
FrontendOp::Rem => "mod",
4283-
FrontendOp::Eq => "eq",
4284-
FrontendOp::Ne => "ne",
4285-
FrontendOp::Lt => "lt",
4286-
FrontendOp::Le => "le",
4287-
FrontendOp::Gt => "gt",
4288-
FrontendOp::Ge => "ge",
4289-
FrontendOp::BitAnd => "bitand",
4290-
FrontendOp::BitOr => "bitor",
4291-
FrontendOp::BitXor => "bitxor",
4275+
// Get the method and trait names for this operator.
4276+
let (method_name, trait_name) = match op {
4277+
FrontendOp::Add => ("add", "Add"),
4278+
FrontendOp::Sub => ("sub", "Sub"),
4279+
FrontendOp::Mul => ("mul", "Mul"),
4280+
FrontendOp::MatMul => ("matmul", "MatMul"),
4281+
FrontendOp::Div => ("div", "Div"),
4282+
FrontendOp::Rem => ("mod", "Mod"),
4283+
FrontendOp::Eq => ("eq", "Eq"),
4284+
FrontendOp::Ne => ("ne", "Eq"),
4285+
FrontendOp::Lt => ("lt", "Ord"),
4286+
FrontendOp::Le => ("le", "Ord"),
4287+
FrontendOp::Gt => ("gt", "Ord"),
4288+
FrontendOp::Ge => ("ge", "Ord"),
4289+
FrontendOp::BitAnd => ("bitand", "BitAnd"),
4290+
FrontendOp::BitOr => ("bitor", "BitOr"),
4291+
FrontendOp::BitXor => ("bitxor", "BitXor"),
42924292
_ => return Ok(None), // No trait method for this operator
42934293
};
42944294

@@ -4299,35 +4299,53 @@ impl SsaBuilder {
42994299
}
43004300
let type_name = type_name.unwrap();
43014301

4302-
// Construct the method function name: TypeName$method
4303-
// For extern types, the type_name already starts with '$' (e.g., "$Tensor")
4304-
// so we just append $method to get "$Tensor$add"
4305-
let method_symbol = if type_name.starts_with('$') {
4306-
// Already has $ prefix (extern type)
4302+
// Build candidate names:
4303+
// 1) Type$method (standalone helper)
4304+
// 2) Type$Trait$method (impl-lowered method)
4305+
// 3) $Type$method (runtime symbol for extern-backed types)
4306+
let mut function_candidates = Vec::with_capacity(2);
4307+
let runtime_symbol = if let Some(stripped) = type_name.strip_prefix('$') {
4308+
function_candidates.push(format!("{}${}", stripped, method_name));
4309+
function_candidates.push(format!("{}${}${}", stripped, trait_name, method_name));
43074310
format!("{}${}", type_name, method_name)
43084311
} else {
4309-
// Regular type, add $ prefix for ZRTL compatibility
4312+
function_candidates.push(format!("{}${}", type_name, method_name));
4313+
function_candidates.push(format!("{}${}${}", type_name, trait_name, method_name));
43104314
format!("${}${}", type_name, method_name)
43114315
};
4312-
let method_name_interned = InternedString::new_global(&method_symbol);
4313-
log::debug!(
4314-
"[SSA] Operator trait dispatch: {} for type {}",
4315-
method_symbol,
4316-
type_name
4317-
);
43184316

4319-
// Try to look up the function ID for this method
4320-
// If it's a compiled ZynML function (like Duration$add), we need HirCallable::Function
4321-
// If it's an external plugin function (like $Tensor$add), we need HirCallable::Symbol
4322-
let callee = if let Some(&func_id) = self.function_symbols.get(&method_name_interned) {
4323-
log::debug!("[SSA] Found function ID for {}", method_symbol);
4317+
let mut matched_function: Option<(HirId, String)> = None;
4318+
for candidate in &function_candidates {
4319+
let candidate_interned = InternedString::new_global(candidate);
4320+
if let Some(&func_id) = self.function_symbols.get(&candidate_interned) {
4321+
matched_function = Some((func_id, candidate.clone()));
4322+
break;
4323+
}
4324+
}
4325+
4326+
let callee = if let Some((func_id, matched_name)) = matched_function {
4327+
log::debug!(
4328+
"[SSA] Operator trait dispatch: using function '{}' for type {}",
4329+
matched_name,
4330+
type_name
4331+
);
43244332
crate::hir::HirCallable::Function(func_id)
4325-
} else {
4333+
} else if type_name.starts_with('$') {
4334+
// Extern-backed types dispatch to runtime symbols.
43264335
log::debug!(
4327-
"[SSA] No function ID found, using external symbol for {}",
4328-
method_symbol
4336+
"[SSA] Operator trait dispatch: using runtime symbol '{}' for type {}",
4337+
runtime_symbol,
4338+
type_name
43294339
);
4330-
crate::hir::HirCallable::Symbol(method_symbol)
4340+
crate::hir::HirCallable::Symbol(runtime_symbol)
4341+
} else if matches!(op, FrontendOp::MatMul) {
4342+
// MatMul must not silently fall back.
4343+
return Err(crate::CompilerError::Analysis(format!(
4344+
"matrix multiplication '@' requires MatMul::matmul implementation for type {:?} (expected '{}' or '{}')",
4345+
left_type, function_candidates[0], function_candidates[1]
4346+
)));
4347+
} else {
4348+
return Ok(None);
43314349
};
43324350

43334351
// Translate arguments

crates/compiler/tests/expression_lowering_tests.rs

Lines changed: 254 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
2424
use std::sync::{Arc, Mutex};
2525
use zyntax_compiler::{
26-
hir::{HirInstruction, HirTerminator},
26+
hir::{HirCallable, HirInstruction, HirTerminator},
2727
lowering::{AstLowering, LoweringConfig, LoweringContext},
2828
};
2929
use zyntax_typed_ast::{
3030
arena::AstArena,
31-
typed_ast::{TypedBinary, TypedBlock, TypedIfExpr, TypedLet, TypedUnary},
31+
typed_ast::{ParameterKind, TypedBinary, TypedBlock, TypedIfExpr, TypedLet, TypedUnary},
3232
typed_node, BinaryOp, CallingConvention, Mutability, PrimitiveType, Span, Type, TypeRegistry,
33-
TypedDeclaration, TypedExpression, TypedFunction, TypedLiteral, TypedProgram, TypedStatement,
34-
UnaryOp, Visibility,
33+
TypedDeclaration, TypedExpression, TypedFunction, TypedLiteral, TypedParameter, TypedProgram,
34+
TypedStatement, UnaryOp, Visibility,
3535
};
3636

3737
/// Helper to create a test arena
@@ -44,6 +44,19 @@ fn test_span() -> Span {
4444
Span::new(0, 10)
4545
}
4646

47+
struct SkipTypeCheckGuard;
48+
49+
impl Drop for SkipTypeCheckGuard {
50+
fn drop(&mut self) {
51+
std::env::remove_var("SKIP_TYPE_CHECK");
52+
}
53+
}
54+
55+
fn skip_type_check() -> SkipTypeCheckGuard {
56+
std::env::set_var("SKIP_TYPE_CHECK", "1");
57+
SkipTypeCheckGuard
58+
}
59+
4760
/// Helper to create a simple typed program with one function
4861
fn create_test_program(arena: &mut AstArena, func_name: &str, body: TypedBlock) -> TypedProgram {
4962
let name = arena.intern_string(func_name);
@@ -417,6 +430,243 @@ fn test_logical_or_short_circuit_lowering() {
417430
);
418431
}
419432

433+
#[test]
434+
fn test_matmul_dispatch_uses_named_type_function() {
435+
let _skip_type_check = skip_type_check();
436+
let mut arena = test_arena();
437+
let mut type_registry = TypeRegistry::new();
438+
439+
let mat_name = arena.intern_string("Mat");
440+
let mat_id = type_registry.register_struct_type(
441+
mat_name,
442+
vec![],
443+
vec![],
444+
vec![],
445+
vec![],
446+
zyntax_typed_ast::TypeMetadata::default(),
447+
test_span(),
448+
);
449+
let mat_ty = Type::Named {
450+
id: mat_id,
451+
type_args: vec![],
452+
const_args: vec![],
453+
variance: vec![],
454+
nullability: zyntax_typed_ast::NullabilityKind::NonNull,
455+
};
456+
457+
let lhs_name = arena.intern_string("lhs");
458+
let rhs_name = arena.intern_string("rhs");
459+
460+
let matmul_impl = TypedFunction {
461+
name: arena.intern_string("Mat$matmul"),
462+
params: vec![
463+
TypedParameter {
464+
name: lhs_name,
465+
ty: mat_ty.clone(),
466+
mutability: Mutability::Immutable,
467+
kind: ParameterKind::Regular,
468+
default_value: None,
469+
attributes: vec![],
470+
span: test_span(),
471+
},
472+
TypedParameter {
473+
name: rhs_name,
474+
ty: mat_ty.clone(),
475+
mutability: Mutability::Immutable,
476+
kind: ParameterKind::Regular,
477+
default_value: None,
478+
attributes: vec![],
479+
span: test_span(),
480+
},
481+
],
482+
type_params: vec![],
483+
return_type: mat_ty.clone(),
484+
body: None,
485+
visibility: Visibility::Public,
486+
is_async: false,
487+
is_external: true,
488+
calling_convention: CallingConvention::Default,
489+
link_name: None,
490+
annotations: vec![],
491+
effects: vec![],
492+
is_pure: false,
493+
};
494+
495+
let matmul_expr = typed_node(
496+
TypedExpression::Binary(TypedBinary {
497+
op: BinaryOp::MatMul,
498+
left: Box::new(typed_node(
499+
TypedExpression::Variable(lhs_name),
500+
mat_ty.clone(),
501+
test_span(),
502+
)),
503+
right: Box::new(typed_node(
504+
TypedExpression::Variable(rhs_name),
505+
mat_ty.clone(),
506+
test_span(),
507+
)),
508+
}),
509+
mat_ty.clone(),
510+
test_span(),
511+
);
512+
let entry_body = TypedBlock {
513+
statements: vec![typed_node(
514+
TypedStatement::Return(Some(Box::new(matmul_expr))),
515+
Type::Primitive(PrimitiveType::Unit),
516+
test_span(),
517+
)],
518+
span: test_span(),
519+
};
520+
let entry_fn = TypedFunction {
521+
name: arena.intern_string("entry"),
522+
params: vec![
523+
TypedParameter {
524+
name: lhs_name,
525+
ty: mat_ty.clone(),
526+
mutability: Mutability::Immutable,
527+
kind: ParameterKind::Regular,
528+
default_value: None,
529+
attributes: vec![],
530+
span: test_span(),
531+
},
532+
TypedParameter {
533+
name: rhs_name,
534+
ty: mat_ty.clone(),
535+
mutability: Mutability::Immutable,
536+
kind: ParameterKind::Regular,
537+
default_value: None,
538+
attributes: vec![],
539+
span: test_span(),
540+
},
541+
],
542+
type_params: vec![],
543+
return_type: mat_ty.clone(),
544+
body: Some(entry_body),
545+
visibility: Visibility::Public,
546+
is_async: false,
547+
is_external: false,
548+
calling_convention: CallingConvention::Default,
549+
link_name: None,
550+
annotations: vec![],
551+
effects: vec![],
552+
is_pure: false,
553+
};
554+
555+
let mut program = TypedProgram {
556+
declarations: vec![
557+
typed_node(
558+
TypedDeclaration::Function(matmul_impl),
559+
Type::Primitive(PrimitiveType::Unit),
560+
test_span(),
561+
),
562+
typed_node(
563+
TypedDeclaration::Function(entry_fn),
564+
Type::Primitive(PrimitiveType::Unit),
565+
test_span(),
566+
),
567+
],
568+
span: test_span(),
569+
source_files: vec![],
570+
type_registry: type_registry.clone(),
571+
};
572+
573+
let type_registry = Arc::new(type_registry);
574+
let config = LoweringConfig::default();
575+
let module_name = arena.intern_string("test_module");
576+
let arena = Arc::new(Mutex::new(arena));
577+
let mut ctx = LoweringContext::new(module_name, type_registry, arena, config);
578+
579+
let result = ctx.lower_program(&mut program);
580+
assert!(
581+
result.is_ok(),
582+
"Failed to lower matmul dispatch program: {:?}",
583+
result.err()
584+
);
585+
586+
let module = result.unwrap();
587+
let entry = module
588+
.functions
589+
.values()
590+
.find(|f| f.name.resolve_global().as_deref() == Some("entry"))
591+
.expect("entry function should exist");
592+
593+
let call_callee = entry
594+
.blocks
595+
.values()
596+
.flat_map(|b| b.instructions.iter())
597+
.find_map(|inst| match inst {
598+
HirInstruction::Call { callee, .. } => Some(callee),
599+
_ => None,
600+
})
601+
.expect("entry should contain a call for matmul dispatch");
602+
603+
assert!(
604+
matches!(call_callee, &HirCallable::Function(_)),
605+
"MatMul on named type should dispatch to compiled function, got {:?}",
606+
call_callee
607+
);
608+
}
609+
610+
#[test]
611+
fn test_matmul_missing_impl_reports_clear_error() {
612+
let _skip_type_check = skip_type_check();
613+
let mut arena = test_arena();
614+
615+
let left = typed_node(
616+
TypedExpression::Literal(TypedLiteral::Integer(2)),
617+
Type::Primitive(PrimitiveType::I32),
618+
test_span(),
619+
);
620+
let right = typed_node(
621+
TypedExpression::Literal(TypedLiteral::Integer(3)),
622+
Type::Primitive(PrimitiveType::I32),
623+
test_span(),
624+
);
625+
let expr = typed_node(
626+
TypedExpression::Binary(TypedBinary {
627+
op: BinaryOp::MatMul,
628+
left: Box::new(left),
629+
right: Box::new(right),
630+
}),
631+
Type::Primitive(PrimitiveType::I32),
632+
test_span(),
633+
);
634+
let body = TypedBlock {
635+
statements: vec![typed_node(
636+
TypedStatement::Return(Some(Box::new(expr))),
637+
Type::Primitive(PrimitiveType::Unit),
638+
test_span(),
639+
)],
640+
span: test_span(),
641+
};
642+
643+
let mut program = create_test_program(&mut arena, "matmul_missing_impl", body);
644+
645+
let type_registry = Arc::new(TypeRegistry::new());
646+
let config = LoweringConfig::default();
647+
let module_name = arena.intern_string("test_module");
648+
let arena = Arc::new(Mutex::new(arena));
649+
let mut ctx = LoweringContext::new(module_name, type_registry, arena, config);
650+
651+
let result = ctx.lower_program(&mut program);
652+
assert!(
653+
result.is_ok(),
654+
"Lowering should complete while skipping invalid functions: {:?}",
655+
result.err()
656+
);
657+
658+
let module = result.unwrap();
659+
let matmul_fn_present = module
660+
.functions
661+
.values()
662+
.any(|f| f.name.resolve_global().as_deref() == Some("matmul_missing_impl"));
663+
assert!(
664+
!matmul_fn_present,
665+
"Invalid matmul function should be dropped from lowered module"
666+
);
667+
668+
}
669+
420670
#[test]
421671
fn test_unary_operation_lowering() {
422672
let mut arena = test_arena();

0 commit comments

Comments
 (0)