2323
2424use std:: sync:: { Arc , Mutex } ;
2525use zyntax_compiler:: {
26- hir:: { HirInstruction , HirTerminator } ,
26+ hir:: { HirCallable , HirInstruction , HirTerminator } ,
2727 lowering:: { AstLowering , LoweringConfig , LoweringContext } ,
2828} ;
2929use zyntax_typed_ast:: {
3030 arena:: AstArena ,
31- typed_ast:: { TypedBinary , TypedBlock , TypedIfExpr , TypedLet , TypedUnary } ,
31+ typed_ast:: { ParameterKind , TypedBinary , TypedBlock , TypedIfExpr , TypedLet , TypedUnary } ,
3232 typed_node, BinaryOp , CallingConvention , Mutability , PrimitiveType , Span , Type , TypeRegistry ,
33- TypedDeclaration , TypedExpression , TypedFunction , TypedLiteral , TypedProgram , TypedStatement ,
34- UnaryOp , Visibility ,
33+ TypedDeclaration , TypedExpression , TypedFunction , TypedLiteral , TypedParameter , TypedProgram ,
34+ TypedStatement , UnaryOp , Visibility ,
3535} ;
3636
3737/// Helper to create a test arena
@@ -44,6 +44,19 @@ fn test_span() -> Span {
4444 Span :: new ( 0 , 10 )
4545}
4646
47+ struct SkipTypeCheckGuard ;
48+
49+ impl Drop for SkipTypeCheckGuard {
50+ fn drop ( & mut self ) {
51+ std:: env:: remove_var ( "SKIP_TYPE_CHECK" ) ;
52+ }
53+ }
54+
55+ fn skip_type_check ( ) -> SkipTypeCheckGuard {
56+ std:: env:: set_var ( "SKIP_TYPE_CHECK" , "1" ) ;
57+ SkipTypeCheckGuard
58+ }
59+
4760/// Helper to create a simple typed program with one function
4861fn create_test_program ( arena : & mut AstArena , func_name : & str , body : TypedBlock ) -> TypedProgram {
4962 let name = arena. intern_string ( func_name) ;
@@ -417,6 +430,243 @@ fn test_logical_or_short_circuit_lowering() {
417430 ) ;
418431}
419432
433+ #[ test]
434+ fn test_matmul_dispatch_uses_named_type_function ( ) {
435+ let _skip_type_check = skip_type_check ( ) ;
436+ let mut arena = test_arena ( ) ;
437+ let mut type_registry = TypeRegistry :: new ( ) ;
438+
439+ let mat_name = arena. intern_string ( "Mat" ) ;
440+ let mat_id = type_registry. register_struct_type (
441+ mat_name,
442+ vec ! [ ] ,
443+ vec ! [ ] ,
444+ vec ! [ ] ,
445+ vec ! [ ] ,
446+ zyntax_typed_ast:: TypeMetadata :: default ( ) ,
447+ test_span ( ) ,
448+ ) ;
449+ let mat_ty = Type :: Named {
450+ id : mat_id,
451+ type_args : vec ! [ ] ,
452+ const_args : vec ! [ ] ,
453+ variance : vec ! [ ] ,
454+ nullability : zyntax_typed_ast:: NullabilityKind :: NonNull ,
455+ } ;
456+
457+ let lhs_name = arena. intern_string ( "lhs" ) ;
458+ let rhs_name = arena. intern_string ( "rhs" ) ;
459+
460+ let matmul_impl = TypedFunction {
461+ name : arena. intern_string ( "Mat$matmul" ) ,
462+ params : vec ! [
463+ TypedParameter {
464+ name: lhs_name,
465+ ty: mat_ty. clone( ) ,
466+ mutability: Mutability :: Immutable ,
467+ kind: ParameterKind :: Regular ,
468+ default_value: None ,
469+ attributes: vec![ ] ,
470+ span: test_span( ) ,
471+ } ,
472+ TypedParameter {
473+ name: rhs_name,
474+ ty: mat_ty. clone( ) ,
475+ mutability: Mutability :: Immutable ,
476+ kind: ParameterKind :: Regular ,
477+ default_value: None ,
478+ attributes: vec![ ] ,
479+ span: test_span( ) ,
480+ } ,
481+ ] ,
482+ type_params : vec ! [ ] ,
483+ return_type : mat_ty. clone ( ) ,
484+ body : None ,
485+ visibility : Visibility :: Public ,
486+ is_async : false ,
487+ is_external : true ,
488+ calling_convention : CallingConvention :: Default ,
489+ link_name : None ,
490+ annotations : vec ! [ ] ,
491+ effects : vec ! [ ] ,
492+ is_pure : false ,
493+ } ;
494+
495+ let matmul_expr = typed_node (
496+ TypedExpression :: Binary ( TypedBinary {
497+ op : BinaryOp :: MatMul ,
498+ left : Box :: new ( typed_node (
499+ TypedExpression :: Variable ( lhs_name) ,
500+ mat_ty. clone ( ) ,
501+ test_span ( ) ,
502+ ) ) ,
503+ right : Box :: new ( typed_node (
504+ TypedExpression :: Variable ( rhs_name) ,
505+ mat_ty. clone ( ) ,
506+ test_span ( ) ,
507+ ) ) ,
508+ } ) ,
509+ mat_ty. clone ( ) ,
510+ test_span ( ) ,
511+ ) ;
512+ let entry_body = TypedBlock {
513+ statements : vec ! [ typed_node(
514+ TypedStatement :: Return ( Some ( Box :: new( matmul_expr) ) ) ,
515+ Type :: Primitive ( PrimitiveType :: Unit ) ,
516+ test_span( ) ,
517+ ) ] ,
518+ span : test_span ( ) ,
519+ } ;
520+ let entry_fn = TypedFunction {
521+ name : arena. intern_string ( "entry" ) ,
522+ params : vec ! [
523+ TypedParameter {
524+ name: lhs_name,
525+ ty: mat_ty. clone( ) ,
526+ mutability: Mutability :: Immutable ,
527+ kind: ParameterKind :: Regular ,
528+ default_value: None ,
529+ attributes: vec![ ] ,
530+ span: test_span( ) ,
531+ } ,
532+ TypedParameter {
533+ name: rhs_name,
534+ ty: mat_ty. clone( ) ,
535+ mutability: Mutability :: Immutable ,
536+ kind: ParameterKind :: Regular ,
537+ default_value: None ,
538+ attributes: vec![ ] ,
539+ span: test_span( ) ,
540+ } ,
541+ ] ,
542+ type_params : vec ! [ ] ,
543+ return_type : mat_ty. clone ( ) ,
544+ body : Some ( entry_body) ,
545+ visibility : Visibility :: Public ,
546+ is_async : false ,
547+ is_external : false ,
548+ calling_convention : CallingConvention :: Default ,
549+ link_name : None ,
550+ annotations : vec ! [ ] ,
551+ effects : vec ! [ ] ,
552+ is_pure : false ,
553+ } ;
554+
555+ let mut program = TypedProgram {
556+ declarations : vec ! [
557+ typed_node(
558+ TypedDeclaration :: Function ( matmul_impl) ,
559+ Type :: Primitive ( PrimitiveType :: Unit ) ,
560+ test_span( ) ,
561+ ) ,
562+ typed_node(
563+ TypedDeclaration :: Function ( entry_fn) ,
564+ Type :: Primitive ( PrimitiveType :: Unit ) ,
565+ test_span( ) ,
566+ ) ,
567+ ] ,
568+ span : test_span ( ) ,
569+ source_files : vec ! [ ] ,
570+ type_registry : type_registry. clone ( ) ,
571+ } ;
572+
573+ let type_registry = Arc :: new ( type_registry) ;
574+ let config = LoweringConfig :: default ( ) ;
575+ let module_name = arena. intern_string ( "test_module" ) ;
576+ let arena = Arc :: new ( Mutex :: new ( arena) ) ;
577+ let mut ctx = LoweringContext :: new ( module_name, type_registry, arena, config) ;
578+
579+ let result = ctx. lower_program ( & mut program) ;
580+ assert ! (
581+ result. is_ok( ) ,
582+ "Failed to lower matmul dispatch program: {:?}" ,
583+ result. err( )
584+ ) ;
585+
586+ let module = result. unwrap ( ) ;
587+ let entry = module
588+ . functions
589+ . values ( )
590+ . find ( |f| f. name . resolve_global ( ) . as_deref ( ) == Some ( "entry" ) )
591+ . expect ( "entry function should exist" ) ;
592+
593+ let call_callee = entry
594+ . blocks
595+ . values ( )
596+ . flat_map ( |b| b. instructions . iter ( ) )
597+ . find_map ( |inst| match inst {
598+ HirInstruction :: Call { callee, .. } => Some ( callee) ,
599+ _ => None ,
600+ } )
601+ . expect ( "entry should contain a call for matmul dispatch" ) ;
602+
603+ assert ! (
604+ matches!( call_callee, & HirCallable :: Function ( _) ) ,
605+ "MatMul on named type should dispatch to compiled function, got {:?}" ,
606+ call_callee
607+ ) ;
608+ }
609+
610+ #[ test]
611+ fn test_matmul_missing_impl_reports_clear_error ( ) {
612+ let _skip_type_check = skip_type_check ( ) ;
613+ let mut arena = test_arena ( ) ;
614+
615+ let left = typed_node (
616+ TypedExpression :: Literal ( TypedLiteral :: Integer ( 2 ) ) ,
617+ Type :: Primitive ( PrimitiveType :: I32 ) ,
618+ test_span ( ) ,
619+ ) ;
620+ let right = typed_node (
621+ TypedExpression :: Literal ( TypedLiteral :: Integer ( 3 ) ) ,
622+ Type :: Primitive ( PrimitiveType :: I32 ) ,
623+ test_span ( ) ,
624+ ) ;
625+ let expr = typed_node (
626+ TypedExpression :: Binary ( TypedBinary {
627+ op : BinaryOp :: MatMul ,
628+ left : Box :: new ( left) ,
629+ right : Box :: new ( right) ,
630+ } ) ,
631+ Type :: Primitive ( PrimitiveType :: I32 ) ,
632+ test_span ( ) ,
633+ ) ;
634+ let body = TypedBlock {
635+ statements : vec ! [ typed_node(
636+ TypedStatement :: Return ( Some ( Box :: new( expr) ) ) ,
637+ Type :: Primitive ( PrimitiveType :: Unit ) ,
638+ test_span( ) ,
639+ ) ] ,
640+ span : test_span ( ) ,
641+ } ;
642+
643+ let mut program = create_test_program ( & mut arena, "matmul_missing_impl" , body) ;
644+
645+ let type_registry = Arc :: new ( TypeRegistry :: new ( ) ) ;
646+ let config = LoweringConfig :: default ( ) ;
647+ let module_name = arena. intern_string ( "test_module" ) ;
648+ let arena = Arc :: new ( Mutex :: new ( arena) ) ;
649+ let mut ctx = LoweringContext :: new ( module_name, type_registry, arena, config) ;
650+
651+ let result = ctx. lower_program ( & mut program) ;
652+ assert ! (
653+ result. is_ok( ) ,
654+ "Lowering should complete while skipping invalid functions: {:?}" ,
655+ result. err( )
656+ ) ;
657+
658+ let module = result. unwrap ( ) ;
659+ let matmul_fn_present = module
660+ . functions
661+ . values ( )
662+ . any ( |f| f. name . resolve_global ( ) . as_deref ( ) == Some ( "matmul_missing_impl" ) ) ;
663+ assert ! (
664+ !matmul_fn_present,
665+ "Invalid matmul function should be dropped from lowered module"
666+ ) ;
667+
668+ }
669+
420670#[ test]
421671fn test_unary_operation_lowering ( ) {
422672 let mut arena = test_arena ( ) ;
0 commit comments