Skip to content

Commit 085a311

Browse files
committed
Auto merge of #15000 - lowr:fix/builtin-derive-bound-for-assoc, r=HKalbasi
fix: only generate trait bound for associated types in field types Given the following definitions: ```rust trait Trait { type A; type B; type C; } #[derive(Clone)] struct S<T: Trait> where T::A: Send, { qualified: <T as Trait>::B, shorthand: T::C, } ``` we currently expand the derive macro to: ```rust impl<T> Clone for S<T> where T: Trait + Clone, T::A: Clone, T::B: Clone, T::C: Clone, { /* ... */ } ``` This does not match how rustc expands it. Specifically, `Clone` bounds for `T::A` and `T::B` should not be generated. The criteria for associated types to get bound seem to be 1) the associated type appears as part of field types AND 2) it's written in the shorthand form. I have no idea why rustc doesn't consider qualified associated types (there's even a comment that suggests they should be considered; see rust-lang/rust#50730), but it's important to follow rustc.
2 parents 1c25885 + 4f0c6fa commit 085a311

File tree

3 files changed

+139
-71
lines changed

3 files changed

+139
-71
lines changed

crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs

+60
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,66 @@ impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Com
114114
);
115115
}
116116

117+
#[test]
118+
fn test_clone_expand_with_associated_types() {
119+
check(
120+
r#"
121+
//- minicore: derive, clone
122+
trait Trait {
123+
type InWc;
124+
type InFieldQualified;
125+
type InFieldShorthand;
126+
type InGenericArg;
127+
}
128+
trait Marker {}
129+
struct Vec<T>(T);
130+
131+
#[derive(Clone)]
132+
struct Foo<T: Trait>
133+
where
134+
<T as Trait>::InWc: Marker,
135+
{
136+
qualified: <T as Trait>::InFieldQualified,
137+
shorthand: T::InFieldShorthand,
138+
generic: Vec<T::InGenericArg>,
139+
}
140+
"#,
141+
expect![[r#"
142+
trait Trait {
143+
type InWc;
144+
type InFieldQualified;
145+
type InFieldShorthand;
146+
type InGenericArg;
147+
}
148+
trait Marker {}
149+
struct Vec<T>(T);
150+
151+
#[derive(Clone)]
152+
struct Foo<T: Trait>
153+
where
154+
<T as Trait>::InWc: Marker,
155+
{
156+
qualified: <T as Trait>::InFieldQualified,
157+
shorthand: T::InFieldShorthand,
158+
generic: Vec<T::InGenericArg>,
159+
}
160+
161+
impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
162+
fn clone(&self ) -> Self {
163+
match self {
164+
Foo {
165+
qualified: qualified, shorthand: shorthand, generic: generic,
166+
}
167+
=>Foo {
168+
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
169+
}
170+
,
171+
}
172+
}
173+
}"#]],
174+
);
175+
}
176+
117177
#[test]
118178
fn test_clone_expand_with_const_generics() {
119179
check(

crates/hir-expand/src/builtin_derive_macro.rs

+74-66
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@ use ::tt::Ident;
44
use base_db::{CrateOrigin, LangCrateOrigin};
55
use itertools::izip;
66
use mbe::TokenMap;
7-
use std::collections::HashSet;
7+
use rustc_hash::FxHashSet;
88
use stdx::never;
99
use tracing::debug;
1010

11-
use crate::tt::{self, TokenId};
12-
use syntax::{
13-
ast::{
14-
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
15-
HasTypeBounds, PathType,
16-
},
17-
match_ast,
11+
use crate::{
12+
name::{AsName, Name},
13+
tt::{self, TokenId},
14+
};
15+
use syntax::ast::{
16+
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
1817
};
1918

2019
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
@@ -201,41 +200,54 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
201200
debug!("no module item parsed");
202201
ExpandError::Other("no item found".into())
203202
})?;
204-
let node = item.syntax();
205-
let (name, params, shape) = match_ast! {
206-
match node {
207-
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
208-
ast::Enum(it) => {
209-
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
210-
(
211-
it.name(),
212-
it.generic_param_list(),
213-
AdtShape::Enum {
214-
default_variant,
215-
variants: it.variant_list()
216-
.into_iter()
217-
.flat_map(|x| x.variants())
218-
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
219-
}
220-
)
221-
},
222-
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
223-
_ => {
224-
debug!("unexpected node is {:?}", node);
225-
return Err(ExpandError::Other("expected struct, enum or union".into()))
226-
},
203+
let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
204+
debug!("expected adt, found: {:?}", item);
205+
ExpandError::Other("expected struct, enum or union".into())
206+
})?;
207+
let (name, generic_param_list, shape) = match &adt {
208+
ast::Adt::Struct(it) => (
209+
it.name(),
210+
it.generic_param_list(),
211+
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
212+
),
213+
ast::Adt::Enum(it) => {
214+
let default_variant = it
215+
.variant_list()
216+
.into_iter()
217+
.flat_map(|x| x.variants())
218+
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
219+
(
220+
it.name(),
221+
it.generic_param_list(),
222+
AdtShape::Enum {
223+
default_variant,
224+
variants: it
225+
.variant_list()
226+
.into_iter()
227+
.flat_map(|x| x.variants())
228+
.map(|x| {
229+
Ok((
230+
name_to_token(&token_map, x.name())?,
231+
VariantShape::from(x.field_list(), &token_map)?,
232+
))
233+
})
234+
.collect::<Result<_, ExpandError>>()?,
235+
},
236+
)
227237
}
238+
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
228239
};
229-
let mut param_type_set: HashSet<String> = HashSet::new();
230-
let param_types = params
240+
241+
let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
242+
let param_types = generic_param_list
231243
.into_iter()
232244
.flat_map(|param_list| param_list.type_or_const_params())
233245
.map(|param| {
234246
let name = {
235247
let this = param.name();
236248
match this {
237249
Some(x) => {
238-
param_type_set.insert(x.to_string());
250+
param_type_set.insert(x.as_name());
239251
mbe::syntax_node_to_token_tree(x.syntax()).0
240252
}
241253
None => tt::Subtree::empty(),
@@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
259271
(name, ty, bounds)
260272
})
261273
.collect();
262-
let is_associated_type = |p: &PathType| {
263-
if let Some(p) = p.path() {
264-
if let Some(parent) = p.qualifier() {
265-
if let Some(x) = parent.segment() {
266-
if let Some(x) = x.path_type() {
267-
if let Some(x) = x.path() {
268-
if let Some(pname) = x.as_single_name_ref() {
269-
if param_type_set.contains(&pname.to_string()) {
270-
// <T as Trait>::Assoc
271-
return true;
272-
}
273-
}
274-
}
275-
}
276-
}
277-
if let Some(pname) = parent.as_single_name_ref() {
278-
if param_type_set.contains(&pname.to_string()) {
279-
// T::Assoc
280-
return true;
281-
}
282-
}
283-
}
284-
}
285-
false
274+
275+
// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
276+
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
277+
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
278+
// does not do that for some unknown reason.
279+
//
280+
// See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
281+
// [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
282+
283+
// It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
284+
// `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
285+
// we should not inspect `ast::PathType`s in parameter bounds and where clauses.
286+
let field_list = match adt {
287+
ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
288+
ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
289+
ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
286290
};
287-
let associated_types = node
288-
.descendants()
289-
.filter_map(PathType::cast)
290-
.filter(is_associated_type)
291+
let associated_types = field_list
292+
.into_iter()
293+
.flat_map(|it| it.descendants())
294+
.filter_map(ast::PathType::cast)
295+
.filter_map(|p| {
296+
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
297+
param_type_set.contains(&name).then_some(p)
298+
})
291299
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
292-
.collect::<Vec<_>>();
300+
.collect();
293301
let name_token = name_to_token(&token_map, name)?;
294302
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
295303
}
@@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
334342
/// }
335343
/// ```
336344
///
337-
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
345+
/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
338346
/// therefore does not get bound by the derived trait.
339347
fn expand_simple_derive(
340348
tt: &tt::Subtree,
341349
trait_path: tt::Subtree,
342-
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
350+
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
343351
) -> ExpandResult<tt::Subtree> {
344352
let info = match parse_adt(tt) {
345353
Ok(info) => info,
346354
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
347355
};
348-
let trait_body = trait_body(&info);
356+
let trait_body = make_trait_body(&info);
349357
let mut where_block = vec![];
350358
let (params, args): (Vec<_>, Vec<_>) = info
351359
.param_types

crates/hir-ty/src/tests/traits.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -4335,8 +4335,9 @@ fn derive_macro_bounds() {
43354335
#[derive(Clone)]
43364336
struct AssocGeneric<T: Tr>(T::Assoc);
43374337
4338-
#[derive(Clone)]
4339-
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
4338+
// Currently rustc does not accept this.
4339+
// #[derive(Clone)]
4340+
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
43404341
43414342
#[derive(Clone)]
43424343
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
@@ -4361,9 +4362,8 @@ fn derive_macro_bounds() {
43614362
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
43624363
let x = x.clone();
43634364
//^ &AssocGeneric<Copy>
4364-
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
4365-
let x = x.clone();
4366-
//^ &AssocGeneric2<Copy>
4365+
// let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
4366+
// let x = x.clone();
43674367
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
43684368
let x = x.clone();
43694369
//^ &AssocGeneric3<Copy>

0 commit comments

Comments
 (0)