From 92a434a33904608f5659cf7b5d4df3d2a99bd5bd Mon Sep 17 00:00:00 2001 From: Philip Herron Date: Wed, 23 Jun 2021 17:22:12 +0100 Subject: [PATCH] Add support for nested functions We missed that stmts in rust can be items like functions. This adds support for resolution and compilation of nested functions. Rust allows nested functions which are distinct to closures. Nested functions are not allowed to encapsulate the enclosing scope so they can be extracted as normal functions. --- gcc/rust/backend/rust-compile-base.h | 3 + gcc/rust/backend/rust-compile-implitem.h | 44 ++-------- gcc/rust/backend/rust-compile-item.h | 22 +---- gcc/rust/backend/rust-compile.cc | 51 ++++++++--- gcc/rust/hir/rust-ast-lower-stmt.h | 85 +++++++++++++++++++ gcc/rust/resolve/rust-ast-resolve-stmt.h | 55 ++++++++++++ gcc/rust/resolve/rust-ast-resolve.cc | 5 +- gcc/rust/typecheck/rust-hir-type-check-stmt.h | 80 +++++++++++++++++ gcc/testsuite/rust/compile/lookup_err1.rs | 7 ++ .../rust/compile/torture/nested_fn1.rs | 10 +++ .../rust/compile/torture/nested_fn2.rs | 11 +++ 11 files changed, 299 insertions(+), 74 deletions(-) create mode 100644 gcc/testsuite/rust/compile/lookup_err1.rs create mode 100644 gcc/testsuite/rust/compile/torture/nested_fn1.rs create mode 100644 gcc/testsuite/rust/compile/torture/nested_fn2.rs diff --git a/gcc/rust/backend/rust-compile-base.h b/gcc/rust/backend/rust-compile-base.h index ed33515e9e0c..c346af596262 100644 --- a/gcc/rust/backend/rust-compile-base.h +++ b/gcc/rust/backend/rust-compile-base.h @@ -210,6 +210,9 @@ class HIRCompileBase : public HIR::HIRVisitor void compile_function_body (Bfunction *fndecl, std::unique_ptr &function_body, bool has_return_type); + + bool compile_locals_for_block (Resolver::Rib &rib, Bfunction *fndecl, + std::vector &locals); }; } // namespace Compile diff --git a/gcc/rust/backend/rust-compile-implitem.h b/gcc/rust/backend/rust-compile-implitem.h index d6698d1800ac..70f76b7d873a 100644 --- a/gcc/rust/backend/rust-compile-implitem.h +++ b/gcc/rust/backend/rust-compile-implitem.h @@ -183,26 +183,10 @@ class CompileInherentImplItem : public HIRCompileBase } std::vector locals; - rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool { - Resolver::Definition d; - bool ok = ctx->get_resolver ()->lookup_definition (n, &d); - rust_assert (ok); - - HIR::Stmt *decl = nullptr; - ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl); - rust_assert (ok); - - Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx); - locals.push_back (compiled); - - return true; - }); - - bool toplevel_item - = function.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID; - Bblock *enclosing_scope - = toplevel_item ? NULL : ctx->peek_enclosing_scope (); + bool ok = compile_locals_for_block (*rib, fndecl, locals); + rust_assert (ok); + Bblock *enclosing_scope = NULL; HIR::BlockExpr *function_body = function.get_definition ().get (); Location start_location = function_body->get_locus (); Location end_location = function_body->get_closing_locus (); @@ -409,26 +393,10 @@ class CompileInherentImplItem : public HIRCompileBase } std::vector locals; - rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool { - Resolver::Definition d; - bool ok = ctx->get_resolver ()->lookup_definition (n, &d); - rust_assert (ok); - - HIR::Stmt *decl = nullptr; - ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl); - rust_assert (ok); - - Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx); - locals.push_back (compiled); - - return true; - }); - - bool toplevel_item - = method.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID; - Bblock *enclosing_scope - = toplevel_item ? NULL : ctx->peek_enclosing_scope (); + bool ok = compile_locals_for_block (*rib, fndecl, locals); + rust_assert (ok); + Bblock *enclosing_scope = NULL; HIR::BlockExpr *function_body = method.get_function_body ().get (); Location start_location = function_body->get_locus (); Location end_location = function_body->get_closing_locus (); diff --git a/gcc/rust/backend/rust-compile-item.h b/gcc/rust/backend/rust-compile-item.h index 8a521e714fcb..eacfda90a79c 100644 --- a/gcc/rust/backend/rust-compile-item.h +++ b/gcc/rust/backend/rust-compile-item.h @@ -213,26 +213,10 @@ class CompileItem : public HIRCompileBase } std::vector locals; - rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool { - Resolver::Definition d; - bool ok = ctx->get_resolver ()->lookup_definition (n, &d); - rust_assert (ok); - - HIR::Stmt *decl = nullptr; - ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl); - rust_assert (ok); - - Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx); - locals.push_back (compiled); - - return true; - }); - - bool toplevel_item - = function.get_mappings ().get_local_defid () != UNKNOWN_LOCAL_DEFID; - Bblock *enclosing_scope - = toplevel_item ? NULL : ctx->peek_enclosing_scope (); + bool ok = compile_locals_for_block (*rib, fndecl, locals); + rust_assert (ok); + Bblock *enclosing_scope = NULL; HIR::BlockExpr *function_body = function.get_definition ().get (); Location start_location = function_body->get_locus (); Location end_location = function_body->get_closing_locus (); diff --git a/gcc/rust/backend/rust-compile.cc b/gcc/rust/backend/rust-compile.cc index 351271c91091..5ffd11a422c8 100644 --- a/gcc/rust/backend/rust-compile.cc +++ b/gcc/rust/backend/rust-compile.cc @@ -212,20 +212,8 @@ CompileBlock::visit (HIR::BlockExpr &expr) } std::vector locals; - rib->iterate_decls ([&] (NodeId n, Location) mutable -> bool { - Resolver::Definition d; - bool ok = ctx->get_resolver ()->lookup_definition (n, &d); - rust_assert (ok); - - HIR::Stmt *decl = nullptr; - ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl); - rust_assert (ok); - - Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx); - locals.push_back (compiled); - - return true; - }); + bool ok = compile_locals_for_block (*rib, fndecl, locals); + rust_assert (ok); Bblock *enclosing_scope = ctx->peek_enclosing_scope (); Bblock *new_block @@ -415,6 +403,41 @@ HIRCompileBase::compile_function_body ( } } +bool +HIRCompileBase::compile_locals_for_block (Resolver::Rib &rib, Bfunction *fndecl, + std::vector &locals) +{ + rib.iterate_decls ([&] (NodeId n, Location) mutable -> bool { + Resolver::Definition d; + bool ok = ctx->get_resolver ()->lookup_definition (n, &d); + rust_assert (ok); + + HIR::Stmt *decl = nullptr; + ok = ctx->get_mappings ()->resolve_nodeid_to_stmt (d.parent, &decl); + rust_assert (ok); + + // if its a function we extract this out side of this fn context + // and it is not a local to this function + bool is_item = ctx->get_mappings ()->lookup_hir_item ( + decl->get_mappings ().get_crate_num (), + decl->get_mappings ().get_hirid ()) + != nullptr; + if (is_item) + { + HIR::Item *item = static_cast (decl); + CompileItem::compile (item, ctx, true); + return true; + } + + Bvariable *compiled = CompileVarDecl::compile (fndecl, decl, ctx); + locals.push_back (compiled); + + return true; + }); + + return true; +} + // Mr Mangle time static const std::string kMangledSymbolPrefix = "_ZN"; diff --git a/gcc/rust/hir/rust-ast-lower-stmt.h b/gcc/rust/hir/rust-ast-lower-stmt.h index c495932497e2..1dd8a10425ed 100644 --- a/gcc/rust/hir/rust-ast-lower-stmt.h +++ b/gcc/rust/hir/rust-ast-lower-stmt.h @@ -230,6 +230,91 @@ class ASTLoweringStmt : public ASTLoweringBase empty.get_locus ()); } + void visit (AST::Function &function) override + { + // ignore for now and leave empty + std::vector > where_clause_items; + HIR::WhereClause where_clause (std::move (where_clause_items)); + HIR::FunctionQualifiers qualifiers ( + HIR::FunctionQualifiers::AsyncConstStatus::NONE, false); + HIR::Visibility vis = HIR::Visibility::create_public (); + + // need + std::vector > generic_params; + if (function.has_generics ()) + { + generic_params = lower_generic_params (function.get_generic_params ()); + } + + Identifier function_name = function.get_function_name (); + Location locus = function.get_locus (); + + std::unique_ptr return_type + = function.has_return_type () ? std::unique_ptr ( + ASTLoweringType::translate (function.get_return_type ().get ())) + : nullptr; + + std::vector function_params; + for (auto ¶m : function.get_function_params ()) + { + auto translated_pattern = std::unique_ptr ( + ASTLoweringPattern::translate (param.get_pattern ().get ())); + auto translated_type = std::unique_ptr ( + ASTLoweringType::translate (param.get_type ().get ())); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, param.get_node_id (), + mappings->get_next_hir_id (crate_num), + UNKNOWN_LOCAL_DEFID); + + auto hir_param + = HIR::FunctionParam (mapping, std::move (translated_pattern), + std::move (translated_type), + param.get_locus ()); + function_params.push_back (hir_param); + } + + bool terminated = false; + std::unique_ptr function_body + = std::unique_ptr ( + ASTLoweringBlock::translate (function.get_definition ().get (), + &terminated)); + + auto crate_num = mappings->get_current_crate (); + Analysis::NodeMapping mapping (crate_num, function.get_node_id (), + mappings->get_next_hir_id (crate_num), + UNKNOWN_LOCAL_DEFID); + + mappings->insert_location (crate_num, + function_body->get_mappings ().get_hirid (), + function.get_locus ()); + + auto fn + = new HIR::Function (mapping, std::move (function_name), + std::move (qualifiers), std::move (generic_params), + std::move (function_params), std::move (return_type), + std::move (where_clause), std::move (function_body), + std::move (vis), function.get_outer_attrs (), locus); + + mappings->insert_hir_item (mapping.get_crate_num (), mapping.get_hirid (), + fn); + mappings->insert_hir_stmt (mapping.get_crate_num (), mapping.get_hirid (), + fn); + mappings->insert_location (crate_num, mapping.get_hirid (), + function.get_locus ()); + + // add the mappings for the function params at the end + for (auto ¶m : fn->get_function_params ()) + { + mappings->insert_hir_param (mapping.get_crate_num (), + param.get_mappings ().get_hirid (), ¶m); + mappings->insert_location (crate_num, mapping.get_hirid (), + param.get_locus ()); + } + + translated = fn; + } + private: ASTLoweringStmt () : translated (nullptr), terminated (false) {} diff --git a/gcc/rust/resolve/rust-ast-resolve-stmt.h b/gcc/rust/resolve/rust-ast-resolve-stmt.h index 3fd1cfa841f1..e68e7b93f3f1 100644 --- a/gcc/rust/resolve/rust-ast-resolve-stmt.h +++ b/gcc/rust/resolve/rust-ast-resolve-stmt.h @@ -129,6 +129,61 @@ class ResolveStmt : public ResolverBase resolver->get_type_scope ().pop (); } + void visit (AST::Function &function) override + { + auto path = ResolveFunctionItemToCanonicalPath::resolve (function); + resolver->get_name_scope ().insert ( + path, function.get_node_id (), function.get_locus (), false, + [&] (const CanonicalPath &, NodeId, Location locus) -> void { + RichLocation r (function.get_locus ()); + r.add_range (locus); + rust_error_at (r, "redefined multiple times"); + }); + resolver->insert_new_definition (function.get_node_id (), + Definition{function.get_node_id (), + function.get_node_id ()}); + + NodeId scope_node_id = function.get_node_id (); + resolver->get_name_scope ().push (scope_node_id); + resolver->get_type_scope ().push (scope_node_id); + resolver->get_label_scope ().push (scope_node_id); + resolver->push_new_name_rib (resolver->get_name_scope ().peek ()); + resolver->push_new_type_rib (resolver->get_type_scope ().peek ()); + resolver->push_new_label_rib (resolver->get_type_scope ().peek ()); + + if (function.has_generics ()) + { + for (auto &generic : function.get_generic_params ()) + ResolveGenericParam::go (generic.get (), function.get_node_id ()); + } + + if (function.has_return_type ()) + ResolveType::go (function.get_return_type ().get (), + function.get_node_id ()); + + // we make a new scope so the names of parameters are resolved and shadowed + // correctly + for (auto ¶m : function.get_function_params ()) + { + ResolveType::go (param.get_type ().get (), param.get_node_id ()); + PatternDeclaration::go (param.get_pattern ().get (), + param.get_node_id ()); + + // the mutability checker needs to verify for immutable decls the number + // of assignments are <1. This marks an implicit assignment + resolver->mark_assignment_to_decl (param.get_pattern ()->get_node_id (), + param.get_node_id ()); + } + + // resolve the function body + ResolveExpr::go (function.get_definition ().get (), + function.get_node_id ()); + + resolver->get_name_scope ().pop (); + resolver->get_type_scope ().pop (); + resolver->get_label_scope ().pop (); + } + private: ResolveStmt (NodeId parent) : ResolverBase (parent) {} }; diff --git a/gcc/rust/resolve/rust-ast-resolve.cc b/gcc/rust/resolve/rust-ast-resolve.cc index e03a745b0240..fae3f77930ba 100644 --- a/gcc/rust/resolve/rust-ast-resolve.cc +++ b/gcc/rust/resolve/rust-ast-resolve.cc @@ -499,9 +499,8 @@ ResolvePath::resolve_path (AST::PathInExpression *expr) else { rust_error_at (expr->get_locus (), - "unknown root segment in path %s lookup %s", - expr->as_string ().c_str (), - root_ident_seg.as_string ().c_str ()); + "Cannot find path %<%s%> in this scope", + expr->as_string ().c_str ()); return; } diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.h b/gcc/rust/typecheck/rust-hir-type-check-stmt.h index 0e55df839723..3655d968dfec 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-stmt.h +++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.h @@ -216,6 +216,86 @@ class TypeCheckStmt : public TypeCheckBase infered = type; } + void visit (HIR::Function &function) override + { + std::vector substitutions; + if (function.has_generics ()) + { + for (auto &generic_param : function.get_generic_params ()) + { + switch (generic_param.get ()->get_kind ()) + { + case HIR::GenericParam::GenericKind::LIFETIME: + // Skipping Lifetime completely until better handling. + break; + + case HIR::GenericParam::GenericKind::TYPE: { + auto param_type + = TypeResolveGenericParam::Resolve (generic_param.get ()); + context->insert_type (generic_param->get_mappings (), + param_type); + + substitutions.push_back (TyTy::SubstitutionParamMapping ( + static_cast (*generic_param), + param_type)); + } + break; + } + } + } + + TyTy::BaseType *ret_type = nullptr; + if (!function.has_function_return_type ()) + ret_type = new TyTy::TupleType (function.get_mappings ().get_hirid ()); + else + { + auto resolved + = TypeCheckType::Resolve (function.get_return_type ().get ()); + if (resolved == nullptr) + { + rust_error_at (function.get_locus (), + "failed to resolve return type"); + return; + } + + ret_type = resolved->clone (); + ret_type->set_ref ( + function.get_return_type ()->get_mappings ().get_hirid ()); + } + + std::vector > params; + for (auto ¶m : function.get_function_params ()) + { + // get the name as well required for later on + auto param_tyty = TypeCheckType::Resolve (param.get_type ()); + params.push_back ( + std::pair (param.get_param_name (), + param_tyty)); + + context->insert_type (param.get_mappings (), param_tyty); + } + + auto fnType = new TyTy::FnType (function.get_mappings ().get_hirid (), + function.get_function_name (), false, + std::move (params), ret_type, + std::move (substitutions)); + context->insert_type (function.get_mappings (), fnType); + + TyTy::FnType *resolved_fn_type = fnType; + auto expected_ret_tyty = resolved_fn_type->get_return_type (); + context->push_return_type (expected_ret_tyty); + + auto block_expr_ty + = TypeCheckExpr::Resolve (function.get_definition ().get (), false); + + context->pop_return_type (); + + if (block_expr_ty->get_kind () != TyTy::NEVER) + expected_ret_tyty->unify (block_expr_ty); + + infered = fnType; + } + private: TypeCheckStmt (bool inside_loop) : TypeCheckBase (), infered (nullptr), inside_loop (inside_loop) diff --git a/gcc/testsuite/rust/compile/lookup_err1.rs b/gcc/testsuite/rust/compile/lookup_err1.rs new file mode 100644 index 000000000000..4a96f9ff1403 --- /dev/null +++ b/gcc/testsuite/rust/compile/lookup_err1.rs @@ -0,0 +1,7 @@ +fn test() { + fn nested() {} +} + +fn main() { + nested(); // { dg-error "Cannot find path .nested. in this scope" } +} diff --git a/gcc/testsuite/rust/compile/torture/nested_fn1.rs b/gcc/testsuite/rust/compile/torture/nested_fn1.rs new file mode 100644 index 000000000000..075b5dba8e04 --- /dev/null +++ b/gcc/testsuite/rust/compile/torture/nested_fn1.rs @@ -0,0 +1,10 @@ +pub fn main() { + let a = 123; + + fn test(x: i32) -> i32 { + x + 456 + } + + let b; + b = test(a); +} diff --git a/gcc/testsuite/rust/compile/torture/nested_fn2.rs b/gcc/testsuite/rust/compile/torture/nested_fn2.rs new file mode 100644 index 000000000000..7040c862e75e --- /dev/null +++ b/gcc/testsuite/rust/compile/torture/nested_fn2.rs @@ -0,0 +1,11 @@ +pub fn main() { + fn test(x: T) -> T { + x + } + + let mut a = 123; + a = test(a); + + let mut b = 456f32; + b = test(b); +}