Skip to content

Commit c8066eb

Browse files
Merge #8201
8201: Fix recursive macro statements expansion r=edwin0cheng a=edwin0cheng This PR attempts to properly handle macro statement expansion by implementing the following: 1. Merge macro expanded statements to parent scope statements. 2. Add a new hir `Expr::MacroStmts` for handle tail expression infer. PS : The scope of macro expanded statements are so strange that it took more time than I thought to understand and implement it :( Fixes #8171 Co-authored-by: Edwin Cheng <[email protected]>
2 parents 9c9376c + 8ce15b0 commit c8066eb

File tree

7 files changed

+119
-70
lines changed

7 files changed

+119
-70
lines changed

crates/hir_def/src/body/lower.rs

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ pub(super) fn lower(
7474
_c: Count::new(),
7575
},
7676
expander,
77+
statements_in_scope: Vec::new(),
7778
}
7879
.collect(params, body)
7980
}
@@ -83,6 +84,7 @@ struct ExprCollector<'a> {
8384
expander: Expander,
8485
body: Body,
8586
source_map: BodySourceMap,
87+
statements_in_scope: Vec<Statement>,
8688
}
8789

8890
impl ExprCollector<'_> {
@@ -533,15 +535,13 @@ impl ExprCollector<'_> {
533535
ids[0]
534536
}
535537
ast::Expr::MacroStmts(e) => {
536-
// FIXME: these statements should be held by some hir containter
537-
for stmt in e.statements() {
538-
self.collect_stmt(stmt);
539-
}
540-
if let Some(expr) = e.expr() {
541-
self.collect_expr(expr)
542-
} else {
543-
self.alloc_expr(Expr::Missing, syntax_ptr)
544-
}
538+
e.statements().for_each(|s| self.collect_stmt(s));
539+
let tail = e
540+
.expr()
541+
.map(|e| self.collect_expr(e))
542+
.unwrap_or_else(|| self.alloc_expr(Expr::Missing, syntax_ptr.clone()));
543+
544+
self.alloc_expr(Expr::MacroStmts { tail }, syntax_ptr)
545545
}
546546
})
547547
}
@@ -618,58 +618,54 @@ impl ExprCollector<'_> {
618618
}
619619
}
620620

621-
fn collect_stmt(&mut self, s: ast::Stmt) -> Option<Vec<Statement>> {
622-
let stmt = match s {
621+
fn collect_stmt(&mut self, s: ast::Stmt) {
622+
match s {
623623
ast::Stmt::LetStmt(stmt) => {
624-
self.check_cfg(&stmt)?;
625-
624+
if self.check_cfg(&stmt).is_none() {
625+
return;
626+
}
626627
let pat = self.collect_pat_opt(stmt.pat());
627628
let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it));
628629
let initializer = stmt.initializer().map(|e| self.collect_expr(e));
629-
vec![Statement::Let { pat, type_ref, initializer }]
630+
self.statements_in_scope.push(Statement::Let { pat, type_ref, initializer });
630631
}
631632
ast::Stmt::ExprStmt(stmt) => {
632-
self.check_cfg(&stmt)?;
633+
if self.check_cfg(&stmt).is_none() {
634+
return;
635+
}
633636

634637
// Note that macro could be expended to multiple statements
635638
if let Some(ast::Expr::MacroCall(m)) = stmt.expr() {
636639
let syntax_ptr = AstPtr::new(&stmt.expr().unwrap());
637-
let mut stmts = vec![];
638640

639641
self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| {
640642
match expansion {
641643
Some(expansion) => {
642644
let statements: ast::MacroStmts = expansion;
643645

644-
statements.statements().for_each(|stmt| {
645-
if let Some(mut r) = this.collect_stmt(stmt) {
646-
stmts.append(&mut r);
647-
}
648-
});
646+
statements.statements().for_each(|stmt| this.collect_stmt(stmt));
649647
if let Some(expr) = statements.expr() {
650-
stmts.push(Statement::Expr(this.collect_expr(expr)));
648+
let expr = this.collect_expr(expr);
649+
this.statements_in_scope.push(Statement::Expr(expr));
651650
}
652651
}
653652
None => {
654-
stmts.push(Statement::Expr(
655-
this.alloc_expr(Expr::Missing, syntax_ptr.clone()),
656-
));
653+
let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone());
654+
this.statements_in_scope.push(Statement::Expr(expr));
657655
}
658656
}
659657
});
660-
stmts
661658
} else {
662-
vec![Statement::Expr(self.collect_expr_opt(stmt.expr()))]
659+
let expr = self.collect_expr_opt(stmt.expr());
660+
self.statements_in_scope.push(Statement::Expr(expr));
663661
}
664662
}
665663
ast::Stmt::Item(item) => {
666-
self.check_cfg(&item)?;
667-
668-
return None;
664+
if self.check_cfg(&item).is_none() {
665+
return;
666+
}
669667
}
670-
};
671-
672-
Some(stmt)
668+
}
673669
}
674670

675671
fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId {
@@ -685,10 +681,12 @@ impl ExprCollector<'_> {
685681
let module = if has_def_map { def_map.root() } else { self.expander.module };
686682
let prev_def_map = mem::replace(&mut self.expander.def_map, def_map);
687683
let prev_local_module = mem::replace(&mut self.expander.module, module);
684+
let prev_statements = std::mem::take(&mut self.statements_in_scope);
685+
686+
block.statements().for_each(|s| self.collect_stmt(s));
688687

689-
let statements =
690-
block.statements().filter_map(|s| self.collect_stmt(s)).flatten().collect();
691688
let tail = block.tail_expr().map(|e| self.collect_expr(e));
689+
let statements = std::mem::replace(&mut self.statements_in_scope, prev_statements);
692690
let syntax_node_ptr = AstPtr::new(&block.into());
693691
let expr_id = self.alloc_expr(
694692
Expr::Block { id: block_id, statements, tail, label: None },

crates/hir_def/src/expr.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ pub enum Expr {
171171
Unsafe {
172172
body: ExprId,
173173
},
174+
MacroStmts {
175+
tail: ExprId,
176+
},
174177
Array(Array),
175178
Literal(Literal),
176179
}
@@ -357,6 +360,7 @@ impl Expr {
357360
f(*repeat)
358361
}
359362
},
363+
Expr::MacroStmts { tail } => f(*tail),
360364
Expr::Literal(_) => {}
361365
}
362366
}

crates/hir_def/src/item_tree.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,6 @@ impl ItemTree {
110110
// still need to collect inner items.
111111
ctx.lower_inner_items(e.syntax())
112112
},
113-
ast::ExprStmt(stmt) => {
114-
// Macros can expand to stmt. We return an empty item tree in this case, but
115-
// still need to collect inner items.
116-
ctx.lower_inner_items(stmt.syntax())
117-
},
118-
ast::Item(item) => {
119-
// Macros can expand to stmt and other item, and we add it as top level item
120-
ctx.lower_single_item(item)
121-
},
122113
_ => {
123114
panic!("cannot create item tree from {:?} {}", syntax, syntax);
124115
},

crates/hir_def/src/item_tree/lower.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,6 @@ impl Ctx {
8787
self.tree
8888
}
8989

90-
pub(super) fn lower_single_item(mut self, item: ast::Item) -> ItemTree {
91-
self.tree.top_level = self
92-
.lower_mod_item(&item, false)
93-
.map(|item| item.0)
94-
.unwrap_or_else(|| Default::default());
95-
self.tree
96-
}
97-
9890
pub(super) fn lower_inner_items(mut self, within: &SyntaxNode) -> ItemTree {
9991
self.collect_inner_items(within);
10092
self.tree

crates/hir_expand/src/db.rs

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ use std::sync::Arc;
55
use base_db::{salsa, SourceDatabase};
66
use mbe::{ExpandError, ExpandResult, MacroRules};
77
use parser::FragmentKind;
8-
use syntax::{algo::diff, ast::NameOwner, AstNode, GreenNode, Parse, SyntaxKind::*, SyntaxNode};
8+
use syntax::{
9+
algo::diff,
10+
ast::{MacroStmts, NameOwner},
11+
AstNode, GreenNode, Parse,
12+
SyntaxKind::*,
13+
SyntaxNode,
14+
};
915

1016
use crate::{
1117
ast_id_map::AstIdMap, hygiene::HygieneFrame, BuiltinDeriveExpander, BuiltinFnLikeExpander,
@@ -340,13 +346,19 @@ fn parse_macro_with_arg(
340346
None => return ExpandResult { value: None, err: result.err },
341347
};
342348

343-
log::debug!("expanded = {}", tt.as_debug_string());
344-
345349
let fragment_kind = to_fragment_kind(db, macro_call_id);
346350

351+
log::debug!("expanded = {}", tt.as_debug_string());
352+
log::debug!("kind = {:?}", fragment_kind);
353+
347354
let (parse, rev_token_map) = match mbe::token_tree_to_syntax_node(&tt, fragment_kind) {
348355
Ok(it) => it,
349356
Err(err) => {
357+
log::debug!(
358+
"failed to parse expanstion to {:?} = {}",
359+
fragment_kind,
360+
tt.as_debug_string()
361+
);
350362
return ExpandResult::only_err(err);
351363
}
352364
};
@@ -362,15 +374,34 @@ fn parse_macro_with_arg(
362374
return ExpandResult::only_err(err);
363375
}
364376
};
365-
366-
if !diff(&node, &call_node.value).is_empty() {
367-
ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) }
368-
} else {
377+
if is_self_replicating(&node, &call_node.value) {
369378
return ExpandResult::only_err(err);
379+
} else {
380+
ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) }
381+
}
382+
}
383+
None => {
384+
log::debug!("parse = {:?}", parse.syntax_node().kind());
385+
ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None }
386+
}
387+
}
388+
}
389+
390+
fn is_self_replicating(from: &SyntaxNode, to: &SyntaxNode) -> bool {
391+
if diff(from, to).is_empty() {
392+
return true;
393+
}
394+
if let Some(stmts) = MacroStmts::cast(from.clone()) {
395+
if stmts.statements().any(|stmt| diff(stmt.syntax(), to).is_empty()) {
396+
return true;
397+
}
398+
if let Some(expr) = stmts.expr() {
399+
if diff(expr.syntax(), to).is_empty() {
400+
return true;
370401
}
371402
}
372-
None => ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None },
373403
}
404+
false
374405
}
375406

376407
fn hygiene_frame(db: &dyn AstDatabase, file_id: HirFileId) -> Arc<HygieneFrame> {
@@ -390,21 +421,15 @@ fn to_fragment_kind(db: &dyn AstDatabase, id: MacroCallId) -> FragmentKind {
390421

391422
let parent = match syn.parent() {
392423
Some(it) => it,
393-
None => {
394-
// FIXME:
395-
// If it is root, which means the parent HirFile
396-
// MacroKindFile must be non-items
397-
// return expr now.
398-
return FragmentKind::Expr;
399-
}
424+
None => return FragmentKind::Statements,
400425
};
401426

402427
match parent.kind() {
403428
MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items,
404-
MACRO_STMTS => FragmentKind::Statement,
429+
MACRO_STMTS => FragmentKind::Statements,
405430
ITEM_LIST => FragmentKind::Items,
406431
LET_STMT => {
407-
// FIXME: Handle Pattern
432+
// FIXME: Handle LHS Pattern
408433
FragmentKind::Expr
409434
}
410435
EXPR_STMT => FragmentKind::Statements,

crates/hir_ty/src/infer/expr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ impl<'a> InferenceContext<'a> {
767767
None => self.table.new_float_var(),
768768
},
769769
},
770+
Expr::MacroStmts { tail } => self.infer_expr(*tail, expected),
770771
};
771772
// use a new type variable if we got unknown here
772773
let ty = self.insert_type_vars_shallow(ty);

crates/hir_ty/src/tests/macros.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,48 @@ fn expr_macro_expanded_in_stmts() {
226226
"#,
227227
expect![[r#"
228228
!0..8 'leta=();': ()
229+
!0..8 'leta=();': ()
230+
!3..4 'a': ()
231+
!5..7 '()': ()
229232
57..84 '{ ...); } }': ()
230233
"#]],
231234
);
232235
}
233236

237+
#[test]
238+
fn recurisve_macro_expanded_in_stmts() {
239+
check_infer(
240+
r#"
241+
macro_rules! ng {
242+
([$($tts:tt)*]) => {
243+
$($tts)*;
244+
};
245+
([$($tts:tt)*] $head:tt $($rest:tt)*) => {
246+
ng! {
247+
[$($tts)* $head] $($rest)*
248+
}
249+
};
250+
}
251+
fn foo() {
252+
ng!([] let a = 3);
253+
let b = a;
254+
}
255+
"#,
256+
expect![[r#"
257+
!0..7 'leta=3;': {unknown}
258+
!0..7 'leta=3;': {unknown}
259+
!0..13 'ng!{[leta=3]}': {unknown}
260+
!0..13 'ng!{[leta=]3}': {unknown}
261+
!0..13 'ng!{[leta]=3}': {unknown}
262+
!3..4 'a': i32
263+
!5..6 '3': i32
264+
196..237 '{ ...= a; }': ()
265+
229..230 'b': i32
266+
233..234 'a': i32
267+
"#]],
268+
);
269+
}
270+
234271
#[test]
235272
fn recursive_inner_item_macro_rules() {
236273
check_infer(
@@ -246,7 +283,8 @@ fn recursive_inner_item_macro_rules() {
246283
"#,
247284
expect![[r#"
248285
!0..1 '1': i32
249-
!0..7 'mac!($)': {unknown}
286+
!0..26 'macro_...>{1};}': {unknown}
287+
!0..26 'macro_...>{1};}': {unknown}
250288
107..143 '{ ...!(); }': ()
251289
129..130 'a': i32
252290
"#]],

0 commit comments

Comments
 (0)