@@ -4,17 +4,16 @@ use ::tt::Ident;
4
4
use base_db:: { CrateOrigin , LangCrateOrigin } ;
5
5
use itertools:: izip;
6
6
use mbe:: TokenMap ;
7
- use std :: collections :: HashSet ;
7
+ use rustc_hash :: FxHashSet ;
8
8
use stdx:: never;
9
9
use tracing:: debug;
10
10
11
- use crate :: tt:: { self , TokenId } ;
12
- use syntax:: {
13
- ast:: {
14
- self , AstNode , FieldList , HasAttrs , HasGenericParams , HasModuleItem , HasName ,
15
- HasTypeBounds , PathType ,
16
- } ,
17
- match_ast,
11
+ use crate :: {
12
+ name:: { AsName , Name } ,
13
+ tt:: { self , TokenId } ,
14
+ } ;
15
+ use syntax:: ast:: {
16
+ self , AstNode , FieldList , HasAttrs , HasGenericParams , HasModuleItem , HasName , HasTypeBounds ,
18
17
} ;
19
18
20
19
use crate :: { db:: ExpandDatabase , name, quote, ExpandError , ExpandResult , MacroCallId } ;
@@ -201,41 +200,54 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
201
200
debug ! ( "no module item parsed" ) ;
202
201
ExpandError :: Other ( "no item found" . into ( ) )
203
202
} ) ?;
204
- let node = item. syntax ( ) ;
205
- let ( name, params, shape) = match_ast ! {
206
- match node {
207
- ast:: Struct ( it) => ( it. name( ) , it. generic_param_list( ) , AdtShape :: Struct ( VariantShape :: from( it. field_list( ) , & token_map) ?) ) ,
208
- ast:: Enum ( it) => {
209
- let default_variant = it. variant_list( ) . into_iter( ) . flat_map( |x| x. variants( ) ) . position( |x| x. attrs( ) . any( |x| x. simple_name( ) == Some ( "default" . into( ) ) ) ) ;
210
- (
211
- it. name( ) ,
212
- it. generic_param_list( ) ,
213
- AdtShape :: Enum {
214
- default_variant,
215
- variants: it. variant_list( )
216
- . into_iter( )
217
- . flat_map( |x| x. variants( ) )
218
- . map( |x| Ok ( ( name_to_token( & token_map, x. name( ) ) ?, VariantShape :: from( x. field_list( ) , & token_map) ?) ) ) . collect:: <Result <_, ExpandError >>( ) ?
219
- }
220
- )
221
- } ,
222
- ast:: Union ( it) => ( it. name( ) , it. generic_param_list( ) , AdtShape :: Union ) ,
223
- _ => {
224
- debug!( "unexpected node is {:?}" , node) ;
225
- return Err ( ExpandError :: Other ( "expected struct, enum or union" . into( ) ) )
226
- } ,
203
+ let adt = ast:: Adt :: cast ( item. syntax ( ) . clone ( ) ) . ok_or_else ( || {
204
+ debug ! ( "expected adt, found: {:?}" , item) ;
205
+ ExpandError :: Other ( "expected struct, enum or union" . into ( ) )
206
+ } ) ?;
207
+ let ( name, generic_param_list, shape) = match & adt {
208
+ ast:: Adt :: Struct ( it) => (
209
+ it. name ( ) ,
210
+ it. generic_param_list ( ) ,
211
+ AdtShape :: Struct ( VariantShape :: from ( it. field_list ( ) , & token_map) ?) ,
212
+ ) ,
213
+ ast:: Adt :: Enum ( it) => {
214
+ let default_variant = it
215
+ . variant_list ( )
216
+ . into_iter ( )
217
+ . flat_map ( |x| x. variants ( ) )
218
+ . position ( |x| x. attrs ( ) . any ( |x| x. simple_name ( ) == Some ( "default" . into ( ) ) ) ) ;
219
+ (
220
+ it. name ( ) ,
221
+ it. generic_param_list ( ) ,
222
+ AdtShape :: Enum {
223
+ default_variant,
224
+ variants : it
225
+ . variant_list ( )
226
+ . into_iter ( )
227
+ . flat_map ( |x| x. variants ( ) )
228
+ . map ( |x| {
229
+ Ok ( (
230
+ name_to_token ( & token_map, x. name ( ) ) ?,
231
+ VariantShape :: from ( x. field_list ( ) , & token_map) ?,
232
+ ) )
233
+ } )
234
+ . collect :: < Result < _ , ExpandError > > ( ) ?,
235
+ } ,
236
+ )
227
237
}
238
+ ast:: Adt :: Union ( it) => ( it. name ( ) , it. generic_param_list ( ) , AdtShape :: Union ) ,
228
239
} ;
229
- let mut param_type_set: HashSet < String > = HashSet :: new ( ) ;
230
- let param_types = params
240
+
241
+ let mut param_type_set: FxHashSet < Name > = FxHashSet :: default ( ) ;
242
+ let param_types = generic_param_list
231
243
. into_iter ( )
232
244
. flat_map ( |param_list| param_list. type_or_const_params ( ) )
233
245
. map ( |param| {
234
246
let name = {
235
247
let this = param. name ( ) ;
236
248
match this {
237
249
Some ( x) => {
238
- param_type_set. insert ( x. to_string ( ) ) ;
250
+ param_type_set. insert ( x. as_name ( ) ) ;
239
251
mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0
240
252
}
241
253
None => tt:: Subtree :: empty ( ) ,
@@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
259
271
( name, ty, bounds)
260
272
} )
261
273
. collect ( ) ;
262
- let is_associated_type = |p : & PathType | {
263
- if let Some ( p) = p. path ( ) {
264
- if let Some ( parent) = p. qualifier ( ) {
265
- if let Some ( x) = parent. segment ( ) {
266
- if let Some ( x) = x. path_type ( ) {
267
- if let Some ( x) = x. path ( ) {
268
- if let Some ( pname) = x. as_single_name_ref ( ) {
269
- if param_type_set. contains ( & pname. to_string ( ) ) {
270
- // <T as Trait>::Assoc
271
- return true ;
272
- }
273
- }
274
- }
275
- }
276
- }
277
- if let Some ( pname) = parent. as_single_name_ref ( ) {
278
- if param_type_set. contains ( & pname. to_string ( ) ) {
279
- // T::Assoc
280
- return true ;
281
- }
282
- }
283
- }
284
- }
285
- false
274
+
275
+ // For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
276
+ // types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
277
+ // also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
278
+ // does not do that for some unknown reason.
279
+ //
280
+ // See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
281
+ // [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
282
+
283
+ // It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
284
+ // `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
285
+ // we should not inspect `ast::PathType`s in parameter bounds and where clauses.
286
+ let field_list = match adt {
287
+ ast:: Adt :: Enum ( it) => it. variant_list ( ) . map ( |list| list. syntax ( ) . clone ( ) ) ,
288
+ ast:: Adt :: Struct ( it) => it. field_list ( ) . map ( |list| list. syntax ( ) . clone ( ) ) ,
289
+ ast:: Adt :: Union ( it) => it. record_field_list ( ) . map ( |list| list. syntax ( ) . clone ( ) ) ,
286
290
} ;
287
- let associated_types = node
288
- . descendants ( )
289
- . filter_map ( PathType :: cast)
290
- . filter ( is_associated_type)
291
+ let associated_types = field_list
292
+ . into_iter ( )
293
+ . flat_map ( |it| it. descendants ( ) )
294
+ . filter_map ( ast:: PathType :: cast)
295
+ . filter_map ( |p| {
296
+ let name = p. path ( ) ?. qualifier ( ) ?. as_single_name_ref ( ) ?. as_name ( ) ;
297
+ param_type_set. contains ( & name) . then_some ( p)
298
+ } )
291
299
. map ( |x| mbe:: syntax_node_to_token_tree ( x. syntax ( ) ) . 0 )
292
- . collect :: < Vec < _ > > ( ) ;
300
+ . collect ( ) ;
293
301
let name_token = name_to_token ( & token_map, name) ?;
294
302
Ok ( BasicAdtInfo { name : name_token, shape, param_types, associated_types } )
295
303
}
@@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
334
342
/// }
335
343
/// ```
336
344
///
337
- /// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
345
+ /// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
338
346
/// therefore does not get bound by the derived trait.
339
347
fn expand_simple_derive (
340
348
tt : & tt:: Subtree ,
341
349
trait_path : tt:: Subtree ,
342
- trait_body : impl FnOnce ( & BasicAdtInfo ) -> tt:: Subtree ,
350
+ make_trait_body : impl FnOnce ( & BasicAdtInfo ) -> tt:: Subtree ,
343
351
) -> ExpandResult < tt:: Subtree > {
344
352
let info = match parse_adt ( tt) {
345
353
Ok ( info) => info,
346
354
Err ( e) => return ExpandResult :: new ( tt:: Subtree :: empty ( ) , e) ,
347
355
} ;
348
- let trait_body = trait_body ( & info) ;
356
+ let trait_body = make_trait_body ( & info) ;
349
357
let mut where_block = vec ! [ ] ;
350
358
let ( params, args) : ( Vec < _ > , Vec < _ > ) = info
351
359
. param_types
0 commit comments