Skip to content

Commit 9badd2e

Browse files
committed
Complete Grammar2 semantic pass for imports, compute, and matmul dispatch
1 parent dfdcda1 commit 9badd2e

7 files changed

Lines changed: 589 additions & 67 deletions

File tree

crates/compiler/src/ssa.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,10 @@ impl SsaBuilder {
13631363
// Evaluate expression for side effects
13641364
self.translate_expression(block_id, expr)?;
13651365
}
1366+
TypedStatement::Yield(expr) => {
1367+
// Yield behaves like an expression statement at the SSA level.
1368+
self.translate_expression(block_id, expr)?;
1369+
}
13661370

13671371
TypedStatement::Match(match_stmt) => {
13681372
// Handle match statement: evaluate scrutinee
@@ -1655,6 +1659,14 @@ impl SsaBuilder {
16551659
}
16561660
eprintln!("[DEBUG SSA] No trait dispatch, using native binary op");
16571661

1662+
// `@` must dispatch through MatMul::matmul; do not silently alias to numeric `mul`.
1663+
if matches!(op, FrontendOp::MatMul) {
1664+
return Err(crate::CompilerError::Analysis(format!(
1665+
"matrix multiplication '@' requires MatMul::matmul implementation for lhs type {:?}",
1666+
left_with_type.ty
1667+
)));
1668+
}
1669+
16581670
// Regular binary operations for primitive types
16591671
let left_val = self.translate_expression(block_id, left)?;
16601672
let right_val = self.translate_expression(block_id, right)?;
@@ -1836,6 +1848,28 @@ impl SsaBuilder {
18361848
Ok(result_or_void)
18371849
}
18381850

1851+
TypedExpression::Compute(compute) => {
1852+
// Lower compute expressions through the existing call pipeline for now.
1853+
// This preserves call-based execution while keeping typed compute structure.
1854+
let compute_name = InternedString::new_global("compute");
1855+
let lowered = zyntax_typed_ast::typed_ast::TypedCall {
1856+
callee: Box::new(zyntax_typed_ast::typed_node(
1857+
TypedExpression::Variable(compute_name),
1858+
Type::Any,
1859+
expr.span,
1860+
)),
1861+
positional_args: compute.args.clone(),
1862+
named_args: vec![],
1863+
type_args: vec![],
1864+
};
1865+
let lowered_expr = zyntax_typed_ast::typed_node(
1866+
TypedExpression::Call(lowered),
1867+
expr.ty.clone(),
1868+
expr.span,
1869+
);
1870+
self.translate_expression(block_id, &lowered_expr)
1871+
}
1872+
18391873
TypedExpression::Field(field_access) => {
18401874
let object = &field_access.object;
18411875
let field = &field_access.field;
@@ -4097,6 +4131,7 @@ impl SsaBuilder {
40974131
FrontendOp::Add => HirOp::Add,
40984132
FrontendOp::Sub => HirOp::Sub,
40994133
FrontendOp::Mul => HirOp::Mul,
4134+
FrontendOp::MatMul => HirOp::Mul,
41004135
FrontendOp::Div => HirOp::Div,
41014136
FrontendOp::Rem => HirOp::Rem,
41024137
FrontendOp::BitAnd => HirOp::And,
@@ -4150,6 +4185,7 @@ impl SsaBuilder {
41504185
FrontendOp::Add => "add",
41514186
FrontendOp::Sub => "sub",
41524187
FrontendOp::Mul => "mul",
4188+
FrontendOp::MatMul => "matmul",
41534189
FrontendOp::Div => "div",
41544190
FrontendOp::Rem => "mod",
41554191
FrontendOp::Eq => "eq",
@@ -6206,6 +6242,12 @@ impl SsaBuilder {
62066242
result_ty,
62076243
)?;
62086244

6245+
if matches!(bin.op, zyntax_typed_ast::typed_ast::BinaryOp::MatMul) {
6246+
return Err(crate::CompilerError::Analysis(
6247+
"matrix multiplication '@' requires trait dispatch and is not supported in lambda const lowering".to_string(),
6248+
));
6249+
}
6250+
62096251
let hir_op = self.convert_binary_op(&bin.op);
62106252
let result_id = HirId::new();
62116253
func.values.insert(

crates/typed_ast/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ pub use typed_ast::{
156156
TypedCast,
157157
// Class/Struct/Enum types
158158
TypedClass,
159+
TypedComputeExpr,
160+
TypedComputeModifier,
159161
TypedDeclaration,
160162
// Defer type
161163
TypedDefer,
@@ -191,6 +193,7 @@ pub use typed_ast::{
191193
TypedImportModifier,
192194
TypedIndex,
193195
TypedInterface,
196+
TypedKernelAttr,
194197
// Lambda types
195198
TypedLambda,
196199
TypedLambdaBody,
@@ -209,6 +212,7 @@ pub use typed_ast::{
209212
TypedMethodCall,
210213
TypedMethodParam,
211214
TypedModule,
215+
TypedNamedArg,
212216
TypedNode,
213217
// Additional types for parser generation
214218
TypedParameter,

crates/typed_ast/src/type_checker.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,13 @@ impl TypeChecker {
10961096

10971097
Ok(Type::Never)
10981098
}
1099+
TypedStatement::Yield(expr_node) => {
1100+
// Yield is currently validated as an expression-producing statement.
1101+
// Context-specific validation (e.g., only within compute reductions)
1102+
// happens in lowering/runtime passes.
1103+
self.check_expression(&expr_node.node)?;
1104+
Ok(Type::Primitive(PrimitiveType::Unit))
1105+
}
10991106
TypedStatement::If(if_stmt) => self.check_if_statement(if_stmt),
11001107
TypedStatement::While(while_stmt) => {
11011108
self.check_while_statement(while_stmt)?;
@@ -1243,7 +1250,8 @@ impl TypeChecker {
12431250
| TypedExpression::Range(_)
12441251
| TypedExpression::Block(_)
12451252
| TypedExpression::ListComprehension(_)
1246-
| TypedExpression::Slice(_) => {
1253+
| TypedExpression::Slice(_)
1254+
| TypedExpression::Compute(_) => {
12471255
// Placeholder for now
12481256
Ok(self.inference.fresh_type_var())
12491257
}
@@ -1281,7 +1289,7 @@ impl TypeChecker {
12811289
use BinaryOp::*;
12821290
match bin.op {
12831291
// Arithmetic operators
1284-
Add | Sub | Mul | Div | Rem => {
1292+
Add | Sub | Mul | MatMul | Div | Rem => {
12851293
self.inference.unify(left_ty.clone(), right_ty)?;
12861294
Ok(left_ty)
12871295
}

crates/typed_ast/src/typed_ast.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ pub enum TypedStatement {
430430
/// Let with pattern destructuring: let (x, y) = expr
431431
LetPattern(TypedLetPattern),
432432
Return(Option<Box<TypedNode<TypedExpression>>>),
433+
/// Yield from compute/reduction contexts
434+
Yield(Box<TypedNode<TypedExpression>>),
433435
If(TypedIf),
434436
While(TypedWhile),
435437
Block(TypedBlock),
@@ -524,6 +526,8 @@ pub enum TypedExpression {
524526
Slice(TypedSlice),
525527
/// Import modifier expression: import loader("path") as Type
526528
ImportModifier(TypedImportModifier),
529+
/// Compute expression: compute(args) @modifier { ... }
530+
Compute(TypedComputeExpr),
527531
/// Path expression: Type::method or module::function
528532
Path(TypedPath),
529533
}
@@ -565,6 +569,7 @@ pub enum BinaryOp {
565569
Add,
566570
Sub,
567571
Mul,
572+
MatMul,
568573
Div,
569574
Rem,
570575
// Comparison
@@ -986,6 +991,41 @@ pub struct TypedPath {
986991
pub segments: Vec<InternedString>,
987992
}
988993

994+
/// Compute expression: `compute(args...) @modifier { body }`
995+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
996+
pub struct TypedComputeExpr {
997+
/// Positional compute inputs
998+
pub args: Vec<TypedNode<TypedExpression>>,
999+
/// Compute-level modifiers (e.g., @device("cpu"), @async)
1000+
#[serde(default)]
1001+
pub modifiers: Vec<TypedComputeModifier>,
1002+
/// Kernel-specific attributes (e.g., @kernel(matmul), @workgroup(16,16))
1003+
#[serde(default)]
1004+
pub kernel_attrs: Vec<TypedKernelAttr>,
1005+
/// Compute body block
1006+
pub body: TypedBlock,
1007+
}
1008+
1009+
/// Generic compute modifier
1010+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1011+
pub struct TypedComputeModifier {
1012+
pub name: InternedString,
1013+
#[serde(default)]
1014+
pub positional_args: Vec<TypedNode<TypedExpression>>,
1015+
#[serde(default)]
1016+
pub named_args: Vec<TypedNamedArg>,
1017+
}
1018+
1019+
/// Kernel-specific attribute
1020+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1021+
pub struct TypedKernelAttr {
1022+
pub name: InternedString,
1023+
#[serde(default)]
1024+
pub positional_args: Vec<TypedNode<TypedExpression>>,
1025+
#[serde(default)]
1026+
pub named_args: Vec<TypedNamedArg>,
1027+
}
1028+
9891029
/// Method call with enhanced argument support
9901030
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9911031
pub struct TypedMethodCall {

0 commit comments

Comments
 (0)