Skip to content

Commit 97df0d3

Browse files
committed
Desugar for await loops
1 parent 27d6539 commit 97df0d3

File tree

8 files changed

+125
-30
lines changed

8 files changed

+125
-30
lines changed

compiler/rustc_ast_lowering/src/expr.rs

+92-28
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
337337
),
338338
ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr),
339339

340-
ExprKind::Paren(_) | ExprKind::ForLoop{..} => {
340+
ExprKind::Paren(_) | ExprKind::ForLoop { .. } => {
341341
unreachable!("already handled")
342342
}
343343

@@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
874874
/// }
875875
/// ```
876876
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
877+
let expr = self.arena.alloc(self.lower_expr_mut(expr));
878+
self.make_lowered_await(await_kw_span, expr, FutureKind::Future)
879+
}
880+
881+
/// Takes an expr that has already been lowered and generates a desugared await loop around it
882+
fn make_lowered_await(
883+
&mut self,
884+
await_kw_span: Span,
885+
expr: &'hir hir::Expr<'hir>,
886+
await_kind: FutureKind,
887+
) -> hir::ExprKind<'hir> {
877888
let full_span = expr.span.to(await_kw_span);
878889

879890
let is_async_gen = match self.coroutine_kind {
@@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
887898
}
888899
};
889900

890-
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None);
901+
let features = match await_kind {
902+
FutureKind::Future => None,
903+
FutureKind::AsyncIterator => Some(self.allow_for_await.clone()),
904+
};
905+
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, features);
891906
let gen_future_span = self.mark_span_with_reason(
892907
DesugaringKind::Await,
893908
full_span,
894909
Some(self.allow_gen_future.clone()),
895910
);
896-
let expr = self.lower_expr_mut(expr);
897911
let expr_hir_id = expr.hir_id;
898912

899913
// Note that the name of this binding must not be changed to something else because
@@ -933,11 +947,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
933947
hir::LangItem::GetContext,
934948
arena_vec![self; task_context],
935949
);
936-
let call = self.expr_call_lang_item_fn(
937-
span,
938-
hir::LangItem::FuturePoll,
939-
arena_vec![self; new_unchecked, get_context],
940-
);
950+
let call = match await_kind {
951+
FutureKind::Future => self.expr_call_lang_item_fn(
952+
span,
953+
hir::LangItem::FuturePoll,
954+
arena_vec![self; new_unchecked, get_context],
955+
),
956+
FutureKind::AsyncIterator => self.expr_call_lang_item_fn(
957+
span,
958+
hir::LangItem::AsyncIteratorPollNext,
959+
arena_vec![self; new_unchecked, get_context],
960+
),
961+
};
941962
self.arena.alloc(self.expr_unsafe(call))
942963
};
943964

@@ -1021,11 +1042,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
10211042
let awaitee_arm = self.arm(awaitee_pat, loop_expr);
10221043

10231044
// `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
1024-
let into_future_expr = self.expr_call_lang_item_fn(
1025-
span,
1026-
hir::LangItem::IntoFutureIntoFuture,
1027-
arena_vec![self; expr],
1028-
);
1045+
let into_future_expr = match await_kind {
1046+
FutureKind::Future => self.expr_call_lang_item_fn(
1047+
span,
1048+
hir::LangItem::IntoFutureIntoFuture,
1049+
arena_vec![self; *expr],
1050+
),
1051+
// Not needed for `for await` because we expect to have already called
1052+
// `IntoAsyncIterator::into_async_iter` on it.
1053+
FutureKind::AsyncIterator => expr,
1054+
};
10291055

10301056
// match <into_future_expr> {
10311057
// mut __awaitee => loop { .. }
@@ -1673,7 +1699,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16731699
head: &Expr,
16741700
body: &Block,
16751701
opt_label: Option<Label>,
1676-
_loop_kind: ForLoopKind,
1702+
loop_kind: ForLoopKind,
16771703
) -> hir::Expr<'hir> {
16781704
let head = self.lower_expr_mut(head);
16791705
let pat = self.lower_pat(pat);
@@ -1702,17 +1728,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
17021728
let (iter_pat, iter_pat_nid) =
17031729
self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT);
17041730

1705-
// `match Iterator::next(&mut iter) { ... }`
17061731
let match_expr = {
17071732
let iter = self.expr_ident(head_span, iter, iter_pat_nid);
1708-
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
1709-
let next_expr = self.expr_call_lang_item_fn(
1710-
head_span,
1711-
hir::LangItem::IteratorNext,
1712-
arena_vec![self; ref_mut_iter],
1713-
);
1733+
let next_expr = match loop_kind {
1734+
ForLoopKind::For => {
1735+
// `Iterator::next(&mut iter)`
1736+
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
1737+
self.expr_call_lang_item_fn(
1738+
head_span,
1739+
hir::LangItem::IteratorNext,
1740+
arena_vec![self; ref_mut_iter],
1741+
)
1742+
}
1743+
ForLoopKind::ForAwait => {
1744+
// we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
1745+
// to make_lowered_await with `FutureKind::AsyncIterator` which will generator
1746+
// calls to `poll_next`. In user code, this would probably be a call to
1747+
// `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
1748+
1749+
// `&mut iter`
1750+
let iter = self.expr_mut_addr_of(head_span, iter);
1751+
// `Pin::new_unchecked(...)`
1752+
let iter = self.arena.alloc(self.expr_call_lang_item_fn_mut(
1753+
head_span,
1754+
hir::LangItem::PinNewUnchecked,
1755+
arena_vec![self; iter],
1756+
));
1757+
// `unsafe { ... }`
1758+
let iter = self.arena.alloc(self.expr_unsafe(iter));
1759+
let kind = self.make_lowered_await(head_span, iter, FutureKind::AsyncIterator);
1760+
self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span })
1761+
}
1762+
};
17141763
let arms = arena_vec![self; none_arm, some_arm];
17151764

1765+
// `match $next_expr { ... }`
17161766
self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar)
17171767
};
17181768
let match_stmt = self.stmt_expr(for_span, match_expr);
@@ -1732,13 +1782,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
17321782
// `mut iter => { ... }`
17331783
let iter_arm = self.arm(iter_pat, loop_expr);
17341784

1735-
// `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1736-
let into_iter_expr = {
1737-
self.expr_call_lang_item_fn(
1738-
head_span,
1739-
hir::LangItem::IntoIterIntoIter,
1740-
arena_vec![self; head],
1741-
)
1785+
let into_iter_expr = match loop_kind {
1786+
ForLoopKind::For => {
1787+
// `::std::iter::IntoIterator::into_iter(<head>)`
1788+
self.expr_call_lang_item_fn(
1789+
head_span,
1790+
hir::LangItem::IntoIterIntoIter,
1791+
arena_vec![self; head],
1792+
)
1793+
}
1794+
ForLoopKind::ForAwait => self.arena.alloc(head),
17421795
};
17431796

17441797
let match_expr = self.arena.alloc(self.expr_match(
@@ -2141,3 +2194,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
21412194
}
21422195
}
21432196
}
2197+
2198+
/// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
2199+
/// of future we are awaiting.
2200+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
2201+
enum FutureKind {
2202+
/// We are awaiting a normal future
2203+
Future,
2204+
/// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
2205+
/// a `for await` loop)
2206+
AsyncIterator,
2207+
}

compiler/rustc_ast_lowering/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ struct LoweringContext<'a, 'hir> {
130130
allow_try_trait: Lrc<[Symbol]>,
131131
allow_gen_future: Lrc<[Symbol]>,
132132
allow_async_iterator: Lrc<[Symbol]>,
133+
allow_for_await: Lrc<[Symbol]>,
133134

134135
/// Mapping from generics `def_id`s to TAIT generics `def_id`s.
135136
/// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
@@ -174,6 +175,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
174175
} else {
175176
[sym::gen_future].into()
176177
},
178+
allow_for_await: [sym::async_iterator].into(),
177179
// FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
178180
// interact with `gen`/`async gen` blocks
179181
allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),

compiler/rustc_builtin_macros/src/assert/context.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ impl<'cx, 'a> Context<'cx, 'a> {
303303
| ExprKind::Continue(_)
304304
| ExprKind::Err
305305
| ExprKind::Field(_, _)
306-
| ExprKind::ForLoop {..}
306+
| ExprKind::ForLoop { .. }
307307
| ExprKind::FormatArgs(_)
308308
| ExprKind::IncludedBytes(..)
309309
| ExprKind::InlineAsm(_)

compiler/rustc_feature/src/unstable.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ declare_features! (
358358
/// Allows `#[track_caller]` on async functions.
359359
(unstable, async_fn_track_caller, "1.73.0", Some(110011)),
360360
/// Allows `for await` loops.
361-
(unstable, async_for_loop, "CURRENT_RUSTC_VERSION", None),
361+
(unstable, async_for_loop, "CURRENT_RUSTC_VERSION", Some(118898)),
362362
/// Allows builtin # foo() syntax
363363
(unstable, builtin_syntax, "1.71.0", Some(110680)),
364364
/// Treat `extern "C"` function as nounwind.

compiler/rustc_hir/src/lang_items.rs

+2
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ language_item_table! {
307307
Context, sym::Context, context, Target::Struct, GenericRequirement::None;
308308
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
309309

310+
AsyncIteratorPollNext, sym::async_iterator_poll_next, async_iterator_poll_next, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(0);
311+
310312
Option, sym::Option, option_type, Target::Enum, GenericRequirement::None;
311313
OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None;
312314
OptionNone, sym::None, option_none_variant, Target::Variant, GenericRequirement::None;

compiler/rustc_span/src/symbol.rs

+2
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ symbols! {
428428
async_fn_track_caller,
429429
async_for_loop,
430430
async_iterator,
431+
async_iterator_poll_next,
431432
atomic,
432433
atomic_mod,
433434
atomics,
@@ -894,6 +895,7 @@ symbols! {
894895
instruction_set,
895896
integer_: "integer", // underscore to avoid clashing with the function `sym::integer` below
896897
integral,
898+
into_async_iter_into_iter,
897899
into_future,
898900
into_iter,
899901
intra_doc_pointers,

library/core/src/async_iter/async_iter.rs

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub trait AsyncIterator {
4747
/// Rust's usual rules apply: calls must never cause undefined behavior
4848
/// (memory corruption, incorrect use of `unsafe` functions, or the like),
4949
/// regardless of the async iterator's state.
50+
#[cfg_attr(not(bootstrap), lang = "async_iterator_poll_next")]
5051
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
5152

5253
/// Returns the bounds on the remaining length of the async iterator.

tests/ui/async-await/for-await.rs

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// run-pass
2+
// edition: 2021
3+
#![feature(async_iterator, async_iter_from_iter, const_waker, async_for_loop, noop_waker)]
4+
5+
use std::future::Future;
6+
7+
// make sure a simple for await loop works
8+
async fn real_main() {
9+
let iter = core::async_iter::from_iter(0..3);
10+
let mut count = 0;
11+
for await i in iter {
12+
assert_eq!(i, count);
13+
count += 1;
14+
}
15+
assert_eq!(count, 3);
16+
}
17+
18+
fn main() {
19+
let future = real_main();
20+
let waker = std::task::Waker::noop();
21+
let mut cx = &mut core::task::Context::from_waker(&waker);
22+
let mut future = core::pin::pin!(future);
23+
while let core::task::Poll::Pending = future.as_mut().poll(&mut cx) {}
24+
}

0 commit comments

Comments
 (0)