@@ -5,6 +5,7 @@ use wit_parser::*;
55#[ derive( Default ) ]
66pub struct Types {
77 type_info : HashMap < TypeId , TypeInfo > ,
8+ equal_types : UnionFind ,
89}
910
1011#[ derive( Default , Clone , Copy , Debug ) ]
@@ -93,6 +94,22 @@ impl Types {
9394 }
9495 }
9596 }
97+ pub fn collect_equal_types ( & mut self , resolve : & Resolve ) {
98+ for ( i, ( ty, _) ) in resolve. types . iter ( ) . enumerate ( ) {
99+ // TODO: we could define a hash function for TypeDefKind to prevent the inner loop.
100+ for ( earlier, _) in resolve. types . iter ( ) . take ( i) {
101+ if self . equal_types . find ( ty) == self . equal_types . find ( earlier) {
102+ continue ;
103+ }
104+ // The correctness of is_structurally_equal relies on the fact
105+ // that resolve.types.iter() is in topological order.
106+ if self . is_structurally_equal ( resolve, ty, earlier) {
107+ self . equal_types . union ( ty, earlier) ;
108+ break ;
109+ }
110+ }
111+ }
112+ }
96113
97114 fn type_info_func ( & mut self , resolve : & Resolve , func : & Function , import : bool ) {
98115 let mut live = LiveTypes :: default ( ) ;
@@ -233,4 +250,185 @@ impl Types {
233250 None => TypeInfo :: default ( ) ,
234251 }
235252 }
253+
254+ fn is_structurally_equal ( & mut self , resolve : & Resolve , a : TypeId , b : TypeId ) -> bool {
255+ let a_def = & resolve. types [ a] . kind ;
256+ let b_def = & resolve. types [ b] . kind ;
257+ if self . equal_types . find ( a) == self . equal_types . find ( b) {
258+ return true ;
259+ }
260+ match ( a_def, b_def) {
261+ // Peel off typedef layers and continue recursing.
262+ ( TypeDefKind :: Type ( a) , _) => self . type_id_equal_to_type ( resolve, b, a) ,
263+ ( _, TypeDefKind :: Type ( b) ) => self . type_id_equal_to_type ( resolve, a, b) ,
264+
265+ ( TypeDefKind :: Record ( ra) , TypeDefKind :: Record ( rb) ) => {
266+ ra. fields . len ( ) == rb. fields . len ( )
267+ // Fields are ordered in WIT, so record {a: T, b: U} is different from {b: U, a: T}
268+ && ra. fields . iter ( ) . zip ( rb. fields . iter ( ) ) . all ( |( fa, fb) | {
269+ fa. name == fb. name && self . types_equal ( resolve, & fa. ty , & fb. ty )
270+ } )
271+ }
272+ ( TypeDefKind :: Record ( _) , _) => false ,
273+ ( TypeDefKind :: Variant ( va) , TypeDefKind :: Variant ( vb) ) => {
274+ va. cases . len ( ) == vb. cases . len ( )
275+ && va. cases . iter ( ) . zip ( vb. cases . iter ( ) ) . all ( |( ca, cb) | {
276+ ca. name == cb. name && self . optional_types_equal ( resolve, & ca. ty , & cb. ty )
277+ } )
278+ }
279+ ( TypeDefKind :: Variant ( _) , _) => false ,
280+ ( TypeDefKind :: Enum ( ea) , TypeDefKind :: Enum ( eb) ) => {
281+ ea. cases . len ( ) == eb. cases . len ( )
282+ && ea
283+ . cases
284+ . iter ( )
285+ . zip ( eb. cases . iter ( ) )
286+ . all ( |( ca, cb) | ca. name == cb. name )
287+ }
288+ ( TypeDefKind :: Enum ( _) , _) => false ,
289+ ( TypeDefKind :: Flags ( fa) , TypeDefKind :: Flags ( fb) ) => {
290+ fa. flags . len ( ) == fb. flags . len ( )
291+ && fa
292+ . flags
293+ . iter ( )
294+ . zip ( fb. flags . iter ( ) )
295+ . all ( |( fa, fb) | fa. name == fb. name )
296+ }
297+ ( TypeDefKind :: Flags ( _) , _) => false ,
298+ ( TypeDefKind :: Tuple ( ta) , TypeDefKind :: Tuple ( tb) ) => {
299+ ta. types . len ( ) == tb. types . len ( )
300+ && ta
301+ . types
302+ . iter ( )
303+ . zip ( tb. types . iter ( ) )
304+ . all ( |( a, b) | self . types_equal ( resolve, a, b) )
305+ }
306+ ( TypeDefKind :: Tuple ( _) , _) => false ,
307+ ( TypeDefKind :: List ( la) , TypeDefKind :: List ( lb) ) => self . types_equal ( resolve, la, lb) ,
308+ ( TypeDefKind :: List ( _) , _) => false ,
309+ ( TypeDefKind :: FixedLengthList ( ta, sa) , TypeDefKind :: FixedLengthList ( tb, sb) ) => {
310+ sa == sb && self . types_equal ( resolve, ta, tb)
311+ }
312+ ( TypeDefKind :: FixedLengthList ( ..) , _) => false ,
313+ ( TypeDefKind :: Option ( oa) , TypeDefKind :: Option ( ob) ) => self . types_equal ( resolve, oa, ob) ,
314+ ( TypeDefKind :: Option ( _) , _) => false ,
315+ ( TypeDefKind :: Result ( ra) , TypeDefKind :: Result ( rb) ) => {
316+ self . optional_types_equal ( resolve, & ra. ok , & rb. ok )
317+ && self . optional_types_equal ( resolve, & ra. err , & rb. err )
318+ }
319+ ( TypeDefKind :: Result ( _) , _) => false ,
320+ ( TypeDefKind :: Map ( ak, av) , TypeDefKind :: Map ( bk, bv) ) => {
321+ self . types_equal ( resolve, ak, bk) && self . types_equal ( resolve, av, bv)
322+ }
323+ ( TypeDefKind :: Map ( ..) , _) => false ,
324+ ( TypeDefKind :: Future ( a) , TypeDefKind :: Future ( b) ) => {
325+ self . optional_types_equal ( resolve, a, b)
326+ }
327+ ( TypeDefKind :: Future ( ..) , _) => false ,
328+ ( TypeDefKind :: Stream ( a) , TypeDefKind :: Stream ( b) ) => {
329+ self . optional_types_equal ( resolve, a, b)
330+ }
331+ ( TypeDefKind :: Stream ( ..) , _) => false ,
332+ ( TypeDefKind :: Handle ( a) , TypeDefKind :: Handle ( b) ) => match ( a, b) {
333+ ( Handle :: Own ( a) , Handle :: Own ( b) ) | ( Handle :: Borrow ( a) , Handle :: Borrow ( b) ) => {
334+ self . is_structurally_equal ( resolve, * a, * b)
335+ }
336+ ( Handle :: Own ( _) | Handle :: Borrow ( _) , _) => false ,
337+ } ,
338+ ( TypeDefKind :: Handle ( _) , _) => false ,
339+ ( TypeDefKind :: Unknown , _) => unreachable ! ( ) ,
340+
341+ // TODO: for now consider all resources not-equal to each other.
342+ // This is because the same type id can be used for both an imported
343+ // and exported resource where those should be distinct types.
344+ ( TypeDefKind :: Resource , _) => false ,
345+ }
346+ }
347+
348+ fn types_equal ( & mut self , resolve : & Resolve , a : & Type , b : & Type ) -> bool {
349+ match ( a, b) {
350+ // Peel off typedef layers and continue recursing.
351+ ( Type :: Id ( a) , b) => self . type_id_equal_to_type ( resolve, * a, b) ,
352+ ( a, Type :: Id ( b) ) => self . type_id_equal_to_type ( resolve, * b, a) ,
353+
354+ // When both a and b are primitives, they're only equal of
355+ // the primitives are the same.
356+ (
357+ Type :: Bool
358+ | Type :: U8
359+ | Type :: S8
360+ | Type :: U16
361+ | Type :: S16
362+ | Type :: U32
363+ | Type :: S32
364+ | Type :: U64
365+ | Type :: S64
366+ | Type :: F32
367+ | Type :: F64
368+ | Type :: Char
369+ | Type :: String
370+ | Type :: ErrorContext ,
371+ _,
372+ ) => a == b,
373+ }
374+ }
375+
376+ fn type_id_equal_to_type ( & mut self , resolve : & Resolve , a : TypeId , b : & Type ) -> bool {
377+ let ak = & resolve. types [ a] . kind ;
378+ match ( ak, b) {
379+ ( TypeDefKind :: Type ( a) , b) => self . types_equal ( resolve, a, b) ,
380+ ( _, Type :: Id ( b) ) => self . is_structurally_equal ( resolve, a, * b) ,
381+
382+ // Type `a` isn't a typedef, and type `b` is a primitive, so it's no
383+ // longer possible for them to be equal.
384+ _ => false ,
385+ }
386+ }
387+
388+ fn optional_types_equal (
389+ & mut self ,
390+ resolve : & Resolve ,
391+ a : & Option < Type > ,
392+ b : & Option < Type > ,
393+ ) -> bool {
394+ match ( a, b) {
395+ ( Some ( a) , Some ( b) ) => self . types_equal ( resolve, a, b) ,
396+ ( Some ( _) , None ) | ( None , Some ( _) ) => false ,
397+ ( None , None ) => true ,
398+ }
399+ }
400+
401+ pub fn get_representative_type ( & mut self , id : TypeId ) -> TypeId {
402+ self . equal_types . find ( id)
403+ }
404+ }
405+
406+ #[ derive( Default ) ]
407+ pub struct UnionFind {
408+ parent : HashMap < TypeId , TypeId > ,
409+ }
410+ impl UnionFind {
411+ fn find ( & mut self , id : TypeId ) -> TypeId {
412+ // Path compression
413+ let parent = self . parent . get ( & id) . copied ( ) . unwrap_or ( id) ;
414+ if parent != id {
415+ let root = self . find ( parent) ;
416+ self . parent . insert ( id, root) ;
417+ root
418+ } else {
419+ id
420+ }
421+ }
422+ fn union ( & mut self , a : TypeId , b : TypeId ) {
423+ let ra = self . find ( a) ;
424+ let rb = self . find ( b) ;
425+ if ra != rb {
426+ // Use smaller id as root for determinism
427+ if ra < rb {
428+ self . parent . insert ( rb, ra) ;
429+ } else {
430+ self . parent . insert ( ra, rb) ;
431+ }
432+ }
433+ }
236434}
0 commit comments