Skip to content

Commit e39fadc

Browse files
bors[bot]philberty
andauthored
Merge #1611
1611: Initial state capture for closures r=philberty a=philberty This patch set adds the initial support closure captures, move semantics are not handled here. We track what variables are being captured by a closure during name resolution so that when a VAR_DECL is resolved, we check if we are inside a closure context node_id which is the same id as its associated rib id. So when we resolve a name that resides in an outermost rib we can add this to set of node-id's that are captured by this closure. There is a gap here for the case where we need to check if it is inside a nested function and that function contains closures which could wrongly capture variables in the enclosing function. This will also be a problem for nested functions in general. Fixes #195 Co-authored-by: Philip Herron <[email protected]>
2 parents 22329b0 + 3053ec3 commit e39fadc

19 files changed

+498
-100
lines changed

gcc/rust/backend/rust-compile-context.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,52 @@ Context::type_hasher (tree type)
142142
return hstate.end ();
143143
}
144144

145+
void
146+
Context::push_closure_context (HirId id)
147+
{
148+
auto it = closure_bindings.find (id);
149+
rust_assert (it == closure_bindings.end ());
150+
151+
closure_bindings.insert ({id, {}});
152+
closure_scope_bindings.push_back (id);
153+
}
154+
155+
void
156+
Context::pop_closure_context ()
157+
{
158+
rust_assert (!closure_scope_bindings.empty ());
159+
160+
HirId ref = closure_scope_bindings.back ();
161+
closure_scope_bindings.pop_back ();
162+
closure_bindings.erase (ref);
163+
}
164+
165+
void
166+
Context::insert_closure_binding (HirId id, tree expr)
167+
{
168+
rust_assert (!closure_scope_bindings.empty ());
169+
170+
HirId ref = closure_scope_bindings.back ();
171+
closure_bindings[ref].insert ({id, expr});
172+
}
173+
174+
bool
175+
Context::lookup_closure_binding (HirId id, tree *expr)
176+
{
177+
if (closure_scope_bindings.empty ())
178+
return false;
179+
180+
HirId ref = closure_scope_bindings.back ();
181+
auto it = closure_bindings.find (ref);
182+
rust_assert (it != closure_bindings.end ());
183+
184+
auto iy = it->second.find (id);
185+
if (iy == it->second.end ())
186+
return false;
187+
188+
*expr = iy->second;
189+
return true;
190+
}
191+
145192
} // namespace Compile
146193
} // namespace Rust

gcc/rust/backend/rust-compile-context.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,11 @@ class Context
345345
return mangler.mangle_item (ty, path);
346346
}
347347

348+
void push_closure_context (HirId id);
349+
void pop_closure_context ();
350+
void insert_closure_binding (HirId id, tree expr);
351+
bool lookup_closure_binding (HirId id, tree *expr);
352+
348353
std::vector<tree> &get_type_decls () { return type_decls; }
349354
std::vector<::Bvariable *> &get_var_decls () { return var_decls; }
350355
std::vector<tree> &get_const_decls () { return const_decls; }
@@ -377,6 +382,10 @@ class Context
377382
std::map<HirId, tree> implicit_pattern_bindings;
378383
std::map<hashval_t, tree> main_variants;
379384

385+
// closure bindings
386+
std::vector<HirId> closure_scope_bindings;
387+
std::map<HirId, std::map<HirId, tree>> closure_bindings;
388+
380389
// To GCC middle-end
381390
std::vector<tree> type_decls;
382391
std::vector<::Bvariable *> var_decls;

gcc/rust/backend/rust-compile-expr.cc

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2824,10 +2824,25 @@ CompileExpr::visit (HIR::ClosureExpr &expr)
28242824

28252825
// lets ignore state capture for now we need to instantiate the struct anyway
28262826
// then generate the function
2827-
28282827
std::vector<tree> vals;
2829-
// TODO
2830-
// setup argument captures based on the mode?
2828+
for (const auto &capture : closure_tyty->get_captures ())
2829+
{
2830+
// lookup the HirId
2831+
HirId ref = UNKNOWN_HIRID;
2832+
bool ok = ctx->get_mappings ()->lookup_node_to_hir (capture, &ref);
2833+
rust_assert (ok);
2834+
2835+
// lookup the var decl
2836+
Bvariable *var = nullptr;
2837+
bool found = ctx->lookup_var_decl (ref, &var);
2838+
rust_assert (found);
2839+
2840+
// FIXME
2841+
// this should bes based on the closure move-ability
2842+
tree var_expr = var->get_tree (expr.get_locus ());
2843+
tree val = address_expression (var_expr, expr.get_locus ());
2844+
vals.push_back (val);
2845+
}
28312846

28322847
translated
28332848
= ctx->get_backend ()->constructor_expression (compiled_closure_tyty, false,
@@ -2874,8 +2889,29 @@ CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
28742889
DECL_ARTIFICIAL (self_param->get_decl ()) = 1;
28752890
param_vars.push_back (self_param);
28762891

2892+
// push a new context
2893+
ctx->push_closure_context (expr.get_mappings ().get_hirid ());
2894+
28772895
// setup the implicit argument captures
2878-
// TODO
2896+
size_t idx = 0;
2897+
for (const auto &capture : closure_tyty.get_captures ())
2898+
{
2899+
// lookup the HirId
2900+
HirId ref = UNKNOWN_HIRID;
2901+
bool ok = ctx->get_mappings ()->lookup_node_to_hir (capture, &ref);
2902+
rust_assert (ok);
2903+
2904+
// get the assessor
2905+
tree binding = ctx->get_backend ()->struct_field_expression (
2906+
self_param->get_tree (expr.get_locus ()), idx, expr.get_locus ());
2907+
tree indirection = indirect_expression (binding, expr.get_locus ());
2908+
2909+
// insert bindings
2910+
ctx->insert_closure_binding (ref, indirection);
2911+
2912+
// continue
2913+
idx++;
2914+
}
28792915

28802916
// args tuple
28812917
tree args_type
@@ -2905,7 +2941,10 @@ CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
29052941
}
29062942

29072943
if (!ctx->get_backend ()->function_set_parameters (fndecl, param_vars))
2908-
return error_mark_node;
2944+
{
2945+
ctx->pop_closure_context ();
2946+
return error_mark_node;
2947+
}
29092948

29102949
// lookup locals
29112950
HIR::Expr *function_body = expr.get_expr ().get ();
@@ -2972,6 +3011,7 @@ CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
29723011
gcc_assert (TREE_CODE (bind_tree) == BIND_EXPR);
29733012
DECL_SAVED_TREE (fndecl) = bind_tree;
29743013

3014+
ctx->pop_closure_context ();
29753015
ctx->pop_fn ();
29763016
ctx->push_function (fndecl);
29773017

gcc/rust/backend/rust-compile-resolve-path.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ ResolvePathRef::resolve (const HIR::PathIdentSegment &final_segment,
121121
return constant_expr;
122122
}
123123

124+
// maybe closure binding
125+
tree closure_binding = error_mark_node;
126+
if (ctx->lookup_closure_binding (ref, &closure_binding))
127+
{
128+
TREE_USED (closure_binding) = 1;
129+
return closure_binding;
130+
}
131+
124132
// this might be a variable reference or a function reference
125133
Bvariable *var = nullptr;
126134
if (ctx->lookup_var_decl (ref, &var))

gcc/rust/backend/rust-compile-type.cc

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "rust-compile-type.h"
2020
#include "rust-compile-expr.h"
2121
#include "rust-constexpr.h"
22+
#include "rust-gcc.h"
2223

2324
#include "tree.h"
2425

@@ -99,11 +100,39 @@ TyTyResolveCompile::visit (const TyTy::InferType &)
99100
void
100101
TyTyResolveCompile::visit (const TyTy::ClosureType &type)
101102
{
103+
auto mappings = ctx->get_mappings ();
104+
102105
std::vector<Backend::typed_identifier> fields;
106+
107+
size_t i = 0;
108+
for (const auto &capture : type.get_captures ())
109+
{
110+
// lookup the HirId
111+
HirId ref = UNKNOWN_HIRID;
112+
bool ok = mappings->lookup_node_to_hir (capture, &ref);
113+
rust_assert (ok);
114+
115+
// lookup the var decl type
116+
TyTy::BaseType *lookup = nullptr;
117+
bool found = ctx->get_tyctx ()->lookup_type (ref, &lookup);
118+
rust_assert (found);
119+
120+
// FIXME get the var pattern name
121+
std::string mappings_name = "capture_" + std::to_string (i);
122+
123+
// FIXME
124+
// this should be based on the closure move-ability
125+
tree decl_type = TyTyResolveCompile::compile (ctx, lookup);
126+
tree capture_type = build_reference_type (decl_type);
127+
fields.push_back (Backend::typed_identifier (mappings_name, capture_type,
128+
type.get_ident ().locus));
129+
}
130+
103131
tree type_record = ctx->get_backend ()->struct_type (fields);
104132
RS_CLOSURE_FLAG (type_record) = 1;
105133

106-
std::string named_struct_str = type.get_ident ().path.get () + "{{closure}}";
134+
std::string named_struct_str
135+
= type.get_ident ().path.get () + "::{{closure}}";
107136
translated = ctx->get_backend ()->named_type (named_struct_str, type_record,
108137
type.get_ident ().locus);
109138
}

gcc/rust/resolve/rust-ast-resolve-expr.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ ResolveExpr::visit (AST::IfLetExpr &expr)
209209

210210
for (auto &pattern : expr.get_patterns ())
211211
{
212-
PatternDeclaration::go (pattern.get ());
212+
PatternDeclaration::go (pattern.get (), Rib::ItemType::Var);
213213
}
214214

215215
ResolveExpr::go (expr.get_if_block ().get (), prefix, canonical_prefix);
@@ -343,7 +343,7 @@ ResolveExpr::visit (AST::LoopExpr &expr)
343343
auto label_lifetime_node_id = label.get_lifetime ().get_node_id ();
344344
resolver->get_label_scope ().insert (
345345
CanonicalPath::new_seg (expr.get_node_id (), label_name),
346-
label_lifetime_node_id, label.get_locus (), false,
346+
label_lifetime_node_id, label.get_locus (), false, Rib::ItemType::Label,
347347
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
348348
rust_error_at (label.get_locus (), "label redefined multiple times");
349349
rust_error_at (locus, "was defined here");
@@ -400,7 +400,7 @@ ResolveExpr::visit (AST::WhileLoopExpr &expr)
400400
auto label_lifetime_node_id = label.get_lifetime ().get_node_id ();
401401
resolver->get_label_scope ().insert (
402402
CanonicalPath::new_seg (label.get_node_id (), label_name),
403-
label_lifetime_node_id, label.get_locus (), false,
403+
label_lifetime_node_id, label.get_locus (), false, Rib::ItemType::Label,
404404
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
405405
rust_error_at (label.get_locus (), "label redefined multiple times");
406406
rust_error_at (locus, "was defined here");
@@ -429,7 +429,7 @@ ResolveExpr::visit (AST::ForLoopExpr &expr)
429429
auto label_lifetime_node_id = label.get_lifetime ().get_node_id ();
430430
resolver->get_label_scope ().insert (
431431
CanonicalPath::new_seg (label.get_node_id (), label_name),
432-
label_lifetime_node_id, label.get_locus (), false,
432+
label_lifetime_node_id, label.get_locus (), false, Rib::ItemType::Label,
433433
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
434434
rust_error_at (label.get_locus (), "label redefined multiple times");
435435
rust_error_at (locus, "was defined here");
@@ -446,7 +446,7 @@ ResolveExpr::visit (AST::ForLoopExpr &expr)
446446
resolver->push_new_label_rib (resolver->get_type_scope ().peek ());
447447

448448
// resolve the expression
449-
PatternDeclaration::go (expr.get_pattern ().get ());
449+
PatternDeclaration::go (expr.get_pattern ().get (), Rib::ItemType::Var);
450450
ResolveExpr::go (expr.get_iterator_expr ().get (), prefix, canonical_prefix);
451451
ResolveExpr::go (expr.get_loop_block ().get (), prefix, canonical_prefix);
452452

@@ -520,7 +520,7 @@ ResolveExpr::visit (AST::MatchExpr &expr)
520520
// insert any possible new patterns
521521
for (auto &pattern : arm.get_patterns ())
522522
{
523-
PatternDeclaration::go (pattern.get ());
523+
PatternDeclaration::go (pattern.get (), Rib::ItemType::Var);
524524
}
525525

526526
// resolve the body
@@ -581,9 +581,13 @@ ResolveExpr::visit (AST::ClosureExprInner &expr)
581581
resolve_closure_param (p);
582582
}
583583

584+
resolver->push_closure_context (expr.get_node_id ());
585+
584586
ResolveExpr::go (expr.get_definition_expr ().get (), prefix,
585587
canonical_prefix);
586588

589+
resolver->pop_closure_context ();
590+
587591
resolver->get_name_scope ().pop ();
588592
resolver->get_type_scope ().pop ();
589593
resolver->get_label_scope ().pop ();
@@ -606,9 +610,14 @@ ResolveExpr::visit (AST::ClosureExprInnerTyped &expr)
606610
}
607611

608612
ResolveType::go (expr.get_return_type ().get ());
613+
614+
resolver->push_closure_context (expr.get_node_id ());
615+
609616
ResolveExpr::go (expr.get_definition_block ().get (), prefix,
610617
canonical_prefix);
611618

619+
resolver->pop_closure_context ();
620+
612621
resolver->get_name_scope ().pop ();
613622
resolver->get_type_scope ().pop ();
614623
resolver->get_label_scope ().pop ();
@@ -617,7 +626,7 @@ ResolveExpr::visit (AST::ClosureExprInnerTyped &expr)
617626
void
618627
ResolveExpr::resolve_closure_param (AST::ClosureParam &param)
619628
{
620-
PatternDeclaration::go (param.get_pattern ().get ());
629+
PatternDeclaration::go (param.get_pattern ().get (), Rib::ItemType::Param);
621630

622631
if (param.has_type_given ())
623632
ResolveType::go (param.get_type ().get ());

gcc/rust/resolve/rust-ast-resolve-implitem.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ResolveToplevelImplItem : public ResolverBase
5656
auto path = prefix.append (decl);
5757

5858
resolver->get_type_scope ().insert (
59-
path, type.get_node_id (), type.get_locus (), false,
59+
path, type.get_node_id (), type.get_locus (), false, Rib::ItemType::Type,
6060
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
6161
RichLocation r (type.get_locus ());
6262
r.add_range (locus);
@@ -72,6 +72,7 @@ class ResolveToplevelImplItem : public ResolverBase
7272

7373
resolver->get_name_scope ().insert (
7474
path, constant.get_node_id (), constant.get_locus (), false,
75+
Rib::ItemType::Const,
7576
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
7677
RichLocation r (constant.get_locus ());
7778
r.add_range (locus);
@@ -87,6 +88,7 @@ class ResolveToplevelImplItem : public ResolverBase
8788

8889
resolver->get_name_scope ().insert (
8990
path, function.get_node_id (), function.get_locus (), false,
91+
Rib::ItemType::Function,
9092
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
9193
RichLocation r (function.get_locus ());
9294
r.add_range (locus);
@@ -102,6 +104,7 @@ class ResolveToplevelImplItem : public ResolverBase
102104

103105
resolver->get_name_scope ().insert (
104106
path, method.get_node_id (), method.get_locus (), false,
107+
Rib::ItemType::Function,
105108
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
106109
RichLocation r (method.get_locus ());
107110
r.add_range (locus);
@@ -141,6 +144,7 @@ class ResolveTopLevelTraitItems : public ResolverBase
141144

142145
resolver->get_name_scope ().insert (
143146
path, function.get_node_id (), function.get_locus (), false,
147+
Rib::ItemType::Function,
144148
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
145149
RichLocation r (function.get_locus ());
146150
r.add_range (locus);
@@ -159,6 +163,7 @@ class ResolveTopLevelTraitItems : public ResolverBase
159163

160164
resolver->get_name_scope ().insert (
161165
path, method.get_node_id (), method.get_locus (), false,
166+
Rib::ItemType::Function,
162167
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
163168
RichLocation r (method.get_locus ());
164169
r.add_range (locus);
@@ -177,6 +182,7 @@ class ResolveTopLevelTraitItems : public ResolverBase
177182

178183
resolver->get_name_scope ().insert (
179184
path, constant.get_node_id (), constant.get_locus (), false,
185+
Rib::ItemType::Const,
180186
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
181187
RichLocation r (constant.get_locus ());
182188
r.add_range (locus);
@@ -194,7 +200,7 @@ class ResolveTopLevelTraitItems : public ResolverBase
194200
auto cpath = canonical_prefix.append (decl);
195201

196202
resolver->get_type_scope ().insert (
197-
path, type.get_node_id (), type.get_locus (), false,
203+
path, type.get_node_id (), type.get_locus (), false, Rib::ItemType::Type,
198204
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
199205
RichLocation r (type.get_locus ());
200206
r.add_range (locus);
@@ -233,6 +239,7 @@ class ResolveToplevelExternItem : public ResolverBase
233239

234240
resolver->get_name_scope ().insert (
235241
path, function.get_node_id (), function.get_locus (), false,
242+
Rib::ItemType::Function,
236243
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
237244
RichLocation r (function.get_locus ());
238245
r.add_range (locus);
@@ -251,6 +258,7 @@ class ResolveToplevelExternItem : public ResolverBase
251258

252259
resolver->get_name_scope ().insert (
253260
path, item.get_node_id (), item.get_locus (), false,
261+
Rib::ItemType::Static,
254262
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
255263
RichLocation r (item.get_locus ());
256264
r.add_range (locus);

0 commit comments

Comments
 (0)