8
8
9
9
use std:: iter;
10
10
11
- use proc_macro2:: TokenStream ;
11
+ use proc_macro2:: { Span , TokenStream } ;
12
12
use quote:: quote;
13
13
use syn:: {
14
14
parse:: { Parse , ParseStream } ,
15
- parse_macro_input, parse_quote ,
15
+ parse_macro_input,
16
16
punctuated:: Punctuated ,
17
17
token:: Plus ,
18
- Error , FnArg , GenericParam , Ident , ItemTrait , Pat , PatType , Result , ReturnType , Signature ,
19
- Token , TraitBound , TraitItem , TraitItemConst , TraitItemFn , TraitItemType , Type , TypeGenerics ,
20
- TypeImplTrait , TypeParam , TypeParamBound ,
18
+ Error , FnArg , GenericParam , Ident , ItemTrait , Pat , PatType , Receiver , Result , ReturnType ,
19
+ Signature , Token , TraitBound , TraitItem , TraitItemConst , TraitItemFn , TraitItemType , Type ,
20
+ TypeGenerics , TypeImplTrait , TypeParam , TypeParamBound , WhereClause ,
21
21
} ;
22
+ use syn:: { parse_quote, TypeReference } ;
22
23
23
24
struct Attrs {
24
25
variant : MakeVariant ,
@@ -127,10 +128,10 @@ fn mk_variant(
127
128
128
129
// Transforms one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds.
129
130
fn transform_item ( item : & TraitItem , bounds : & Vec < TypeParamBound > ) -> TraitItem {
130
- let TraitItem :: Fn ( fn_item @ TraitItemFn { sig, .. } ) = item else {
131
+ let TraitItem :: Fn ( fn_item @ TraitItemFn { sig, default , .. } ) = item else {
131
132
return item. clone ( ) ;
132
133
} ;
133
- let ( arrow , output ) = if sig. asyncness . is_some ( ) {
134
+ let ( sig , default ) = if sig. asyncness . is_some ( ) {
134
135
let orig = match & sig. output {
135
136
ReturnType :: Default => quote ! { ( ) } ,
136
137
ReturnType :: Type ( _, ty) => quote ! { #ty } ,
@@ -142,7 +143,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
142
143
. chain ( bounds. iter ( ) . cloned ( ) )
143
144
. collect ( ) ,
144
145
} ) ;
145
- ( syn:: parse2 ( quote ! { -> } ) . unwrap ( ) , ty)
146
+ let mut sig = sig. clone ( ) ;
147
+ if default. is_some ( ) {
148
+ add_receiver_bounds ( & mut sig) ;
149
+ }
150
+
151
+ (
152
+ Signature {
153
+ asyncness : None ,
154
+ output : ReturnType :: Type ( syn:: parse2 ( quote ! { -> } ) . unwrap ( ) , Box :: new ( ty) ) ,
155
+ ..sig. clone ( )
156
+ } ,
157
+ fn_item
158
+ . default
159
+ . as_ref ( )
160
+ . map ( |b| syn:: parse2 ( quote ! { { async move #b } } ) . unwrap ( ) ) ,
161
+ )
146
162
} else {
147
163
match & sig. output {
148
164
ReturnType :: Type ( arrow, ty) => match & * * ty {
@@ -151,19 +167,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
151
167
impl_token : it. impl_token ,
152
168
bounds : it. bounds . iter ( ) . chain ( bounds) . cloned ( ) . collect ( ) ,
153
169
} ) ;
154
- ( * arrow, ty)
170
+ (
171
+ Signature {
172
+ output : ReturnType :: Type ( * arrow, Box :: new ( ty) ) ,
173
+ ..sig. clone ( )
174
+ } ,
175
+ fn_item. default . clone ( ) ,
176
+ )
155
177
}
156
178
_ => return item. clone ( ) ,
157
179
} ,
158
180
ReturnType :: Default => return item. clone ( ) ,
159
181
}
160
182
} ;
161
183
TraitItem :: Fn ( TraitItemFn {
162
- sig : Signature {
163
- asyncness : None ,
164
- output : ReturnType :: Type ( arrow, Box :: new ( output) ) ,
165
- ..sig. clone ( )
166
- } ,
184
+ sig,
185
+ default,
167
186
..fn_item. clone ( )
168
187
} )
169
188
}
@@ -182,9 +201,29 @@ fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
182
201
blanket_generics
183
202
. params
184
203
. push ( GenericParam :: Type ( blanket_bound) ) ;
185
- let ( blanket_impl_generics, _ty, blanket_where_clause) = & blanket_generics. split_for_impl ( ) ;
204
+ let ( blanket_impl_generics, _ty, blanket_where_clause) = & mut blanket_generics. split_for_impl ( ) ;
205
+ let self_is_sync = tr. items . iter ( ) . any ( |item| {
206
+ matches ! (
207
+ item,
208
+ TraitItem :: Fn ( TraitItemFn {
209
+ default : Some ( _) ,
210
+ ..
211
+ } )
212
+ )
213
+ } ) ;
214
+
215
+ let mut blanket_where_clause = blanket_where_clause
216
+ . map ( |w| w. predicates . clone ( ) )
217
+ . unwrap_or_default ( ) ;
218
+
219
+ if self_is_sync {
220
+ blanket_where_clause. push ( parse_quote ! { for <' s> & ' s Self : Send } ) ;
221
+ }
222
+
186
223
quote ! {
187
- impl #blanket_impl_generics #orig #orig_ty_generics for #blanket #blanket_where_clause
224
+ impl #blanket_impl_generics #orig #orig_ty_generics for #blanket
225
+ where
226
+ #blanket_where_clause
188
227
{
189
228
#( #items) *
190
229
}
@@ -229,6 +268,7 @@ fn blanket_impl_item(
229
268
} else {
230
269
quote ! { }
231
270
} ;
271
+
232
272
quote ! {
233
273
#sig {
234
274
<Self as #variant #trait_ty_generics>:: #ident( #( #args) , * ) #maybe_await
@@ -246,3 +286,47 @@ fn blanket_impl_item(
246
286
_ => Error :: new_spanned ( item, "unsupported item type" ) . into_compile_error ( ) ,
247
287
}
248
288
}
289
+
290
+ fn add_receiver_bounds ( sig : & mut Signature ) {
291
+ let Some ( FnArg :: Receiver ( Receiver { ty, reference, .. } ) ) = sig. inputs . first_mut ( ) else {
292
+ return ;
293
+ } ;
294
+ let Type :: Reference (
295
+ recv_ty @ TypeReference {
296
+ mutability : None , ..
297
+ } ,
298
+ ) = & mut * * ty
299
+ else {
300
+ return ;
301
+ } ;
302
+ let Some ( ( _and, lt) ) = reference else {
303
+ return ;
304
+ } ;
305
+
306
+ let lifetime = syn:: Lifetime {
307
+ apostrophe : Span :: mixed_site ( ) ,
308
+ ident : Ident :: new ( "the_self_lt" , Span :: mixed_site ( ) ) ,
309
+ } ;
310
+ sig. generics . params . insert (
311
+ 0 ,
312
+ syn:: GenericParam :: Lifetime ( syn:: LifetimeParam {
313
+ lifetime : lifetime. clone ( ) ,
314
+ colon_token : None ,
315
+ bounds : Default :: default ( ) ,
316
+ attrs : Default :: default ( ) ,
317
+ } ) ,
318
+ ) ;
319
+ recv_ty. lifetime = Some ( lifetime. clone ( ) ) ;
320
+ * lt = Some ( lifetime) ;
321
+ let predicate = parse_quote ! { #recv_ty: Send } ;
322
+
323
+ if let Some ( wh) = & mut sig. generics . where_clause {
324
+ wh. predicates . push ( predicate) ;
325
+ } else {
326
+ let where_clause = WhereClause {
327
+ where_token : Token ! [ where ] ( Span :: mixed_site ( ) ) ,
328
+ predicates : Punctuated :: from_iter ( [ predicate] ) ,
329
+ } ;
330
+ sig. generics . where_clause = Some ( where_clause) ;
331
+ }
332
+ }
0 commit comments