@@ -10,6 +10,7 @@ use rustc_hir as hir;
10
10
use rustc_hir:: def:: Res ;
11
11
use rustc_span:: source_map:: { respan, DesugaringKind , Span , Spanned } ;
12
12
use rustc_span:: symbol:: { sym, Symbol } ;
13
+ use rustc_span:: DUMMY_SP ;
13
14
14
15
impl < ' hir > LoweringContext < ' _ , ' hir > {
15
16
fn lower_exprs ( & mut self , exprs : & [ AstP < Expr > ] ) -> & ' hir [ hir:: Expr < ' hir > ] {
@@ -483,14 +484,44 @@ impl<'hir> LoweringContext<'_, 'hir> {
483
484
Some ( ty) => FnRetTy :: Ty ( ty) ,
484
485
None => FnRetTy :: Default ( span) ,
485
486
} ;
486
- let ast_decl = FnDecl { inputs : vec ! [ ] , output } ;
487
+
488
+ let task_context_id = self . resolver . next_node_id ( ) ;
489
+ let task_context_hid = self . lower_node_id ( task_context_id) ;
490
+ let ast_decl = FnDecl {
491
+ inputs : vec ! [ Param {
492
+ attrs: AttrVec :: new( ) ,
493
+ ty: AstP ( Ty {
494
+ id: self . resolver. next_node_id( ) ,
495
+ kind: TyKind :: Infer ,
496
+ span: DUMMY_SP ,
497
+ } ) ,
498
+ pat: AstP ( Pat {
499
+ id: task_context_id,
500
+ kind: PatKind :: Ident (
501
+ BindingMode :: ByValue ( Mutability :: Mut ) ,
502
+ Ident :: with_dummy_span( sym:: _task_context) ,
503
+ None ,
504
+ ) ,
505
+ span: DUMMY_SP ,
506
+ } ) ,
507
+ id: self . resolver. next_node_id( ) ,
508
+ span: DUMMY_SP ,
509
+ is_placeholder: false ,
510
+ } ] ,
511
+ output,
512
+ } ;
487
513
let decl = self . lower_fn_decl ( & ast_decl, None , /* impl trait allowed */ false , None ) ;
488
514
let body_id = self . lower_fn_body ( & ast_decl, |this| {
489
515
this. generator_kind = Some ( hir:: GeneratorKind :: Async ( async_gen_kind) ) ;
490
- body ( this)
516
+
517
+ let old_ctx = this. task_context ;
518
+ this. task_context = Some ( task_context_hid) ;
519
+ let res = body ( this) ;
520
+ this. task_context = old_ctx;
521
+ res
491
522
} ) ;
492
523
493
- // `static || -> <ret_ty> { body }`:
524
+ // `static |task_context | -> <ret_ty> { body }`:
494
525
let generator_kind = hir:: ExprKind :: Closure (
495
526
capture_clause,
496
527
decl,
@@ -523,9 +554,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
523
554
/// ```rust
524
555
/// match <expr> {
525
556
/// mut pinned => loop {
526
- /// match ::std::future::poll_with_tls_context(unsafe {
527
- /// <::std::pin::Pin>::new_unchecked(&mut pinned)
528
- /// }) {
557
+ /// match unsafe { ::std::future::poll_with_context(
558
+ /// <::std::pin::Pin>::new_unchecked(&mut pinned),
559
+ /// task_context,
560
+ /// ) } {
529
561
/// ::std::task::Poll::Ready(result) => break result,
530
562
/// ::std::task::Poll::Pending => {}
531
563
/// }
@@ -561,12 +593,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
561
593
let ( pinned_pat, pinned_pat_hid) =
562
594
self . pat_ident_binding_mode ( span, pinned_ident, hir:: BindingAnnotation :: Mutable ) ;
563
595
564
- // ::std::future::poll_with_tls_context(unsafe {
565
- // ::std::pin::Pin::new_unchecked(&mut pinned)
566
- // })`
596
+ let task_context_ident = Ident :: with_dummy_span ( sym:: _task_context) ;
597
+
598
+ // unsafe {
599
+ // ::std::future::poll_with_context(
600
+ // ::std::pin::Pin::new_unchecked(&mut pinned),
601
+ // task_context,
602
+ // )
603
+ // }
567
604
let poll_expr = {
568
605
let pinned = self . expr_ident ( span, pinned_ident, pinned_pat_hid) ;
569
606
let ref_mut_pinned = self . expr_mut_addr_of ( span, pinned) ;
607
+ let task_context = if let Some ( task_context_hid) = self . task_context {
608
+ self . expr_ident_mut ( span, task_context_ident, task_context_hid)
609
+ } else {
610
+ // Use of `await` outside of an async context, we cannot use `task_context` here.
611
+ self . expr_err ( span)
612
+ } ;
570
613
let pin_ty_id = self . next_id ( ) ;
571
614
let new_unchecked_expr_kind = self . expr_call_std_assoc_fn (
572
615
pin_ty_id,
@@ -575,14 +618,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
575
618
"new_unchecked" ,
576
619
arena_vec ! [ self ; ref_mut_pinned] ,
577
620
) ;
578
- let new_unchecked =
579
- self . arena . alloc ( self . expr ( span, new_unchecked_expr_kind, ThinVec :: new ( ) ) ) ;
580
- let unsafe_expr = self . expr_unsafe ( new_unchecked) ;
581
- self . expr_call_std_path (
621
+ let new_unchecked = self . expr ( span, new_unchecked_expr_kind, ThinVec :: new ( ) ) ;
622
+ let call = self . expr_call_std_path (
582
623
gen_future_span,
583
- & [ sym:: future, sym:: poll_with_tls_context] ,
584
- arena_vec ! [ self ; unsafe_expr] ,
585
- )
624
+ & [ sym:: future, sym:: poll_with_context] ,
625
+ arena_vec ! [ self ; new_unchecked, task_context] ,
626
+ ) ;
627
+ self . arena . alloc ( self . expr_unsafe ( call) )
586
628
} ;
587
629
588
630
// `::std::task::Poll::Ready(result) => break result`
@@ -629,12 +671,27 @@ impl<'hir> LoweringContext<'_, 'hir> {
629
671
hir:: ExprKind :: Yield ( unit, hir:: YieldSource :: Await ) ,
630
672
ThinVec :: new ( ) ,
631
673
) ;
632
- self . stmt_expr ( span, yield_expr)
674
+ let yield_expr = self . arena . alloc ( yield_expr) ;
675
+
676
+ if let Some ( task_context_hid) = self . task_context {
677
+ let lhs = self . expr_ident ( span, task_context_ident, task_context_hid) ;
678
+ let assign = self . expr (
679
+ span,
680
+ hir:: ExprKind :: Assign ( lhs, yield_expr, span) ,
681
+ AttrVec :: new ( ) ,
682
+ ) ;
683
+ self . stmt_expr ( span, assign)
684
+ } else {
685
+ // Use of `await` outside of an async context. Return `yield_expr` so that we can
686
+ // proceed with type checking.
687
+ self . stmt ( span, hir:: StmtKind :: Semi ( yield_expr) )
688
+ }
633
689
} ;
634
690
635
- let loop_block = self . block_all ( span, arena_vec ! [ self ; inner_match_stmt, yield_stmt] , None ) ;
691
+ let loop_block =
692
+ self . block_all ( span, arena_vec ! [ self ; inner_match_stmt, yield_stmt] , None ) ;
636
693
637
- // loop { .. }
694
+ // loop { ...; task_context = yield (); }
638
695
let loop_expr = self . arena . alloc ( hir:: Expr {
639
696
hir_id : loop_hir_id,
640
697
kind : hir:: ExprKind :: Loop ( loop_block, None , hir:: LoopSource :: Loop ) ,
0 commit comments