@@ -20,34 +20,38 @@ pub fn expand(item: &syn::Item) -> Result<TokenStream> {
2020 }
2121
2222 for item in & item. items {
23- methods. push ( prepare_impl_method ( & item) ?) ;
23+ methods. push ( prepare_impl_method ( item) ?) ;
2424 }
2525
2626 let ty = item. self_ty . as_ref ( ) . clone ( ) ;
2727 ( ty, & item. generics , None )
2828 }
2929 syn:: Item :: Trait ( item) => {
30- for param in item. generics . params . iter ( ) {
31- bail ! ( param , "tracked traits cannot be generic" )
30+ if let Some ( first ) = item. generics . params . first ( ) {
31+ bail ! ( first , "tracked traits cannot be generic" )
3232 }
3333
3434 for item in & item. items {
35- methods. push ( prepare_trait_method ( & item) ?) ;
35+ methods. push ( prepare_trait_method ( item) ?) ;
3636 }
3737
3838 let name = & item. ident ;
3939 let ty = parse_quote ! { dyn #name + ' __comemo_dynamic } ;
40- ( ty, & item. generics , Some ( name . clone ( ) ) )
40+ ( ty, & item. generics , Some ( item . ident . clone ( ) ) )
4141 }
4242 _ => bail ! ( item, "`track` can only be applied to impl blocks and traits" ) ,
4343 } ;
4444
4545 // Produce the necessary items for the type to become trackable.
46+ let variants = create_variants ( & methods) ;
4647 let scope = create ( & ty, generics, trait_, & methods) ?;
4748
4849 Ok ( quote ! {
4950 #item
50- const _: ( ) = { #scope } ;
51+ const _: ( ) = {
52+ #variants
53+ #scope
54+ } ;
5155 } )
5256}
5357
@@ -175,6 +179,43 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result<Method>
175179 } )
176180}
177181
182+ /// Produces the variants for the constraint.
183+ fn create_variants ( methods : & [ Method ] ) -> TokenStream {
184+ let variants = methods. iter ( ) . map ( create_variant) ;
185+ let is_mutable_variants = methods. iter ( ) . map ( |m| {
186+ let name = & m. sig . ident ;
187+ let mutable = m. mutable ;
188+ quote ! { __ComemoVariant:: #name( ..) => #mutable }
189+ } ) ;
190+
191+ let is_mutable = ( !methods. is_empty ( ) )
192+ . then ( || {
193+ quote ! {
194+ match & self . 0 {
195+ #( #is_mutable_variants) , *
196+ }
197+ }
198+ } )
199+ . unwrap_or_else ( || quote ! { false } ) ;
200+
201+ quote ! {
202+ #[ derive( Clone , PartialEq , Hash ) ]
203+ pub struct __ComemoCall( __ComemoVariant) ;
204+
205+ impl :: comemo:: internal:: Call for __ComemoCall {
206+ fn is_mutable( & self ) -> bool {
207+ #is_mutable
208+ }
209+ }
210+
211+ #[ derive( Clone , PartialEq , Hash ) ]
212+ #[ allow( non_camel_case_types) ]
213+ enum __ComemoVariant {
214+ #( #variants, ) *
215+ }
216+ }
217+ }
218+
178219/// Produce the necessary items for a type to become trackable.
179220fn create (
180221 ty : & syn:: Type ,
@@ -229,26 +270,32 @@ fn create(
229270 } ;
230271
231272 // Prepare replying.
273+ let immutable = methods. iter ( ) . all ( |m| !m. mutable ) ;
232274 let replays = methods. iter ( ) . map ( create_replay) ;
233- let replay = methods . iter ( ) . any ( |m| m . mutable ) . then ( || {
275+ let replay = ( !immutable ) . then ( || {
234276 quote ! {
235277 constraint. replay( |call| match & call. 0 { #( #replays, ) * } ) ;
236278 }
237279 } ) ;
238280
239281 // Prepare variants and wrapper methods.
240- let variants = methods. iter ( ) . map ( create_variant) ;
241282 let wrapper_methods = methods
242283 . iter ( )
243284 . filter ( |m| !m. mutable )
244285 . map ( |m| create_wrapper ( m, false ) ) ;
245286 let wrapper_methods_mut = methods. iter ( ) . map ( |m| create_wrapper ( m, true ) ) ;
246287
288+ let constraint = if immutable {
289+ quote ! { ImmutableConstraint }
290+ } else {
291+ quote ! { MutableConstraint }
292+ } ;
293+
247294 Ok ( quote ! {
248- impl #impl_params :: comemo:: Track for #ty #where_clause { }
295+ impl #impl_params :: comemo:: Track for #ty #where_clause { }
249296
250- impl #impl_params :: comemo:: Validate for #ty #where_clause {
251- type Constraint = :: comemo:: internal:: Constraint <__ComemoCall>;
297+ impl #impl_params :: comemo:: Validate for #ty #where_clause {
298+ type Constraint = :: comemo:: internal:: #constraint <__ComemoCall>;
252299
253300 #[ inline]
254301 fn validate( & self , constraint: & Self :: Constraint ) -> bool {
@@ -267,15 +314,6 @@ fn create(
267314 }
268315 }
269316
270- #[ derive( Clone , PartialEq , Hash ) ]
271- pub struct __ComemoCall( __ComemoVariant) ;
272-
273- #[ derive( Clone , PartialEq , Hash ) ]
274- #[ allow( non_camel_case_types) ]
275- enum __ComemoVariant {
276- #( #variants, ) *
277- }
278-
279317 #[ doc( hidden) ]
280318 impl #impl_params :: comemo:: internal:: Surfaces for #ty #where_clause {
281319 type Surface <#t> = __ComemoSurface #type_params_t where Self : #t;
@@ -323,7 +361,6 @@ fn create(
323361 impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t {
324362 #( #wrapper_methods_mut) *
325363 }
326-
327364 } )
328365}
329366
@@ -370,10 +407,9 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
370407 let vis = & method. vis ;
371408 let sig = & method. sig ;
372409 let args = & method. args ;
373- let mutable = method. mutable ;
374410 let to_parts = if !tracked_mut {
375411 quote ! { to_parts_ref( self . 0 ) }
376- } else if !mutable {
412+ } else if !method . mutable {
377413 quote ! { to_parts_mut_ref( & self . 0 ) }
378414 } else {
379415 quote ! { to_parts_mut_mut( & mut self . 0 ) }
@@ -389,7 +425,6 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
389425 constraint. push(
390426 __ComemoCall( __comemo_variant) ,
391427 :: comemo:: internal:: hash( & output) ,
392- #mutable,
393428 ) ;
394429 }
395430 output
0 commit comments