Skip to content

fix: Prevent wrong invocations of needs_parens_in with non-ancestral "parent"s #19324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions crates/ide-assists/src/handlers/apply_demorgan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
let parent = neg_expr.syntax().parent();
editor = builder.make_editor(neg_expr.syntax());

if parent.is_some_and(|parent| demorganed.needs_parens_in(&parent)) {
if parent.is_some_and(|parent| {
demorganed.needs_parens_in_place_of(&parent, neg_expr.syntax())
}) {
cov_mark::hit!(demorgan_keep_parens_for_op_precedence2);
editor.replace(neg_expr.syntax(), make.expr_paren(demorganed).syntax());
} else {
Expand Down Expand Up @@ -392,15 +394,19 @@ fn f() { !(S <= S || S < S) }

#[test]
fn demorgan_keep_pars_for_op_precedence3() {
check_assist(apply_demorgan, "fn f() { (a || !(b &&$0 c); }", "fn f() { (a || !b || !c; }");
check_assist(
apply_demorgan,
"fn f() { (a || !(b &&$0 c); }",
"fn f() { (a || (!b || !c); }",
);
}

#[test]
fn demorgan_removes_pars_in_eq_precedence() {
fn demorgan_keeps_pars_in_eq_precedence() {
check_assist(
apply_demorgan,
"fn() { let x = a && !(!b |$0| !c); }",
"fn() { let x = a && b && c; }",
"fn() { let x = a && (b && c); }",
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test chages might be a regression, but I think that the changed ones are more consistent with other test cases

}

Expand Down
118 changes: 85 additions & 33 deletions crates/ide-assists/src/handlers/inline_local_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ide_db::{
EditionedFileId, RootDatabase,
};
use syntax::{
ast::{self, AstNode, AstToken, HasName},
ast::{self, syntax_factory::SyntaxFactory, AstNode, AstToken, HasName},
SyntaxElement, TextRange,
};

Expand Down Expand Up @@ -43,22 +43,6 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>)
}?;
let initializer_expr = let_stmt.initializer()?;

let delete_range = delete_let.then(|| {
if let Some(whitespace) = let_stmt
.syntax()
.next_sibling_or_token()
.and_then(SyntaxElement::into_token)
.and_then(ast::Whitespace::cast)
{
TextRange::new(
let_stmt.syntax().text_range().start(),
whitespace.syntax().text_range().end(),
)
} else {
let_stmt.syntax().text_range()
}
});

let wrap_in_parens = references
.into_iter()
.filter_map(|FileReference { range, name, .. }| match name {
Expand All @@ -73,40 +57,60 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>)
}
let usage_node =
name_ref.syntax().ancestors().find(|it| ast::PathExpr::can_cast(it.kind()));
let usage_parent_option = usage_node.and_then(|it| it.parent());
let usage_parent_option = usage_node.as_ref().and_then(|it| it.parent());
let usage_parent = match usage_parent_option {
Some(u) => u,
None => return Some((range, name_ref, false)),
None => return Some((name_ref, false)),
};
Some((range, name_ref, initializer_expr.needs_parens_in(&usage_parent)))
let should_wrap = initializer_expr
.needs_parens_in_place_of(&usage_parent, usage_node.as_ref().unwrap());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let a = 123 < 456;
let b = !a;

In the above code, initializer_expr is 123 < 456 and usage_parent is let b = !a. So, inside needs_parens_in, it assumes that initializer_expr "comes before parent" because its text range is earlier, and decides that it needs no parentheses

Some((name_ref, should_wrap))
})
.collect::<Option<Vec<_>>>()?;

let init_str = initializer_expr.syntax().text().to_string();
let init_in_paren = format!("({init_str})");

let target = match target {
ast::NameOrNameRef::Name(it) => it.syntax().text_range(),
ast::NameOrNameRef::NameRef(it) => it.syntax().text_range(),
ast::NameOrNameRef::Name(it) => it.syntax().clone(),
ast::NameOrNameRef::NameRef(it) => it.syntax().clone(),
};

acc.add(
AssistId("inline_local_variable", AssistKind::RefactorInline),
"Inline variable",
target,
target.text_range(),
move |builder| {
if let Some(range) = delete_range {
builder.delete(range);
let mut editor = builder.make_editor(&target);
if delete_let {
editor.delete(let_stmt.syntax());
if let Some(whitespace) = let_stmt
.syntax()
.next_sibling_or_token()
.and_then(SyntaxElement::into_token)
.and_then(ast::Whitespace::cast)
{
editor.delete(whitespace.syntax());
}
}
for (range, name, should_wrap) in wrap_in_parens {
let replacement = if should_wrap { &init_in_paren } else { &init_str };
if ast::RecordExprField::for_field_name(&name).is_some() {

let make = SyntaxFactory::new();

for (name, should_wrap) in wrap_in_parens {
let replacement = if should_wrap {
make.expr_paren(initializer_expr.clone()).into()
} else {
initializer_expr.clone()
};

if let Some(record_field) = ast::RecordExprField::for_field_name(&name) {
cov_mark::hit!(inline_field_shorthand);
builder.insert(range.end(), format!(": {replacement}"));
let replacement = make.record_expr_field(name, Some(replacement));
editor.replace(record_field.syntax(), replacement.syntax());
} else {
builder.replace(range, replacement.clone())
editor.replace(name.syntax(), replacement.syntax());
}
}

editor.add_mappings(make.finish_with_mappings());
builder.add_file_edits(ctx.file_id(), editor);
},
)
}
Expand Down Expand Up @@ -939,6 +943,54 @@ fn main() {
fn main() {
let _ = (|| 2)();
}
"#,
);
}

#[test]
fn test_wrap_in_parens() {
check_assist(
inline_local_variable,
r#"
fn main() {
let $0a = 123 < 456;
let b = !a;
}
"#,
r#"
fn main() {
let b = !(123 < 456);
}
"#,
);
check_assist(
inline_local_variable,
r#"
trait Foo {
fn foo(&self);
}

impl Foo for bool {
fn foo(&self) {}
}

fn main() {
let $0a = 123 < 456;
let b = a.foo();
}
"#,
r#"
trait Foo {
fn foo(&self);
}

impl Foo for bool {
fn foo(&self) {}
}

fn main() {
let b = (123 < 456).foo();
}
"#,
);
}
Expand Down
63 changes: 55 additions & 8 deletions crates/syntax/src/ast/prec.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Precedence representation.

use stdx::always;

use crate::{
ast::{self, BinaryOp, Expr, HasArgList, RangeItem},
match_ast, AstNode, SyntaxNode,
Expand Down Expand Up @@ -140,6 +142,22 @@ pub fn precedence(expr: &ast::Expr) -> ExprPrecedence {
}
}

fn check_ancestry(ancestor: &SyntaxNode, descendent: &SyntaxNode) -> bool {
let bail = || always!(false, "{} is not an ancestor of {}", ancestor, descendent);

if !ancestor.text_range().contains_range(descendent.text_range()) {
return bail();
}

for anc in descendent.ancestors() {
if anc == *ancestor {
return true;
}
}

bail()
}

impl Expr {
pub fn precedence(&self) -> ExprPrecedence {
precedence(self)
Expand All @@ -153,9 +171,19 @@ impl Expr {

/// Returns `true` if `self` would need to be wrapped in parentheses given that its parent is `parent`.
pub fn needs_parens_in(&self, parent: &SyntaxNode) -> bool {
self.needs_parens_in_place_of(parent, self.syntax())
}

/// Returns `true` if `self` would need to be wrapped in parentheses if it replaces `place_of`
/// given that `place_of`'s parent is `parent`.
pub fn needs_parens_in_place_of(&self, parent: &SyntaxNode, place_of: &SyntaxNode) -> bool {
if !check_ancestry(parent, place_of) {
return false;
}

match_ast! {
match parent {
ast::Expr(e) => self.needs_parens_in_expr(&e),
ast::Expr(e) => self.needs_parens_in_expr(&e, place_of),
ast::Stmt(e) => self.needs_parens_in_stmt(Some(&e)),
ast::StmtList(_) => self.needs_parens_in_stmt(None),
ast::ArgList(_) => false,
Expand All @@ -165,7 +193,7 @@ impl Expr {
}
}

fn needs_parens_in_expr(&self, parent: &Expr) -> bool {
fn needs_parens_in_expr(&self, parent: &Expr, place_of: &SyntaxNode) -> bool {
// Parentheses are necessary when calling a function-like pointer that is a member of a struct or union
// (e.g. `(a.f)()`).
let is_parent_call_expr = matches!(parent, ast::Expr::CallExpr(_));
Expand Down Expand Up @@ -199,13 +227,17 @@ impl Expr {

if self.is_paren_like()
|| parent.is_paren_like()
|| self.is_prefix() && (parent.is_prefix() || !self.is_ordered_before(parent))
|| self.is_postfix() && (parent.is_postfix() || self.is_ordered_before(parent))
|| self.is_prefix()
&& (parent.is_prefix()
|| !self.is_ordered_before_parent_in_place_of(parent, place_of))
|| self.is_postfix()
&& (parent.is_postfix()
|| self.is_ordered_before_parent_in_place_of(parent, place_of))
{
return false;
}

let (left, right, inv) = match self.is_ordered_before(parent) {
let (left, right, inv) = match self.is_ordered_before_parent_in_place_of(parent, place_of) {
true => (self, parent, false),
false => (parent, self, true),
};
Expand Down Expand Up @@ -413,13 +445,28 @@ impl Expr {
}
}

fn is_ordered_before(&self, other: &Expr) -> bool {
fn is_ordered_before_parent_in_place_of(&self, parent: &Expr, place_of: &SyntaxNode) -> bool {
use rowan::TextSize;
use Expr::*;

return order(self) < order(other);
let self_range = self.syntax().text_range();
let place_of_range = place_of.text_range();

let self_order_adjusted = order(self) - self_range.start() + place_of_range.start();

let parent_order = order(parent);
let parent_order_adjusted = if parent_order <= place_of_range.start() {
parent_order
} else if parent_order >= place_of_range.end() {
parent_order - place_of_range.len() + self_range.len()
} else {
return false;
};

return self_order_adjusted < parent_order_adjusted;

/// Returns text range that can be used to compare two expression for order (which goes first).
fn order(this: &Expr) -> rowan::TextSize {
fn order(this: &Expr) -> TextSize {
// For non-paren-like operators: get the operator itself
let token = match this {
RangeExpr(e) => e.op_token(),
Expand Down
21 changes: 21 additions & 0 deletions crates/syntax/src/ast/syntax_factory/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,27 @@ impl SyntaxFactory {
ast
}

pub fn record_expr_field(
&self,
name: ast::NameRef,
expr: Option<ast::Expr>,
) -> ast::RecordExprField {
let ast = make::record_expr_field(name.clone(), expr.clone()).clone_for_update();

if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());

builder.map_node(name.syntax().clone(), ast.name_ref().unwrap().syntax().clone());
if let Some(expr) = expr {
builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone());
}

builder.finish(&mut mapping);
}

ast
}

pub fn record_field_list(
&self,
fields: impl IntoIterator<Item = ast::RecordField>,
Expand Down
6 changes: 3 additions & 3 deletions docs/book/src/assists_generated.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ fn main() {


### `apply_demorgan_iterator`
**Source:** [apply_demorgan.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/apply_demorgan.rs#L154)
**Source:** [apply_demorgan.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/apply_demorgan.rs#L156)

Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws) to
`Iterator::all` and `Iterator::any`.
Expand Down Expand Up @@ -1070,7 +1070,7 @@ pub use foo::{Bar, Baz};


### `expand_record_rest_pattern`
**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L24)
**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L26)

Fills fields by replacing rest pattern in record patterns.

Expand All @@ -1094,7 +1094,7 @@ fn foo(bar: Bar) {


### `expand_tuple_struct_rest_pattern`
**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L80)
**Source:** [expand_rest_pattern.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/expand_rest_pattern.rs#L82)

Fills fields by replacing rest pattern in tuple struct patterns.

Expand Down
Loading