Skip to content

Commit de24caa

Browse files
committed
Add support for defaulted methods
1 parent 90c80bd commit de24caa

File tree

2 files changed

+113
-15
lines changed

2 files changed

+113
-15
lines changed

trait-variant/examples/variant.rs

+16
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,25 @@ pub trait LocalIntFactory {
1717
Self: 'a;
1818

1919
async fn make(&self, x: u32, y: &str) -> i32;
20+
async fn make_mut(&mut self);
2021
fn stream(&self) -> impl Iterator<Item = i32>;
2122
fn call(&self) -> u32;
2223
fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>;
24+
async fn defaulted(&self) -> i32 {
25+
self.make(10, "10").await
26+
}
27+
async fn defaulted_mut(&mut self) -> i32 {
28+
self.make(10, "10").await
29+
}
30+
async fn defaulted_mut_2(&mut self) {
31+
self.make_mut().await
32+
}
33+
async fn defaulted_move(self) -> i32
34+
where
35+
Self: Sized,
36+
{
37+
self.make(10, "10").await
38+
}
2339
}
2440

2541
#[allow(dead_code)]

trait-variant/src/variant.rs

+97-15
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
use std::iter;
1010

11-
use proc_macro2::TokenStream;
11+
use proc_macro2::{Span, TokenStream};
1212
use quote::{quote, ToTokens};
1313
use syn::{
1414
parse::{Parse, ParseStream},
15-
parse_macro_input,
15+
parse_macro_input, parse_quote,
1616
punctuated::Punctuated,
1717
token::{Comma, Plus},
18-
Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Result,
19-
ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn,
20-
TraitItemType, Type, TypeImplTrait, TypeParamBound,
18+
Error, FnArg, GenericParam, Generics, Ident, ItemTrait, Lifetime, Pat, PatType, Receiver,
19+
Result, ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn,
20+
TraitItemType, Type, TypeImplTrait, TypeParamBound, TypeReference, WhereClause,
2121
};
2222

2323
struct Attrs {
@@ -119,10 +119,10 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
119119
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
120120
// fn call(&self) -> u32;
121121
// }
122-
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
122+
let TraitItem::Fn(fn_item @ TraitItemFn { sig, default, .. }) = item else {
123123
return item.clone();
124124
};
125-
let (arrow, output) = if sig.asyncness.is_some() {
125+
let (sig, default) = if sig.asyncness.is_some() {
126126
let orig = match &sig.output {
127127
ReturnType::Default => quote! { () },
128128
ReturnType::Type(_, ty) => quote! { #ty },
@@ -134,7 +134,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
134134
.chain(bounds.iter().cloned())
135135
.collect(),
136136
});
137-
(syn::parse2(quote! { -> }).unwrap(), ty)
137+
let mut sig = sig.clone();
138+
if default.is_some() {
139+
add_receiver_bounds(&mut sig);
140+
}
141+
142+
(
143+
Signature {
144+
asyncness: None,
145+
output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)),
146+
..sig.clone()
147+
},
148+
fn_item
149+
.default
150+
.as_ref()
151+
.map(|b| syn::parse2(quote! { { async move #b } }).unwrap()),
152+
)
138153
} else {
139154
match &sig.output {
140155
ReturnType::Type(arrow, ty) => match &**ty {
@@ -143,19 +158,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
143158
impl_token: it.impl_token,
144159
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
145160
});
146-
(*arrow, ty)
161+
(
162+
Signature {
163+
output: ReturnType::Type(*arrow, Box::new(ty)),
164+
..sig.clone()
165+
},
166+
fn_item.default.clone(),
167+
)
147168
}
148169
_ => return item.clone(),
149170
},
150171
ReturnType::Default => return item.clone(),
151172
}
152173
};
153174
TraitItem::Fn(TraitItemFn {
154-
sig: Signature {
155-
asyncness: None,
156-
output: ReturnType::Type(arrow, Box::new(output)),
157-
..sig.clone()
158-
},
175+
sig,
176+
default,
159177
..fn_item.clone()
160178
})
161179
}
@@ -184,7 +202,26 @@ fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
184202
.items
185203
.iter()
186204
.map(|item| blanket_impl_item(item, variant, &generic_names));
187-
let where_clauses = tr.generics.where_clause.as_ref().map(|wh| &wh.predicates);
205+
let mut where_clauses = tr
206+
.generics
207+
.where_clause
208+
.as_ref()
209+
.map(|wh| wh.predicates.clone())
210+
.unwrap_or_default();
211+
let self_is_sync = tr.items.iter().any(|item| {
212+
matches!(
213+
item,
214+
TraitItem::Fn(TraitItemFn {
215+
default: Some(_),
216+
..
217+
})
218+
)
219+
});
220+
221+
if self_is_sync {
222+
where_clauses.push(parse_quote! { for<'s> &'s Self: Send });
223+
}
224+
188225
quote! {
189226
impl<#generics #trailing_comma TraitVariantBlanketType> #orig<#generic_names>
190227
for TraitVariantBlanketType
@@ -249,6 +286,7 @@ fn blanket_impl_item(
249286
} else {
250287
quote! {}
251288
};
289+
252290
quote! {
253291
#sig {
254292
<Self as #variant<#generic_names>>::#ident(#(#args),*)#maybe_await
@@ -272,3 +310,47 @@ fn blanket_impl_item(
272310
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
273311
}
274312
}
313+
314+
fn add_receiver_bounds(sig: &mut Signature) {
315+
let Some(FnArg::Receiver(Receiver { ty, reference, .. })) = sig.inputs.first_mut() else {
316+
return;
317+
};
318+
let Type::Reference(
319+
recv_ty @ TypeReference {
320+
mutability: None, ..
321+
},
322+
) = &mut **ty
323+
else {
324+
return;
325+
};
326+
let Some((_and, lt)) = reference else {
327+
return;
328+
};
329+
330+
let lifetime = syn::Lifetime {
331+
apostrophe: Span::mixed_site(),
332+
ident: Ident::new("the_self_lt", Span::mixed_site()),
333+
};
334+
sig.generics.params.insert(
335+
0,
336+
syn::GenericParam::Lifetime(syn::LifetimeParam {
337+
lifetime: lifetime.clone(),
338+
colon_token: None,
339+
bounds: Default::default(),
340+
attrs: Default::default(),
341+
}),
342+
);
343+
recv_ty.lifetime = Some(lifetime.clone());
344+
*lt = Some(lifetime);
345+
let predicate = parse_quote! { #recv_ty: Send };
346+
347+
if let Some(wh) = &mut sig.generics.where_clause {
348+
wh.predicates.push(predicate);
349+
} else {
350+
let where_clause = WhereClause {
351+
where_token: Token![where](Span::mixed_site()),
352+
predicates: Punctuated::from_iter([predicate]),
353+
};
354+
sig.generics.where_clause = Some(where_clause);
355+
}
356+
}

0 commit comments

Comments
 (0)