Skip to content

Commit f719ebe

Browse files
committed
Add support for defaulted methods
1 parent 91bd9eb commit f719ebe

File tree

3 files changed

+132
-32
lines changed

3 files changed

+132
-32
lines changed

Cargo.lock

+16-16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

trait-variant/examples/variant.rs

+16
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,25 @@ pub trait LocalIntFactory {
1717
Self: 'a;
1818

1919
async fn make(&self, x: u32, y: &str) -> i32;
20+
async fn make_mut(&mut self);
2021
fn stream(&self) -> impl Iterator<Item = i32>;
2122
fn call(&self) -> u32;
2223
fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>;
24+
async fn defaulted(&self) -> i32 {
25+
self.make(10, "10").await
26+
}
27+
async fn defaulted_mut(&mut self) -> i32 {
28+
self.make(10, "10").await
29+
}
30+
async fn defaulted_mut_2(&mut self) {
31+
self.make_mut().await
32+
}
33+
async fn defaulted_move(self) -> i32
34+
where
35+
Self: Sized,
36+
{
37+
self.make(10, "10").await
38+
}
2339
}
2440

2541
#[allow(dead_code)]

trait-variant/src/variant.rs

+100-16
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88

99
use std::iter;
1010

11-
use proc_macro2::TokenStream;
11+
use proc_macro2::{Span, TokenStream};
1212
use quote::quote;
1313
use syn::{
1414
parse::{Parse, ParseStream},
15-
parse_macro_input, parse_quote,
15+
parse_macro_input,
1616
punctuated::Punctuated,
1717
token::Plus,
18-
Error, FnArg, GenericParam, Ident, ItemTrait, Pat, PatType, Result, ReturnType, Signature,
19-
Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeGenerics,
20-
TypeImplTrait, TypeParam, TypeParamBound,
18+
Error, FnArg, GenericParam, Ident, ItemTrait, Pat, PatType, Receiver, Result, ReturnType,
19+
Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type,
20+
TypeGenerics, TypeImplTrait, TypeParam, TypeParamBound, WhereClause,
2121
};
22+
use syn::{parse_quote, TypeReference};
2223

2324
struct Attrs {
2425
variant: MakeVariant,
@@ -127,10 +128,10 @@ fn mk_variant(
127128

128129
// Transforms one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds.
129130
fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
130-
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
131+
let TraitItem::Fn(fn_item @ TraitItemFn { sig, default, .. }) = item else {
131132
return item.clone();
132133
};
133-
let (arrow, output) = if sig.asyncness.is_some() {
134+
let (sig, default) = if sig.asyncness.is_some() {
134135
let orig = match &sig.output {
135136
ReturnType::Default => quote! { () },
136137
ReturnType::Type(_, ty) => quote! { #ty },
@@ -142,7 +143,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
142143
.chain(bounds.iter().cloned())
143144
.collect(),
144145
});
145-
(syn::parse2(quote! { -> }).unwrap(), ty)
146+
let mut sig = sig.clone();
147+
if default.is_some() {
148+
add_receiver_bounds(&mut sig);
149+
}
150+
151+
(
152+
Signature {
153+
asyncness: None,
154+
output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)),
155+
..sig.clone()
156+
},
157+
fn_item
158+
.default
159+
.as_ref()
160+
.map(|b| syn::parse2(quote! { { async move #b } }).unwrap()),
161+
)
146162
} else {
147163
match &sig.output {
148164
ReturnType::Type(arrow, ty) => match &**ty {
@@ -151,19 +167,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
151167
impl_token: it.impl_token,
152168
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
153169
});
154-
(*arrow, ty)
170+
(
171+
Signature {
172+
output: ReturnType::Type(*arrow, Box::new(ty)),
173+
..sig.clone()
174+
},
175+
fn_item.default.clone(),
176+
)
155177
}
156178
_ => return item.clone(),
157179
},
158180
ReturnType::Default => return item.clone(),
159181
}
160182
};
161183
TraitItem::Fn(TraitItemFn {
162-
sig: Signature {
163-
asyncness: None,
164-
output: ReturnType::Type(arrow, Box::new(output)),
165-
..sig.clone()
166-
},
184+
sig,
185+
default,
167186
..fn_item.clone()
168187
})
169188
}
@@ -182,9 +201,29 @@ fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
182201
blanket_generics
183202
.params
184203
.push(GenericParam::Type(blanket_bound));
185-
let (blanket_impl_generics, _ty, blanket_where_clause) = &blanket_generics.split_for_impl();
204+
let (blanket_impl_generics, _ty, blanket_where_clause) = &mut blanket_generics.split_for_impl();
205+
let self_is_sync = tr.items.iter().any(|item| {
206+
matches!(
207+
item,
208+
TraitItem::Fn(TraitItemFn {
209+
default: Some(_),
210+
..
211+
})
212+
)
213+
});
214+
215+
let mut blanket_where_clause = blanket_where_clause
216+
.map(|w| w.predicates.clone())
217+
.unwrap_or_default();
218+
219+
if self_is_sync {
220+
blanket_where_clause.push(parse_quote! { for<'s> &'s Self: Send });
221+
}
222+
186223
quote! {
187-
impl #blanket_impl_generics #orig #orig_ty_generics for #blanket #blanket_where_clause
224+
impl #blanket_impl_generics #orig #orig_ty_generics for #blanket
225+
where
226+
#blanket_where_clause
188227
{
189228
#(#items)*
190229
}
@@ -229,6 +268,7 @@ fn blanket_impl_item(
229268
} else {
230269
quote! {}
231270
};
271+
232272
quote! {
233273
#sig {
234274
<Self as #variant #trait_ty_generics>::#ident(#(#args),*)#maybe_await
@@ -246,3 +286,47 @@ fn blanket_impl_item(
246286
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
247287
}
248288
}
289+
290+
fn add_receiver_bounds(sig: &mut Signature) {
291+
let Some(FnArg::Receiver(Receiver { ty, reference, .. })) = sig.inputs.first_mut() else {
292+
return;
293+
};
294+
let Type::Reference(
295+
recv_ty @ TypeReference {
296+
mutability: None, ..
297+
},
298+
) = &mut **ty
299+
else {
300+
return;
301+
};
302+
let Some((_and, lt)) = reference else {
303+
return;
304+
};
305+
306+
let lifetime = syn::Lifetime {
307+
apostrophe: Span::mixed_site(),
308+
ident: Ident::new("the_self_lt", Span::mixed_site()),
309+
};
310+
sig.generics.params.insert(
311+
0,
312+
syn::GenericParam::Lifetime(syn::LifetimeParam {
313+
lifetime: lifetime.clone(),
314+
colon_token: None,
315+
bounds: Default::default(),
316+
attrs: Default::default(),
317+
}),
318+
);
319+
recv_ty.lifetime = Some(lifetime.clone());
320+
*lt = Some(lifetime);
321+
let predicate = parse_quote! { #recv_ty: Send };
322+
323+
if let Some(wh) = &mut sig.generics.where_clause {
324+
wh.predicates.push(predicate);
325+
} else {
326+
let where_clause = WhereClause {
327+
where_token: Token![where](Span::mixed_site()),
328+
predicates: Punctuated::from_iter([predicate]),
329+
};
330+
sig.generics.where_clause = Some(where_clause);
331+
}
332+
}

0 commit comments

Comments
 (0)