Skip to content

Commit cd79201

Browse files
committed
Preserve tensor type metadata and alias targets in Grammar2
1 parent 32999cc commit cd79201

4 files changed

Lines changed: 227 additions & 25 deletions

File tree

crates/zyn_peg/src/runtime2/interpreter.rs

Lines changed: 142 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ use zyntax_typed_ast::typed_ast::{
1818
};
1919
use zyntax_typed_ast::{
2020
type_registry::{
21-
CallingConvention, ConstValue, Mutability, NullabilityKind, PrimitiveType, Type, Visibility,
21+
CallingConvention, ConstValue, Mutability, NullabilityKind, PrimitiveType, Type,
22+
TypeMetadata, Visibility,
2223
},
2324
typed_node, ParameterKind, Span, TypedAnnotation, TypedAnnotationArg, TypedAnnotationValue,
2425
TypedBlock, TypedCall, TypedDeclaration, TypedExpression, TypedExtern, TypedExternStruct,
@@ -1147,12 +1148,14 @@ impl<'g> GrammarInterpreter<'g> {
11471148
span: Span,
11481149
) -> Result<ParsedValue, String> {
11491150
let declarations = self.get_field_as_decl_list("declarations", fields, state)?;
1151+
// Preserve any placeholder or inferred type entries created while parsing.
1152+
let type_registry = state.type_registry().clone();
11501153

11511154
Ok(ParsedValue::Program(Box::new(TypedProgram {
11521155
declarations,
11531156
span,
11541157
source_files: vec![],
1155-
type_registry: zyntax_typed_ast::type_registry::TypeRegistry::new(),
1158+
type_registry,
11561159
})))
11571160
}
11581161

@@ -1162,13 +1165,68 @@ impl<'g> GrammarInterpreter<'g> {
11621165
variant: &str,
11631166
fields: &[(String, ExprIR)],
11641167
state: &mut ParserState<'a>,
1165-
_span: Span,
1168+
span: Span,
11661169
) -> Result<ParsedValue, String> {
11671170
let ty = match variant {
11681171
"Unit" => Type::Primitive(PrimitiveType::Unit),
11691172
"Named" => {
11701173
let name = self.get_field_as_interned("name", fields, state)?;
1171-
Type::Unresolved(name)
1174+
let mut type_args = self
1175+
.get_field_as_type_list_optional("type_args", fields, state)?
1176+
.unwrap_or_default();
1177+
let mut const_args = self.get_field_as_const_list("const_args", fields, state)?;
1178+
1179+
// Tensor shape/dtype sugar from tensor[...] syntax.
1180+
let tensor_items =
1181+
self.get_field_as_interned_list("tensor_items", fields, state)?;
1182+
if !tensor_items.is_empty() {
1183+
let mut tokens: Vec<String> = tensor_items
1184+
.into_iter()
1185+
.map(|item| {
1186+
item.resolve_global().or_else(|| {
1187+
state
1188+
.builder()
1189+
.arena()
1190+
.resolve_string(item)
1191+
.map(|s| s.to_string())
1192+
})
1193+
})
1194+
.collect::<Option<Vec<String>>>()
1195+
.unwrap_or_default();
1196+
1197+
if let Some(last) = tokens.last() {
1198+
if let Some(dtype_prim) = self.primitive_type_from_name(last) {
1199+
type_args.push(Type::Primitive(dtype_prim));
1200+
tokens.pop();
1201+
}
1202+
}
1203+
1204+
for token in tokens {
1205+
if let Ok(v) = token.parse::<i64>() {
1206+
const_args.push(ConstValue::Int(v));
1207+
} else {
1208+
const_args.push(ConstValue::Variable(
1209+
zyntax_typed_ast::InternedString::new_global(&token),
1210+
));
1211+
}
1212+
}
1213+
}
1214+
1215+
let type_id = if let Some(type_def) = state.type_registry().get_type_by_name(name) {
1216+
type_def.id
1217+
} else {
1218+
state
1219+
.type_registry()
1220+
.register_atomic_type(name, TypeMetadata::default(), span)
1221+
};
1222+
1223+
Type::Named {
1224+
id: type_id,
1225+
type_args,
1226+
const_args,
1227+
variance: vec![],
1228+
nullability: NullabilityKind::NonNull,
1229+
}
11721230
}
11731231
"Primitive" => {
11741232
// Parse primitive type from name
@@ -1180,21 +1238,9 @@ impl<'g> GrammarInterpreter<'g> {
11801238
if name_str == "type" {
11811239
return Ok(ParsedValue::Type(Type::Any));
11821240
}
1183-
let prim = match name_str.as_str() {
1184-
"i8" => PrimitiveType::I8,
1185-
"i16" => PrimitiveType::I16,
1186-
"i32" => PrimitiveType::I32,
1187-
"i64" => PrimitiveType::I64,
1188-
"u8" => PrimitiveType::U8,
1189-
"u16" => PrimitiveType::U16,
1190-
"u32" => PrimitiveType::U32,
1191-
"u64" => PrimitiveType::U64,
1192-
"f32" => PrimitiveType::F32,
1193-
"f64" => PrimitiveType::F64,
1194-
"bool" => PrimitiveType::Bool,
1195-
"void" => PrimitiveType::Unit,
1196-
_ => return Err(format!("unknown primitive type: {}", name_str)),
1197-
};
1241+
let prim = self
1242+
.primitive_type_from_name(&name_str)
1243+
.ok_or_else(|| format!("unknown primitive type: {}", name_str))?;
11981244
Type::Primitive(prim)
11991245
}
12001246
"Pointer" => {
@@ -1834,6 +1880,83 @@ impl<'g> GrammarInterpreter<'g> {
18341880
}
18351881
}
18361882

1883+
fn get_field_as_const_list<'a>(
1884+
&self,
1885+
name: &str,
1886+
fields: &[(String, ExprIR)],
1887+
state: &mut ParserState<'a>,
1888+
) -> Result<Vec<ConstValue>, String> {
1889+
let Some(expr) = self.get_field(name, fields) else {
1890+
return Ok(vec![]);
1891+
};
1892+
1893+
let val = self.eval_expr(expr, state)?;
1894+
match val {
1895+
ParsedValue::List(items) => items
1896+
.into_iter()
1897+
.map(|item| self.parsed_value_to_const(item, state))
1898+
.collect(),
1899+
ParsedValue::Optional(None) | ParsedValue::None => Ok(vec![]),
1900+
ParsedValue::Optional(Some(inner)) => match *inner {
1901+
ParsedValue::List(items) => items
1902+
.into_iter()
1903+
.map(|item| self.parsed_value_to_const(item, state))
1904+
.collect(),
1905+
single => Ok(vec![self.parsed_value_to_const(single, state)?]),
1906+
},
1907+
single => Ok(vec![self.parsed_value_to_const(single, state)?]),
1908+
}
1909+
}
1910+
1911+
fn parsed_value_to_const<'a>(
1912+
&self,
1913+
value: ParsedValue,
1914+
state: &mut ParserState<'a>,
1915+
) -> Result<ConstValue, String> {
1916+
match value {
1917+
ParsedValue::Int(i) => Ok(ConstValue::Int(i)),
1918+
ParsedValue::Bool(b) => Ok(ConstValue::Bool(b)),
1919+
ParsedValue::Text(s) => Ok(ConstValue::String(state.intern(&s))),
1920+
ParsedValue::Interned(s) => Ok(ConstValue::Variable(s)),
1921+
ParsedValue::Literal(TypedLiteral::Integer(i)) => Ok(ConstValue::Int(i as i64)),
1922+
ParsedValue::Literal(TypedLiteral::Bool(b)) => Ok(ConstValue::Bool(b)),
1923+
ParsedValue::Literal(TypedLiteral::String(s)) => Ok(ConstValue::String(s)),
1924+
ParsedValue::Literal(TypedLiteral::Char(c)) => Ok(ConstValue::Char(c)),
1925+
ParsedValue::Optional(Some(inner)) => self.parsed_value_to_const(*inner, state),
1926+
ParsedValue::Optional(None) | ParsedValue::None => {
1927+
Err("cannot convert empty optional to const value".to_string())
1928+
}
1929+
other => Err(format!(
1930+
"cannot convert value to const argument: {:?}",
1931+
other
1932+
)),
1933+
}
1934+
}
1935+
1936+
fn primitive_type_from_name(&self, name: &str) -> Option<PrimitiveType> {
1937+
match name {
1938+
"i8" => Some(PrimitiveType::I8),
1939+
"i16" => Some(PrimitiveType::I16),
1940+
"i32" => Some(PrimitiveType::I32),
1941+
"i64" => Some(PrimitiveType::I64),
1942+
"i128" => Some(PrimitiveType::I128),
1943+
"u8" => Some(PrimitiveType::U8),
1944+
"u16" => Some(PrimitiveType::U16),
1945+
"u32" => Some(PrimitiveType::U32),
1946+
"u64" => Some(PrimitiveType::U64),
1947+
"u128" => Some(PrimitiveType::U128),
1948+
"f32" => Some(PrimitiveType::F32),
1949+
"f64" => Some(PrimitiveType::F64),
1950+
"bool" => Some(PrimitiveType::Bool),
1951+
"char" => Some(PrimitiveType::Char),
1952+
"str" | "String" => Some(PrimitiveType::String),
1953+
"isize" => Some(PrimitiveType::ISize),
1954+
"usize" => Some(PrimitiveType::USize),
1955+
"void" => Some(PrimitiveType::Unit),
1956+
_ => None,
1957+
}
1958+
}
1959+
18371960
/// Execute a helper function call
18381961
fn execute_helper_call<'a>(
18391962
&self,

crates/zynml/ml.zyn

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -764,9 +764,10 @@ type_opaque = { "@" ~ "opaque" ~ "(" ~ string_literal ~ ")" ~ "type" ~ name:iden
764764
}
765765

766766
// Type alias: type Scalar = f32
767-
type_alias = { "type" ~ name:identifier ~ "=" ~ type_expr }
767+
type_alias = { "type" ~ name:identifier ~ "=" ~ target:type_expr }
768768
-> TypedDeclaration::TypeAlias {
769769
name: intern(name),
770+
target: target,
770771
}
771772

772773
// Struct type: type Point = { x: f32, y: f32 }
@@ -1367,6 +1368,7 @@ optional_type = { "?" ~ inner:type_expr_non_optional }
13671368
-> Type::Named {
13681369
name: intern("Option"),
13691370
type_args: [inner],
1371+
const_args: [],
13701372
}
13711373

13721374
// Type expression without the optional prefix (to avoid infinite recursion)
@@ -1378,9 +1380,12 @@ type_expr_non_optional = { ty:fn_type | ty:tuple_type | ty:tensor_type | ty:gene
13781380
// tensor[2, 3]
13791381
// tensor[batch, seq, hidden]
13801382
// tensor[1, 3, 224, 224, f32]
1381-
tensor_type = { "tensor" ~ "[" ~ tensor_type_items? ~ "]" }
1383+
tensor_type = { "tensor" ~ "[" ~ items:tensor_type_items? ~ "]" }
13821384
-> Type::Named {
13831385
name: intern("Tensor"),
1386+
tensor_items: items,
1387+
type_args: [],
1388+
const_args: [],
13841389
}
13851390

13861391
tensor_type_items = { first:tensor_type_item ~ rest:tensor_type_item_comma* }
@@ -1389,7 +1394,7 @@ tensor_type_items = { first:tensor_type_item ~ rest:tensor_type_item_comma* }
13891394
tensor_type_item_comma = { "," ~ item:tensor_type_item }
13901395
-> item
13911396

1392-
tensor_type_item = { identifier | "?" | ASCII_DIGIT+ }
1397+
tensor_type_item = @{ identifier | "?" | ASCII_DIGIT+ }
13931398
-> text()
13941399

13951400
// Primitive types must be matched BEFORE simple_type to avoid treating them as named types
@@ -1453,16 +1458,22 @@ prim_unit = { "()" }
14531458
simple_type = { name:identifier }
14541459
-> Type::Named {
14551460
name: intern(name),
1461+
type_args: [],
1462+
const_args: [],
14561463
}
14571464

14581465
array_type = { "Array" ~ "<" ~ elem:type_expr ~ ">" }
14591466
-> Type::Named {
14601467
name: intern("Array"),
1468+
type_args: [elem],
1469+
const_args: [],
14611470
}
14621471

1463-
generic_type = { name:identifier ~ "<" ~ type_args ~ ">" }
1472+
generic_type = { name:identifier ~ "<" ~ args:type_args ~ ">" }
14641473
-> Type::Named {
14651474
name: intern(name),
1475+
type_args: args,
1476+
const_args: [],
14661477
}
14671478

14681479
// Type arguments: can be regular types or associated type bindings (Item=T)
@@ -1505,6 +1516,8 @@ where_bounds = { bound:where_bound ~ ("," ~ bound:where_bound)* }
15051516
where_bound = { name:identifier ~ ":" ~ bounds:where_trait_bounds }
15061517
-> Type::Named {
15071518
name: intern(name),
1519+
type_args: [],
1520+
const_args: [],
15081521
}
15091522

15101523
// Multiple trait bounds separated by +: Clone + Debug + Iterator<Item=T>
@@ -1519,12 +1532,16 @@ where_trait_bound = { bound:trait_bound_generic | bound:trait_bound_fn | bound:t
15191532
trait_bound_generic = { name:identifier ~ "<" ~ args:type_args ~ ">" }
15201533
-> Type::Named {
15211534
name: intern(name),
1535+
type_args: args,
1536+
const_args: [],
15221537
}
15231538

15241539
// Simple trait bound: Clone, Debug
15251540
trait_bound_simple = { name:identifier }
15261541
-> Type::Named {
15271542
name: intern(name),
1543+
type_args: [],
1544+
const_args: [],
15281545
}
15291546

15301547
// Function type as trait bound: (T) => U

crates/zynml/tests/e2e_tests.rs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,60 @@ mod type_system {
188188
#[test]
189189
fn test_parse_type_alias_tensor() {
190190
let grammar = get_grammar();
191-
// Current grammar doesn't support tensor[shape, dtype] syntax in type position
192-
let result = grammar.parse_to_json("type Embedding = Tensor");
191+
let result = grammar.parse_to_json("type Embedding = tensor[2, 3, f32]");
193192
assert!(
194193
result.is_ok(),
195194
"Should parse tensor type alias: {:?}",
196195
result.err()
197196
);
197+
198+
let parsed: serde_json::Value =
199+
serde_json::from_str(&result.unwrap()).expect("JSON should deserialize");
200+
let named = &parsed["declarations"][0]["node"]["TypeAlias"]["target"]["Named"];
201+
let const_args = named["const_args"]
202+
.as_array()
203+
.expect("const_args should be an array");
204+
assert_eq!(
205+
const_args.len(),
206+
2,
207+
"tensor shape should have two dimensions"
208+
);
209+
assert_eq!(const_args[0]["Int"], 2);
210+
assert_eq!(const_args[1]["Int"], 3);
211+
assert_eq!(
212+
named["type_args"][0]["Primitive"], "F32",
213+
"tensor dtype should map to Primitive::F32"
214+
);
215+
}
216+
217+
#[test]
218+
fn test_parse_type_alias_tensor_symbolic_shape() {
219+
let grammar = get_grammar();
220+
let result = grammar.parse_to_json("type HiddenState = tensor[batch, seq, hidden]");
221+
assert!(
222+
result.is_ok(),
223+
"Should parse tensor type alias with symbolic shape: {:?}",
224+
result.err()
225+
);
226+
227+
let parsed: serde_json::Value =
228+
serde_json::from_str(&result.unwrap()).expect("JSON should deserialize");
229+
let named = &parsed["declarations"][0]["node"]["TypeAlias"]["target"]["Named"];
230+
let const_args = named["const_args"]
231+
.as_array()
232+
.expect("const_args should be an array");
233+
assert_eq!(const_args.len(), 3);
234+
assert_eq!(const_args[0]["Variable"], "batch");
235+
assert_eq!(const_args[1]["Variable"], "seq");
236+
assert_eq!(const_args[2]["Variable"], "hidden");
237+
assert_eq!(
238+
named["type_args"]
239+
.as_array()
240+
.expect("type_args should be an array")
241+
.len(),
242+
0,
243+
"symbolic tensor shape should not imply dtype when omitted"
244+
);
198245
}
199246

200247
// --- Struct Definitions ---

crates/zyntax_embed/src/runtime.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,6 +1669,21 @@ impl ZyntaxRuntime {
16691669
);
16701670
}
16711671
}
1672+
Type::Named { id, type_args, .. } => {
1673+
for type_arg in type_args {
1674+
Self::resolve_in_type(type_arg, type_registry);
1675+
}
1676+
1677+
// Canonicalize Named IDs by name in case this ID came from a placeholder
1678+
// registry entry created before imports/externs were merged.
1679+
if let Some(type_def) = type_registry.get_type_by_id(*id) {
1680+
if let Some(canonical) = type_registry.get_type_by_name(type_def.name) {
1681+
if canonical.id != *id {
1682+
*id = canonical.id;
1683+
}
1684+
}
1685+
}
1686+
}
16721687
// Recursively resolve nested types
16731688
Type::Reference { ty: inner, .. } => {
16741689
Self::resolve_in_type(inner, type_registry);

0 commit comments

Comments
 (0)