@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
337
337
) ,
338
338
ExprKind :: Try ( sub_expr) => self . lower_expr_try ( e. span , sub_expr) ,
339
339
340
- ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
340
+ ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
341
341
unreachable ! ( "already handled" )
342
342
}
343
343
@@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
874
874
/// }
875
875
/// ```
876
876
fn lower_expr_await ( & mut self , await_kw_span : Span , expr : & Expr ) -> hir:: ExprKind < ' hir > {
877
+ let expr = self . arena . alloc ( self . lower_expr_mut ( expr) ) ;
878
+ self . make_lowered_await ( await_kw_span, expr, FutureKind :: Future )
879
+ }
880
+
881
+ /// Takes an expr that has already been lowered and generates a desugared await loop around it
882
+ fn make_lowered_await (
883
+ & mut self ,
884
+ await_kw_span : Span ,
885
+ expr : & ' hir hir:: Expr < ' hir > ,
886
+ await_kind : FutureKind ,
887
+ ) -> hir:: ExprKind < ' hir > {
877
888
let full_span = expr. span . to ( await_kw_span) ;
878
889
879
890
let is_async_gen = match self . coroutine_kind {
@@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
887
898
}
888
899
} ;
889
900
890
- let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, None ) ;
901
+ let features = match await_kind {
902
+ FutureKind :: Future => None ,
903
+ FutureKind :: AsyncIterator => Some ( self . allow_for_await . clone ( ) ) ,
904
+ } ;
905
+ let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, features) ;
891
906
let gen_future_span = self . mark_span_with_reason (
892
907
DesugaringKind :: Await ,
893
908
full_span,
894
909
Some ( self . allow_gen_future . clone ( ) ) ,
895
910
) ;
896
- let expr = self . lower_expr_mut ( expr) ;
897
911
let expr_hir_id = expr. hir_id ;
898
912
899
913
// Note that the name of this binding must not be changed to something else because
@@ -933,11 +947,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
933
947
hir:: LangItem :: GetContext ,
934
948
arena_vec ! [ self ; task_context] ,
935
949
) ;
936
- let call = self . expr_call_lang_item_fn (
937
- span,
938
- hir:: LangItem :: FuturePoll ,
939
- arena_vec ! [ self ; new_unchecked, get_context] ,
940
- ) ;
950
+ let call = match await_kind {
951
+ FutureKind :: Future => self . expr_call_lang_item_fn (
952
+ span,
953
+ hir:: LangItem :: FuturePoll ,
954
+ arena_vec ! [ self ; new_unchecked, get_context] ,
955
+ ) ,
956
+ FutureKind :: AsyncIterator => self . expr_call_lang_item_fn (
957
+ span,
958
+ hir:: LangItem :: AsyncIteratorPollNext ,
959
+ arena_vec ! [ self ; new_unchecked, get_context] ,
960
+ ) ,
961
+ } ;
941
962
self . arena . alloc ( self . expr_unsafe ( call) )
942
963
} ;
943
964
@@ -1021,11 +1042,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
1021
1042
let awaitee_arm = self . arm ( awaitee_pat, loop_expr) ;
1022
1043
1023
1044
// `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
1024
- let into_future_expr = self . expr_call_lang_item_fn (
1025
- span,
1026
- hir:: LangItem :: IntoFutureIntoFuture ,
1027
- arena_vec ! [ self ; expr] ,
1028
- ) ;
1045
+ let into_future_expr = match await_kind {
1046
+ FutureKind :: Future => self . expr_call_lang_item_fn (
1047
+ span,
1048
+ hir:: LangItem :: IntoFutureIntoFuture ,
1049
+ arena_vec ! [ self ; * expr] ,
1050
+ ) ,
1051
+ // Not needed for `for await` because we expect to have already called
1052
+ // `IntoAsyncIterator::into_async_iter` on it.
1053
+ FutureKind :: AsyncIterator => expr,
1054
+ } ;
1029
1055
1030
1056
// match <into_future_expr> {
1031
1057
// mut __awaitee => loop { .. }
@@ -1673,7 +1699,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
1673
1699
head : & Expr ,
1674
1700
body : & Block ,
1675
1701
opt_label : Option < Label > ,
1676
- _loop_kind : ForLoopKind ,
1702
+ loop_kind : ForLoopKind ,
1677
1703
) -> hir:: Expr < ' hir > {
1678
1704
let head = self . lower_expr_mut ( head) ;
1679
1705
let pat = self . lower_pat ( pat) ;
@@ -1702,17 +1728,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
1702
1728
let ( iter_pat, iter_pat_nid) =
1703
1729
self . pat_ident_binding_mode ( head_span, iter, hir:: BindingAnnotation :: MUT ) ;
1704
1730
1705
- // `match Iterator::next(&mut iter) { ... }`
1706
1731
let match_expr = {
1707
1732
let iter = self . expr_ident ( head_span, iter, iter_pat_nid) ;
1708
- let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1709
- let next_expr = self . expr_call_lang_item_fn (
1710
- head_span,
1711
- hir:: LangItem :: IteratorNext ,
1712
- arena_vec ! [ self ; ref_mut_iter] ,
1713
- ) ;
1733
+ let next_expr = match loop_kind {
1734
+ ForLoopKind :: For => {
1735
+ // `Iterator::next(&mut iter)`
1736
+ let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1737
+ self . expr_call_lang_item_fn (
1738
+ head_span,
1739
+ hir:: LangItem :: IteratorNext ,
1740
+ arena_vec ! [ self ; ref_mut_iter] ,
1741
+ )
1742
+ }
1743
+ ForLoopKind :: ForAwait => {
1744
+ // we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
1745
+ // to make_lowered_await with `FutureKind::AsyncIterator` which will generator
1746
+ // calls to `poll_next`. In user code, this would probably be a call to
1747
+ // `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
1748
+
1749
+ // `&mut iter`
1750
+ let iter = self . expr_mut_addr_of ( head_span, iter) ;
1751
+ // `Pin::new_unchecked(...)`
1752
+ let iter = self . arena . alloc ( self . expr_call_lang_item_fn_mut (
1753
+ head_span,
1754
+ hir:: LangItem :: PinNewUnchecked ,
1755
+ arena_vec ! [ self ; iter] ,
1756
+ ) ) ;
1757
+ // `unsafe { ... }`
1758
+ let iter = self . arena . alloc ( self . expr_unsafe ( iter) ) ;
1759
+ let kind = self . make_lowered_await ( head_span, iter, FutureKind :: AsyncIterator ) ;
1760
+ self . arena . alloc ( hir:: Expr { hir_id : self . next_id ( ) , kind, span : head_span } )
1761
+ }
1762
+ } ;
1714
1763
let arms = arena_vec ! [ self ; none_arm, some_arm] ;
1715
1764
1765
+ // `match $next_expr { ... }`
1716
1766
self . expr_match ( head_span, next_expr, arms, hir:: MatchSource :: ForLoopDesugar )
1717
1767
} ;
1718
1768
let match_stmt = self . stmt_expr ( for_span, match_expr) ;
@@ -1732,13 +1782,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
1732
1782
// `mut iter => { ... }`
1733
1783
let iter_arm = self . arm ( iter_pat, loop_expr) ;
1734
1784
1735
- // `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1736
- let into_iter_expr = {
1737
- self . expr_call_lang_item_fn (
1738
- head_span,
1739
- hir:: LangItem :: IntoIterIntoIter ,
1740
- arena_vec ! [ self ; head] ,
1741
- )
1785
+ let into_iter_expr = match loop_kind {
1786
+ ForLoopKind :: For => {
1787
+ // `::std::iter::IntoIterator::into_iter(<head>)`
1788
+ self . expr_call_lang_item_fn (
1789
+ head_span,
1790
+ hir:: LangItem :: IntoIterIntoIter ,
1791
+ arena_vec ! [ self ; head] ,
1792
+ )
1793
+ }
1794
+ ForLoopKind :: ForAwait => self . arena . alloc ( head) ,
1742
1795
} ;
1743
1796
1744
1797
let match_expr = self . arena . alloc ( self . expr_match (
@@ -2141,3 +2194,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
2141
2194
}
2142
2195
}
2143
2196
}
2197
+
2198
+ /// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
2199
+ /// of future we are awaiting.
2200
+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
2201
+ enum FutureKind {
2202
+ /// We are awaiting a normal future
2203
+ Future ,
2204
+ /// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
2205
+ /// a `for await` loop)
2206
+ AsyncIterator ,
2207
+ }
0 commit comments