Skip to content

Commit fbe0eb2

Browse files
committed
Add support for defaulted methods
1 parent 90c80bd commit fbe0eb2

File tree

2 files changed

+106
-15
lines changed

2 files changed

+106
-15
lines changed

trait-variant/examples/variant.rs

+16
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,25 @@ pub trait LocalIntFactory {
1717
Self: 'a;
1818

1919
async fn make(&self, x: u32, y: &str) -> i32;
20+
async fn make_mut(&mut self);
2021
fn stream(&self) -> impl Iterator<Item = i32>;
2122
fn call(&self) -> u32;
2223
fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>;
24+
async fn defaulted(&self) -> i32 {
25+
self.make(10, "10").await
26+
}
27+
async fn defaulted_mut(&mut self) -> i32 {
28+
self.make(10, "10").await
29+
}
30+
async fn defaulted_mut_2(&mut self) {
31+
self.make_mut().await
32+
}
33+
async fn defaulted_move(self) -> i32
34+
where
35+
Self: Sized,
36+
{
37+
self.make(10, "10").await
38+
}
2339
}
2440

2541
#[allow(dead_code)]

trait-variant/src/variant.rs

+90-15
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
use std::iter;
1010

11-
use proc_macro2::TokenStream;
11+
use proc_macro2::{Span, TokenStream};
1212
use quote::{quote, ToTokens};
1313
use syn::{
1414
parse::{Parse, ParseStream},
15-
parse_macro_input,
15+
parse_macro_input, parse_quote,
1616
punctuated::Punctuated,
1717
token::{Comma, Plus},
18-
Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Result,
19-
ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn,
20-
TraitItemType, Type, TypeImplTrait, TypeParamBound,
18+
Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Receiver,
19+
Result, ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn,
20+
TraitItemType, Type, TypeImplTrait, TypeParamBound, WhereClause,
2121
};
2222

2323
struct Attrs {
@@ -119,10 +119,10 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
119119
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
120120
// fn call(&self) -> u32;
121121
// }
122-
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
122+
let TraitItem::Fn(fn_item @ TraitItemFn { sig, default, .. }) = item else {
123123
return item.clone();
124124
};
125-
let (arrow, output) = if sig.asyncness.is_some() {
125+
let (sig, default) = if sig.asyncness.is_some() {
126126
let orig = match &sig.output {
127127
ReturnType::Default => quote! { () },
128128
ReturnType::Type(_, ty) => quote! { #ty },
@@ -134,7 +134,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
134134
.chain(bounds.iter().cloned())
135135
.collect(),
136136
});
137-
(syn::parse2(quote! { -> }).unwrap(), ty)
137+
let mut sig = sig.clone();
138+
if default.is_some() {
139+
add_receiver_bounds(&mut sig);
140+
}
141+
142+
(
143+
Signature {
144+
asyncness: None,
145+
output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)),
146+
..sig.clone()
147+
},
148+
fn_item
149+
.default
150+
.as_ref()
151+
.map(|b| syn::parse2(quote! { { async move #b } }).unwrap()),
152+
)
138153
} else {
139154
match &sig.output {
140155
ReturnType::Type(arrow, ty) => match &**ty {
@@ -143,19 +158,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
143158
impl_token: it.impl_token,
144159
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
145160
});
146-
(*arrow, ty)
161+
(
162+
Signature {
163+
output: ReturnType::Type(*arrow, Box::new(ty)),
164+
..sig.clone()
165+
},
166+
fn_item.default.clone(),
167+
)
147168
}
148169
_ => return item.clone(),
149170
},
150171
ReturnType::Default => return item.clone(),
151172
}
152173
};
153174
TraitItem::Fn(TraitItemFn {
154-
sig: Signature {
155-
asyncness: None,
156-
output: ReturnType::Type(arrow, Box::new(output)),
157-
..sig.clone()
158-
},
175+
sig,
176+
default,
159177
..fn_item.clone()
160178
})
161179
}
@@ -184,7 +202,26 @@ fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
184202
.items
185203
.iter()
186204
.map(|item| blanket_impl_item(item, variant, &generic_names));
187-
let where_clauses = tr.generics.where_clause.as_ref().map(|wh| &wh.predicates);
205+
let mut where_clauses = tr
206+
.generics
207+
.where_clause
208+
.as_ref()
209+
.map(|wh| wh.predicates.clone())
210+
.unwrap_or_default();
211+
let self_is_sync = tr.items.iter().any(|item| {
212+
matches!(
213+
item,
214+
TraitItem::Fn(TraitItemFn {
215+
default: Some(_),
216+
..
217+
})
218+
)
219+
});
220+
221+
if self_is_sync {
222+
where_clauses.push(parse_quote! { Self: Sync });
223+
}
224+
188225
quote! {
189226
impl<#generics #trailing_comma TraitVariantBlanketType> #orig<#generic_names>
190227
for TraitVariantBlanketType
@@ -249,6 +286,7 @@ fn blanket_impl_item(
249286
} else {
250287
quote! {}
251288
};
289+
252290
quote! {
253291
#sig {
254292
<Self as #variant<#generic_names>>::#ident(#(#args),*)#maybe_await
@@ -272,3 +310,40 @@ fn blanket_impl_item(
272310
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
273311
}
274312
}
313+
314+
fn add_receiver_bounds(sig: &mut Signature) {
315+
if let Some(FnArg::Receiver(Receiver { ty, reference, .. })) = sig.inputs.first_mut() {
316+
let predicate =
317+
if let (Type::Reference(reference), Some((_and, lt))) = (&mut **ty, reference) {
318+
let lifetime = syn::Lifetime {
319+
apostrophe: Span::mixed_site(),
320+
ident: Ident::new("the_self_lt", Span::mixed_site()),
321+
};
322+
sig.generics.params.insert(
323+
0,
324+
syn::GenericParam::Lifetime(syn::LifetimeParam {
325+
lifetime: lifetime.clone(),
326+
colon_token: None,
327+
bounds: Default::default(),
328+
attrs: Default::default(),
329+
}),
330+
);
331+
reference.lifetime = Some(lifetime.clone());
332+
let predicate = parse_quote! { #reference: Send };
333+
*lt = Some(lifetime);
334+
predicate
335+
} else {
336+
parse_quote! { #ty: Send }
337+
};
338+
339+
if let Some(wh) = &mut sig.generics.where_clause {
340+
wh.predicates.push(predicate);
341+
} else {
342+
let where_clause = WhereClause {
343+
where_token: Token![where](Span::mixed_site()),
344+
predicates: Punctuated::from_iter([predicate]),
345+
};
346+
sig.generics.where_clause = Some(where_clause);
347+
}
348+
}
349+
}

0 commit comments

Comments
 (0)