@@ -18,7 +18,8 @@ use zyntax_typed_ast::typed_ast::{
1818} ;
1919use zyntax_typed_ast:: {
2020 type_registry:: {
21- CallingConvention , ConstValue , Mutability , NullabilityKind , PrimitiveType , Type , Visibility ,
21+ CallingConvention , ConstValue , Mutability , NullabilityKind , PrimitiveType , Type ,
22+ TypeMetadata , Visibility ,
2223 } ,
2324 typed_node, ParameterKind , Span , TypedAnnotation , TypedAnnotationArg , TypedAnnotationValue ,
2425 TypedBlock , TypedCall , TypedDeclaration , TypedExpression , TypedExtern , TypedExternStruct ,
@@ -1147,12 +1148,14 @@ impl<'g> GrammarInterpreter<'g> {
11471148 span : Span ,
11481149 ) -> Result < ParsedValue , String > {
11491150 let declarations = self . get_field_as_decl_list ( "declarations" , fields, state) ?;
1151+ // Preserve any placeholder or inferred type entries created while parsing.
1152+ let type_registry = state. type_registry ( ) . clone ( ) ;
11501153
11511154 Ok ( ParsedValue :: Program ( Box :: new ( TypedProgram {
11521155 declarations,
11531156 span,
11541157 source_files : vec ! [ ] ,
1155- type_registry : zyntax_typed_ast :: type_registry :: TypeRegistry :: new ( ) ,
1158+ type_registry,
11561159 } ) ) )
11571160 }
11581161
@@ -1162,13 +1165,68 @@ impl<'g> GrammarInterpreter<'g> {
11621165 variant : & str ,
11631166 fields : & [ ( String , ExprIR ) ] ,
11641167 state : & mut ParserState < ' a > ,
1165- _span : Span ,
1168+ span : Span ,
11661169 ) -> Result < ParsedValue , String > {
11671170 let ty = match variant {
11681171 "Unit" => Type :: Primitive ( PrimitiveType :: Unit ) ,
11691172 "Named" => {
11701173 let name = self . get_field_as_interned ( "name" , fields, state) ?;
1171- Type :: Unresolved ( name)
1174+ let mut type_args = self
1175+ . get_field_as_type_list_optional ( "type_args" , fields, state) ?
1176+ . unwrap_or_default ( ) ;
1177+ let mut const_args = self . get_field_as_const_list ( "const_args" , fields, state) ?;
1178+
1179+ // Tensor shape/dtype sugar from tensor[...] syntax.
1180+ let tensor_items =
1181+ self . get_field_as_interned_list ( "tensor_items" , fields, state) ?;
1182+ if !tensor_items. is_empty ( ) {
1183+ let mut tokens: Vec < String > = tensor_items
1184+ . into_iter ( )
1185+ . map ( |item| {
1186+ item. resolve_global ( ) . or_else ( || {
1187+ state
1188+ . builder ( )
1189+ . arena ( )
1190+ . resolve_string ( item)
1191+ . map ( |s| s. to_string ( ) )
1192+ } )
1193+ } )
1194+ . collect :: < Option < Vec < String > > > ( )
1195+ . unwrap_or_default ( ) ;
1196+
1197+ if let Some ( last) = tokens. last ( ) {
1198+ if let Some ( dtype_prim) = self . primitive_type_from_name ( last) {
1199+ type_args. push ( Type :: Primitive ( dtype_prim) ) ;
1200+ tokens. pop ( ) ;
1201+ }
1202+ }
1203+
1204+ for token in tokens {
1205+ if let Ok ( v) = token. parse :: < i64 > ( ) {
1206+ const_args. push ( ConstValue :: Int ( v) ) ;
1207+ } else {
1208+ const_args. push ( ConstValue :: Variable (
1209+ zyntax_typed_ast:: InternedString :: new_global ( & token) ,
1210+ ) ) ;
1211+ }
1212+ }
1213+ }
1214+
1215+ let type_id = if let Some ( type_def) = state. type_registry ( ) . get_type_by_name ( name) {
1216+ type_def. id
1217+ } else {
1218+ state
1219+ . type_registry ( )
1220+ . register_atomic_type ( name, TypeMetadata :: default ( ) , span)
1221+ } ;
1222+
1223+ Type :: Named {
1224+ id : type_id,
1225+ type_args,
1226+ const_args,
1227+ variance : vec ! [ ] ,
1228+ nullability : NullabilityKind :: NonNull ,
1229+ }
11721230 }
11731231 "Primitive" => {
11741232 // Parse primitive type from name
@@ -1180,21 +1238,9 @@ impl<'g> GrammarInterpreter<'g> {
11801238 if name_str == "type" {
11811239 return Ok ( ParsedValue :: Type ( Type :: Any ) ) ;
11821240 }
1183- let prim = match name_str. as_str ( ) {
1184- "i8" => PrimitiveType :: I8 ,
1185- "i16" => PrimitiveType :: I16 ,
1186- "i32" => PrimitiveType :: I32 ,
1187- "i64" => PrimitiveType :: I64 ,
1188- "u8" => PrimitiveType :: U8 ,
1189- "u16" => PrimitiveType :: U16 ,
1190- "u32" => PrimitiveType :: U32 ,
1191- "u64" => PrimitiveType :: U64 ,
1192- "f32" => PrimitiveType :: F32 ,
1193- "f64" => PrimitiveType :: F64 ,
1194- "bool" => PrimitiveType :: Bool ,
1195- "void" => PrimitiveType :: Unit ,
1196- _ => return Err ( format ! ( "unknown primitive type: {}" , name_str) ) ,
1197- } ;
1241+ let prim = self
1242+ . primitive_type_from_name ( & name_str)
1243+ . ok_or_else ( || format ! ( "unknown primitive type: {}" , name_str) ) ?;
11981244 Type :: Primitive ( prim)
11991245 }
12001246 "Pointer" => {
@@ -1834,6 +1880,83 @@ impl<'g> GrammarInterpreter<'g> {
18341880 }
18351881 }
18361882
1883+ fn get_field_as_const_list < ' a > (
1884+ & self ,
1885+ name : & str ,
1886+ fields : & [ ( String , ExprIR ) ] ,
1887+ state : & mut ParserState < ' a > ,
1888+ ) -> Result < Vec < ConstValue > , String > {
1889+ let Some ( expr) = self . get_field ( name, fields) else {
1890+ return Ok ( vec ! [ ] ) ;
1891+ } ;
1892+
1893+ let val = self . eval_expr ( expr, state) ?;
1894+ match val {
1895+ ParsedValue :: List ( items) => items
1896+ . into_iter ( )
1897+ . map ( |item| self . parsed_value_to_const ( item, state) )
1898+ . collect ( ) ,
1899+ ParsedValue :: Optional ( None ) | ParsedValue :: None => Ok ( vec ! [ ] ) ,
1900+ ParsedValue :: Optional ( Some ( inner) ) => match * inner {
1901+ ParsedValue :: List ( items) => items
1902+ . into_iter ( )
1903+ . map ( |item| self . parsed_value_to_const ( item, state) )
1904+ . collect ( ) ,
1905+ single => Ok ( vec ! [ self . parsed_value_to_const( single, state) ?] ) ,
1906+ } ,
1907+ single => Ok ( vec ! [ self . parsed_value_to_const( single, state) ?] ) ,
1908+ }
1909+ }
1910+
1911+ fn parsed_value_to_const < ' a > (
1912+ & self ,
1913+ value : ParsedValue ,
1914+ state : & mut ParserState < ' a > ,
1915+ ) -> Result < ConstValue , String > {
1916+ match value {
1917+ ParsedValue :: Int ( i) => Ok ( ConstValue :: Int ( i) ) ,
1918+ ParsedValue :: Bool ( b) => Ok ( ConstValue :: Bool ( b) ) ,
1919+ ParsedValue :: Text ( s) => Ok ( ConstValue :: String ( state. intern ( & s) ) ) ,
1920+ ParsedValue :: Interned ( s) => Ok ( ConstValue :: Variable ( s) ) ,
1921+ ParsedValue :: Literal ( TypedLiteral :: Integer ( i) ) => Ok ( ConstValue :: Int ( i as i64 ) ) ,
1922+ ParsedValue :: Literal ( TypedLiteral :: Bool ( b) ) => Ok ( ConstValue :: Bool ( b) ) ,
1923+ ParsedValue :: Literal ( TypedLiteral :: String ( s) ) => Ok ( ConstValue :: String ( s) ) ,
1924+ ParsedValue :: Literal ( TypedLiteral :: Char ( c) ) => Ok ( ConstValue :: Char ( c) ) ,
1925+ ParsedValue :: Optional ( Some ( inner) ) => self . parsed_value_to_const ( * inner, state) ,
1926+ ParsedValue :: Optional ( None ) | ParsedValue :: None => {
1927+ Err ( "cannot convert empty optional to const value" . to_string ( ) )
1928+ }
1929+ other => Err ( format ! (
1930+ "cannot convert value to const argument: {:?}" ,
1931+ other
1932+ ) ) ,
1933+ }
1934+ }
1935+
1936+ fn primitive_type_from_name ( & self , name : & str ) -> Option < PrimitiveType > {
1937+ match name {
1938+ "i8" => Some ( PrimitiveType :: I8 ) ,
1939+ "i16" => Some ( PrimitiveType :: I16 ) ,
1940+ "i32" => Some ( PrimitiveType :: I32 ) ,
1941+ "i64" => Some ( PrimitiveType :: I64 ) ,
1942+ "i128" => Some ( PrimitiveType :: I128 ) ,
1943+ "u8" => Some ( PrimitiveType :: U8 ) ,
1944+ "u16" => Some ( PrimitiveType :: U16 ) ,
1945+ "u32" => Some ( PrimitiveType :: U32 ) ,
1946+ "u64" => Some ( PrimitiveType :: U64 ) ,
1947+ "u128" => Some ( PrimitiveType :: U128 ) ,
1948+ "f32" => Some ( PrimitiveType :: F32 ) ,
1949+ "f64" => Some ( PrimitiveType :: F64 ) ,
1950+ "bool" => Some ( PrimitiveType :: Bool ) ,
1951+ "char" => Some ( PrimitiveType :: Char ) ,
1952+ "str" | "String" => Some ( PrimitiveType :: String ) ,
1953+ "isize" => Some ( PrimitiveType :: ISize ) ,
1954+ "usize" => Some ( PrimitiveType :: USize ) ,
1955+ "void" => Some ( PrimitiveType :: Unit ) ,
1956+ _ => None ,
1957+ }
1958+ }
1959+
18371960 /// Execute a helper function call
18381961 fn execute_helper_call < ' a > (
18391962 & self ,
0 commit comments