Skip to content

Commit 4f6d5ed

Browse files
committed
Make add_function generate functions in other modules via qualified path
1 parent f11236e commit 4f6d5ed

File tree

3 files changed

+206
-25
lines changed

3 files changed

+206
-25
lines changed

crates/ra_assists/src/handlers/add_function.rs

Lines changed: 195 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use ra_syntax::{
44
};
55

66
use crate::{Assist, AssistCtx, AssistId};
7-
use ast::{edit::IndentLevel, ArgListOwner, CallExpr, Expr};
7+
use ast::{edit::IndentLevel, ArgListOwner, ModuleItemOwner};
88
use hir::HirDisplay;
99
use rustc_hash::{FxHashMap, FxHashSet};
1010

@@ -16,7 +16,7 @@ use rustc_hash::{FxHashMap, FxHashSet};
1616
// struct Baz;
1717
// fn baz() -> Baz { Baz }
1818
// fn foo() {
19-
// bar<|>("", baz());
19+
// bar<|>("", baz());
2020
// }
2121
//
2222
// ```
@@ -25,7 +25,7 @@ use rustc_hash::{FxHashMap, FxHashSet};
2525
// struct Baz;
2626
// fn baz() -> Baz { Baz }
2727
// fn foo() {
28-
// bar("", baz());
28+
// bar("", baz());
2929
// }
3030
//
3131
// fn bar(arg: &str, baz: Baz) {
@@ -38,16 +38,24 @@ pub(crate) fn add_function(ctx: AssistCtx) -> Option<Assist> {
3838
let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?;
3939
let path = path_expr.path()?;
4040

41-
if path.qualifier().is_some() {
42-
return None;
43-
}
44-
4541
if ctx.sema.resolve_path(&path).is_some() {
4642
// The function call already resolves, no need to add a function
4743
return None;
4844
}
4945

50-
let function_builder = FunctionBuilder::from_call(&ctx, &call)?;
46+
let target_module = if let Some(qualifier) = path.qualifier() {
47+
if let Some(hir::PathResolution::Def(hir::ModuleDef::Module(resolved))) =
48+
ctx.sema.resolve_path(&qualifier)
49+
{
50+
Some(resolved.definition_source(ctx.sema.db).value)
51+
} else {
52+
return None;
53+
}
54+
} else {
55+
None
56+
};
57+
58+
let function_builder = FunctionBuilder::from_call(&ctx, &call, &path, target_module)?;
5159

5260
ctx.add_assist(AssistId("add_function"), "Add function", |edit| {
5361
edit.target(call.syntax().text_range());
@@ -66,26 +74,54 @@ struct FunctionTemplate {
6674
}
6775

6876
struct FunctionBuilder {
69-
append_fn_at: SyntaxNode,
77+
target: GeneratedFunctionTarget,
7078
fn_name: ast::Name,
7179
type_params: Option<ast::TypeParamList>,
7280
params: ast::ParamList,
7381
}
7482

7583
impl FunctionBuilder {
76-
fn from_call(ctx: &AssistCtx, call: &ast::CallExpr) -> Option<Self> {
77-
let append_fn_at = next_space_for_fn(&call)?;
78-
let fn_name = fn_name(&call)?;
84+
/// Prepares a generated function that matches `call` in `generate_in`
85+
/// (or as close to `call` as possible, if `generate_in` is `None`)
86+
fn from_call(
87+
ctx: &AssistCtx,
88+
call: &ast::CallExpr,
89+
path: &ast::Path,
90+
generate_in: Option<hir::ModuleSource>,
91+
) -> Option<Self> {
92+
let target = if let Some(generate_in_module) = generate_in {
93+
next_space_for_fn_in_module(generate_in_module)?
94+
} else {
95+
next_space_for_fn_after_call_site(&call)?
96+
};
97+
let fn_name = fn_name(&path)?;
7998
let (type_params, params) = fn_args(ctx, &call)?;
80-
Some(Self { append_fn_at, fn_name, type_params, params })
99+
Some(Self { target, fn_name, type_params, params })
81100
}
82101
fn render(self) -> Option<FunctionTemplate> {
83102
let placeholder_expr = ast::make::expr_todo();
84103
let fn_body = ast::make::block_expr(vec![], Some(placeholder_expr));
85104
let fn_def = ast::make::fn_def(self.fn_name, self.type_params, self.params, fn_body);
86-
let fn_def = ast::make::add_newlines(2, fn_def);
87-
let fn_def = IndentLevel::from_node(&self.append_fn_at).increase_indent(fn_def);
88-
let insert_offset = self.append_fn_at.text_range().end();
105+
106+
let (fn_def, insert_offset) = match self.target {
107+
GeneratedFunctionTarget::BehindItem(it) => {
108+
let with_leading_blank_line = ast::make::add_leading_newlines(2, fn_def);
109+
let indented = IndentLevel::from_node(&it).increase_indent(with_leading_blank_line);
110+
(indented, it.text_range().end())
111+
}
112+
GeneratedFunctionTarget::InEmptyItemList(it) => {
113+
let with_leading_newline = ast::make::add_leading_newlines(1, fn_def);
114+
let indent = IndentLevel::from_node(it.syntax()).indented();
115+
let mut indented = indent.increase_indent(with_leading_newline);
116+
if !item_list_has_whitespace(&it) {
117+
// In this case we want to make sure there's a newline between the closing
118+
// function brace and the closing module brace (so it doesn't end in `}}`).
119+
indented = ast::make::add_trailing_newlines(1, indented);
120+
}
121+
(indented, it.syntax().text_range().start() + TextUnit::from_usize(1))
122+
}
123+
};
124+
89125
let cursor_offset_from_fn_start = fn_def
90126
.syntax()
91127
.descendants()
@@ -98,15 +134,25 @@ impl FunctionBuilder {
98134
}
99135
}
100136

101-
fn fn_name(call: &CallExpr) -> Option<ast::Name> {
102-
let name = call.expr()?.syntax().to_string();
137+
/// Returns true if the given ItemList contains whitespace.
138+
fn item_list_has_whitespace(it: &ast::ItemList) -> bool {
139+
it.syntax().descendants_with_tokens().find(|it| it.kind() == SyntaxKind::WHITESPACE).is_some()
140+
}
141+
142+
enum GeneratedFunctionTarget {
143+
BehindItem(SyntaxNode),
144+
InEmptyItemList(ast::ItemList),
145+
}
146+
147+
fn fn_name(call: &ast::Path) -> Option<ast::Name> {
148+
let name = call.segment()?.syntax().to_string();
103149
Some(ast::make::name(&name))
104150
}
105151

106152
/// Computes the type variables and arguments required for the generated function
107153
fn fn_args(
108154
ctx: &AssistCtx,
109-
call: &CallExpr,
155+
call: &ast::CallExpr,
110156
) -> Option<(Option<ast::TypeParamList>, ast::ParamList)> {
111157
let mut arg_names = Vec::new();
112158
let mut arg_types = Vec::new();
@@ -158,9 +204,9 @@ fn deduplicate_arg_names(arg_names: &mut Vec<String>) {
158204
}
159205
}
160206

161-
fn fn_arg_name(fn_arg: &Expr) -> Option<String> {
207+
fn fn_arg_name(fn_arg: &ast::Expr) -> Option<String> {
162208
match fn_arg {
163-
Expr::CastExpr(cast_expr) => fn_arg_name(&cast_expr.expr()?),
209+
ast::Expr::CastExpr(cast_expr) => fn_arg_name(&cast_expr.expr()?),
164210
_ => Some(
165211
fn_arg
166212
.syntax()
@@ -172,7 +218,7 @@ fn fn_arg_name(fn_arg: &Expr) -> Option<String> {
172218
}
173219
}
174220

175-
fn fn_arg_type(ctx: &AssistCtx, fn_arg: &Expr) -> Option<String> {
221+
fn fn_arg_type(ctx: &AssistCtx, fn_arg: &ast::Expr) -> Option<String> {
176222
let ty = ctx.sema.type_of_expr(fn_arg)?;
177223
if ty.is_unknown() {
178224
return None;
@@ -184,7 +230,7 @@ fn fn_arg_type(ctx: &AssistCtx, fn_arg: &Expr) -> Option<String> {
184230
/// directly after the current block
185231
/// We want to write the generated function directly after
186232
/// fns, impls or macro calls, but inside mods
187-
fn next_space_for_fn(expr: &CallExpr) -> Option<SyntaxNode> {
233+
fn next_space_for_fn_after_call_site(expr: &ast::CallExpr) -> Option<GeneratedFunctionTarget> {
188234
let mut ancestors = expr.syntax().ancestors().peekable();
189235
let mut last_ancestor: Option<SyntaxNode> = None;
190236
while let Some(next_ancestor) = ancestors.next() {
@@ -201,7 +247,26 @@ fn next_space_for_fn(expr: &CallExpr) -> Option<SyntaxNode> {
201247
}
202248
last_ancestor = Some(next_ancestor);
203249
}
204-
last_ancestor
250+
last_ancestor.map(GeneratedFunctionTarget::BehindItem)
251+
}
252+
253+
fn next_space_for_fn_in_module(module: hir::ModuleSource) -> Option<GeneratedFunctionTarget> {
254+
match module {
255+
hir::ModuleSource::SourceFile(it) => {
256+
if let Some(last_item) = it.items().last() {
257+
Some(GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()))
258+
} else {
259+
Some(GeneratedFunctionTarget::BehindItem(it.syntax().clone()))
260+
}
261+
}
262+
hir::ModuleSource::Module(it) => {
263+
if let Some(last_item) = it.item_list().and_then(|it| it.items().last()) {
264+
Some(GeneratedFunctionTarget::BehindItem(last_item.syntax().clone()))
265+
} else {
266+
it.item_list().map(GeneratedFunctionTarget::InEmptyItemList)
267+
}
268+
}
269+
}
205270
}
206271

207272
#[cfg(test)]
@@ -713,6 +778,112 @@ fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) {
713778
)
714779
}
715780

781+
#[test]
782+
fn add_function_in_module() {
783+
check_assist(
784+
add_function,
785+
r"
786+
mod bar {}
787+
788+
fn foo() {
789+
bar::my_fn<|>()
790+
}
791+
",
792+
r"
793+
mod bar {
794+
fn my_fn() {
795+
<|>todo!()
796+
}
797+
}
798+
799+
fn foo() {
800+
bar::my_fn()
801+
}
802+
",
803+
);
804+
check_assist(
805+
add_function,
806+
r"
807+
mod bar {
808+
}
809+
810+
fn foo() {
811+
bar::my_fn<|>()
812+
}
813+
",
814+
r"
815+
mod bar {
816+
fn my_fn() {
817+
<|>todo!()
818+
}
819+
}
820+
821+
fn foo() {
822+
bar::my_fn()
823+
}
824+
",
825+
)
826+
}
827+
828+
#[test]
829+
fn add_function_in_module_containing_other_items() {
830+
check_assist(
831+
add_function,
832+
r"
833+
mod bar {
834+
fn something_else() {}
835+
}
836+
837+
fn foo() {
838+
bar::my_fn<|>()
839+
}
840+
",
841+
r"
842+
mod bar {
843+
fn something_else() {}
844+
845+
fn my_fn() {
846+
<|>todo!()
847+
}
848+
}
849+
850+
fn foo() {
851+
bar::my_fn()
852+
}
853+
",
854+
)
855+
}
856+
857+
#[test]
858+
fn add_function_in_nested_module() {
859+
check_assist(
860+
add_function,
861+
r"
862+
mod bar {
863+
mod baz {
864+
}
865+
}
866+
867+
fn foo() {
868+
bar::baz::my_fn<|>()
869+
}
870+
",
871+
r"
872+
mod bar {
873+
mod baz {
874+
fn my_fn() {
875+
<|>todo!()
876+
}
877+
}
878+
}
879+
880+
fn foo() {
881+
bar::baz::my_fn()
882+
}
883+
",
884+
)
885+
}
886+
716887
#[test]
717888
fn add_function_not_applicable_if_function_already_exists() {
718889
check_assist_not_applicable(

crates/ra_syntax/src/ast/edit.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,11 @@ impl IndentLevel {
449449
IndentLevel(0)
450450
}
451451

452+
pub fn indented(mut self) -> Self {
453+
self.0 += 1;
454+
self
455+
}
456+
452457
pub fn increase_indent<N: AstNode>(self, node: N) -> N {
453458
N::cast(self._increase_indent(node.syntax().clone())).unwrap()
454459
}

crates/ra_syntax/src/ast/make.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,16 @@ pub fn fn_def(
293293
ast_from_text(&format!("fn {}{}{} {}", fn_name, type_params, params, body))
294294
}
295295

296-
pub fn add_newlines(amount_of_newlines: usize, t: impl AstNode) -> ast::SourceFile {
296+
pub fn add_leading_newlines(amount_of_newlines: usize, t: impl AstNode) -> ast::SourceFile {
297297
let newlines = "\n".repeat(amount_of_newlines);
298298
ast_from_text(&format!("{}{}", newlines, t.syntax()))
299299
}
300300

301+
pub fn add_trailing_newlines(amount_of_newlines: usize, t: impl AstNode) -> ast::SourceFile {
302+
let newlines = "\n".repeat(amount_of_newlines);
303+
ast_from_text(&format!("{}{}", t.syntax(), newlines))
304+
}
305+
301306
fn ast_from_text<N: AstNode>(text: &str) -> N {
302307
let parse = SourceFile::parse(text);
303308
let node = parse.tree().syntax().descendants().find_map(N::cast).unwrap();

0 commit comments

Comments
 (0)