@@ -66,9 +66,9 @@ use rustc_index::{Idx, IndexVec};
66
66
use rustc_middle:: mir:: dump_mir;
67
67
use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext , Visitor } ;
68
68
use rustc_middle:: mir:: * ;
69
+ use rustc_middle:: ty:: CoroutineArgs ;
69
70
use rustc_middle:: ty:: InstanceDef ;
70
- use rustc_middle:: ty:: { self , AdtDef , Ty , TyCtxt } ;
71
- use rustc_middle:: ty:: { CoroutineArgs , GenericArgsRef } ;
71
+ use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
72
72
use rustc_mir_dataflow:: impls:: {
73
73
MaybeBorrowedLocals , MaybeLiveLocals , MaybeRequiresStorage , MaybeStorageLive ,
74
74
} ;
@@ -225,8 +225,6 @@ struct SuspensionPoint<'tcx> {
225
225
struct TransformVisitor < ' tcx > {
226
226
tcx : TyCtxt < ' tcx > ,
227
227
coroutine_kind : hir:: CoroutineKind ,
228
- state_adt_ref : AdtDef < ' tcx > ,
229
- state_args : GenericArgsRef < ' tcx > ,
230
228
231
229
// The type of the discriminant in the coroutine struct
232
230
discr_ty : Ty < ' tcx > ,
@@ -245,21 +243,34 @@ struct TransformVisitor<'tcx> {
245
243
always_live_locals : BitSet < Local > ,
246
244
247
245
// The original RETURN_PLACE local
248
- new_ret_local : Local ,
246
+ old_ret_local : Local ,
247
+
248
+ old_yield_ty : Ty < ' tcx > ,
249
+
250
+ old_ret_ty : Ty < ' tcx > ,
249
251
}
250
252
251
253
impl < ' tcx > TransformVisitor < ' tcx > {
252
254
fn insert_none_ret_block ( & self , body : & mut Body < ' tcx > ) -> BasicBlock {
253
- let block = BasicBlock :: new ( body . basic_blocks . len ( ) ) ;
255
+ assert ! ( matches! ( self . coroutine_kind , CoroutineKind :: Gen ( _ ) ) ) ;
254
256
257
+ let block = BasicBlock :: new ( body. basic_blocks . len ( ) ) ;
255
258
let source_info = SourceInfo :: outermost ( body. span ) ;
259
+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
256
260
257
- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( true ) ;
258
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
259
261
let statements = vec ! [ Statement {
260
262
kind: StatementKind :: Assign ( Box :: new( (
261
263
Place :: return_place( ) ,
262
- Rvalue :: Aggregate ( Box :: new( kind) , IndexVec :: new( ) ) ,
264
+ Rvalue :: Aggregate (
265
+ Box :: new( AggregateKind :: Adt (
266
+ option_def_id,
267
+ VariantIdx :: from_usize( 0 ) ,
268
+ self . tcx. mk_args( & [ self . old_yield_ty. into( ) ] ) ,
269
+ None ,
270
+ None ,
271
+ ) ) ,
272
+ IndexVec :: new( ) ,
273
+ ) ,
263
274
) ) ) ,
264
275
source_info,
265
276
} ] ;
@@ -273,23 +284,6 @@ impl<'tcx> TransformVisitor<'tcx> {
273
284
block
274
285
}
275
286
276
- fn coroutine_state_adt_and_variant_idx (
277
- & self ,
278
- is_return : bool ,
279
- ) -> ( AggregateKind < ' tcx > , VariantIdx ) {
280
- let idx = VariantIdx :: new ( match ( is_return, self . coroutine_kind ) {
281
- ( true , hir:: CoroutineKind :: Coroutine ) => 1 , // CoroutineState::Complete
282
- ( false , hir:: CoroutineKind :: Coroutine ) => 0 , // CoroutineState::Yielded
283
- ( true , hir:: CoroutineKind :: Async ( _) ) => 0 , // Poll::Ready
284
- ( false , hir:: CoroutineKind :: Async ( _) ) => 1 , // Poll::Pending
285
- ( true , hir:: CoroutineKind :: Gen ( _) ) => 0 , // Option::None
286
- ( false , hir:: CoroutineKind :: Gen ( _) ) => 1 , // Option::Some
287
- } ) ;
288
-
289
- let kind = AggregateKind :: Adt ( self . state_adt_ref . did ( ) , idx, self . state_args , None , None ) ;
290
- ( kind, idx)
291
- }
292
-
293
287
// Make a `CoroutineState` or `Poll` variant assignment.
294
288
//
295
289
// `core::ops::CoroutineState` only has single element tuple variants,
@@ -302,51 +296,99 @@ impl<'tcx> TransformVisitor<'tcx> {
302
296
is_return : bool ,
303
297
statements : & mut Vec < Statement < ' tcx > > ,
304
298
) {
305
- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( is_return) ;
306
-
307
- match self . coroutine_kind {
308
- // `Poll::Pending`
299
+ let rvalue = match self . coroutine_kind {
309
300
CoroutineKind :: Async ( _) => {
310
- if !is_return {
311
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
312
-
313
- // FIXME(swatinem): assert that `val` is indeed unit?
314
- statements. push ( Statement {
315
- kind : StatementKind :: Assign ( Box :: new ( (
316
- Place :: return_place ( ) ,
317
- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
318
- ) ) ) ,
319
- source_info,
320
- } ) ;
321
- return ;
301
+ let poll_def_id = self . tcx . require_lang_item ( LangItem :: Poll , None ) ;
302
+ let args = self . tcx . mk_args ( & [ self . old_ret_ty . into ( ) ] ) ;
303
+ if is_return {
304
+ // Poll::Ready(val)
305
+ Rvalue :: Aggregate (
306
+ Box :: new ( AggregateKind :: Adt (
307
+ poll_def_id,
308
+ VariantIdx :: from_usize ( 0 ) ,
309
+ args,
310
+ None ,
311
+ None ,
312
+ ) ) ,
313
+ IndexVec :: from_raw ( vec ! [ val] ) ,
314
+ )
315
+ } else {
316
+ // Poll::Pending
317
+ Rvalue :: Aggregate (
318
+ Box :: new ( AggregateKind :: Adt (
319
+ poll_def_id,
320
+ VariantIdx :: from_usize ( 1 ) ,
321
+ args,
322
+ None ,
323
+ None ,
324
+ ) ) ,
325
+ IndexVec :: new ( ) ,
326
+ )
322
327
}
323
328
}
324
- // `Option::None`
325
329
CoroutineKind :: Gen ( _) => {
330
+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
331
+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) ] ) ;
326
332
if is_return {
327
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
328
-
329
- statements. push ( Statement {
330
- kind : StatementKind :: Assign ( Box :: new ( (
331
- Place :: return_place ( ) ,
332
- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
333
- ) ) ) ,
334
- source_info,
335
- } ) ;
336
- return ;
333
+ // None
334
+ Rvalue :: Aggregate (
335
+ Box :: new ( AggregateKind :: Adt (
336
+ option_def_id,
337
+ VariantIdx :: from_usize ( 0 ) ,
338
+ args,
339
+ None ,
340
+ None ,
341
+ ) ) ,
342
+ IndexVec :: new ( ) ,
343
+ )
344
+ } else {
345
+ // Some(val)
346
+ Rvalue :: Aggregate (
347
+ Box :: new ( AggregateKind :: Adt (
348
+ option_def_id,
349
+ VariantIdx :: from_usize ( 1 ) ,
350
+ args,
351
+ None ,
352
+ None ,
353
+ ) ) ,
354
+ IndexVec :: from_raw ( vec ! [ val] ) ,
355
+ )
337
356
}
338
357
}
339
- CoroutineKind :: Coroutine => { }
340
- }
341
-
342
- // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
343
- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 1 ) ;
358
+ CoroutineKind :: Coroutine => {
359
+ let coroutine_state_def_id =
360
+ self . tcx . require_lang_item ( LangItem :: CoroutineState , None ) ;
361
+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) , self . old_ret_ty . into ( ) ] ) ;
362
+ if is_return {
363
+ // CoroutineState::Complete(val)
364
+ Rvalue :: Aggregate (
365
+ Box :: new ( AggregateKind :: Adt (
366
+ coroutine_state_def_id,
367
+ VariantIdx :: from_usize ( 1 ) ,
368
+ args,
369
+ None ,
370
+ None ,
371
+ ) ) ,
372
+ IndexVec :: from_raw ( vec ! [ val] ) ,
373
+ )
374
+ } else {
375
+ // CoroutineState::Yielded(val)
376
+ Rvalue :: Aggregate (
377
+ Box :: new ( AggregateKind :: Adt (
378
+ coroutine_state_def_id,
379
+ VariantIdx :: from_usize ( 0 ) ,
380
+ args,
381
+ None ,
382
+ None ,
383
+ ) ) ,
384
+ IndexVec :: from_raw ( vec ! [ val] ) ,
385
+ )
386
+ }
387
+ }
388
+ } ;
344
389
345
390
statements. push ( Statement {
346
- kind : StatementKind :: Assign ( Box :: new ( (
347
- Place :: return_place ( ) ,
348
- Rvalue :: Aggregate ( Box :: new ( kind) , [ val] . into ( ) ) ,
349
- ) ) ) ,
391
+ kind : StatementKind :: Assign ( Box :: new ( ( Place :: return_place ( ) , rvalue) ) ) ,
350
392
source_info,
351
393
} ) ;
352
394
}
@@ -420,7 +462,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
420
462
421
463
let ret_val = match data. terminator ( ) . kind {
422
464
TerminatorKind :: Return => {
423
- Some ( ( true , None , Operand :: Move ( Place :: from ( self . new_ret_local ) ) , None ) )
465
+ Some ( ( true , None , Operand :: Move ( Place :: from ( self . old_ret_local ) ) , None ) )
424
466
}
425
467
TerminatorKind :: Yield { ref value, resume, resume_arg, drop } => {
426
468
Some ( ( false , Some ( ( resume, resume_arg) ) , value. clone ( ) , drop) )
@@ -1493,10 +1535,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
1493
1535
1494
1536
impl < ' tcx > MirPass < ' tcx > for StateTransform {
1495
1537
fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
1496
- let Some ( yield_ty ) = body. yield_ty ( ) else {
1538
+ let Some ( old_yield_ty ) = body. yield_ty ( ) else {
1497
1539
// This only applies to coroutines
1498
1540
return ;
1499
1541
} ;
1542
+ let old_ret_ty = body. return_ty ( ) ;
1500
1543
1501
1544
assert ! ( body. coroutine_drop( ) . is_none( ) ) ;
1502
1545
@@ -1520,34 +1563,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1520
1563
1521
1564
let is_async_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Async ( _) ) ) ;
1522
1565
let is_gen_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Gen ( _) ) ) ;
1523
- let ( state_adt_ref , state_args ) = match body. coroutine_kind ( ) . unwrap ( ) {
1566
+ let new_ret_ty = match body. coroutine_kind ( ) . unwrap ( ) {
1524
1567
CoroutineKind :: Async ( _) => {
1525
1568
// Compute Poll<return_ty>
1526
1569
let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1527
1570
let poll_adt_ref = tcx. adt_def ( poll_did) ;
1528
- let poll_args = tcx. mk_args ( & [ body . return_ty ( ) . into ( ) ] ) ;
1529
- ( poll_adt_ref, poll_args)
1571
+ let poll_args = tcx. mk_args ( & [ old_ret_ty . into ( ) ] ) ;
1572
+ Ty :: new_adt ( tcx , poll_adt_ref, poll_args)
1530
1573
}
1531
1574
CoroutineKind :: Gen ( _) => {
1532
1575
// Compute Option<yield_ty>
1533
1576
let option_did = tcx. require_lang_item ( LangItem :: Option , None ) ;
1534
1577
let option_adt_ref = tcx. adt_def ( option_did) ;
1535
- let option_args = tcx. mk_args ( & [ body . yield_ty ( ) . unwrap ( ) . into ( ) ] ) ;
1536
- ( option_adt_ref, option_args)
1578
+ let option_args = tcx. mk_args ( & [ old_yield_ty . into ( ) ] ) ;
1579
+ Ty :: new_adt ( tcx , option_adt_ref, option_args)
1537
1580
}
1538
1581
CoroutineKind :: Coroutine => {
1539
1582
// Compute CoroutineState<yield_ty, return_ty>
1540
1583
let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
1541
1584
let state_adt_ref = tcx. adt_def ( state_did) ;
1542
- let state_args = tcx. mk_args ( & [ yield_ty . into ( ) , body . return_ty ( ) . into ( ) ] ) ;
1543
- ( state_adt_ref, state_args)
1585
+ let state_args = tcx. mk_args ( & [ old_yield_ty . into ( ) , old_ret_ty . into ( ) ] ) ;
1586
+ Ty :: new_adt ( tcx , state_adt_ref, state_args)
1544
1587
}
1545
1588
} ;
1546
- let ret_ty = Ty :: new_adt ( tcx, state_adt_ref, state_args) ;
1547
1589
1548
- // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1590
+ // We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
1549
1591
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1550
- let new_ret_local = replace_local ( RETURN_PLACE , ret_ty , body, tcx) ;
1592
+ let old_ret_local = replace_local ( RETURN_PLACE , new_ret_ty , body, tcx) ;
1551
1593
1552
1594
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1553
1595
if is_async_kind {
@@ -1564,17 +1606,18 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1564
1606
} else {
1565
1607
body. local_decls [ resume_local] . ty
1566
1608
} ;
1567
- let new_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
1609
+ let old_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
1568
1610
1569
- // When first entering the coroutine, move the resume argument into its new local.
1611
+ // When first entering the coroutine, move the resume argument into its old local
1612
+ // (which is now a generator interior).
1570
1613
let source_info = SourceInfo :: outermost ( body. span ) ;
1571
1614
let stmts = & mut body. basic_blocks_mut ( ) [ START_BLOCK ] . statements ;
1572
1615
stmts. insert (
1573
1616
0 ,
1574
1617
Statement {
1575
1618
source_info,
1576
1619
kind : StatementKind :: Assign ( Box :: new ( (
1577
- new_resume_local . into ( ) ,
1620
+ old_resume_local . into ( ) ,
1578
1621
Rvalue :: Use ( Operand :: Move ( resume_local. into ( ) ) ) ,
1579
1622
) ) ) ,
1580
1623
} ,
@@ -1610,14 +1653,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1610
1653
let mut transform = TransformVisitor {
1611
1654
tcx,
1612
1655
coroutine_kind : body. coroutine_kind ( ) . unwrap ( ) ,
1613
- state_adt_ref,
1614
- state_args,
1615
1656
remap,
1616
1657
storage_liveness,
1617
1658
always_live_locals,
1618
1659
suspension_points : Vec :: new ( ) ,
1619
- new_ret_local ,
1660
+ old_ret_local ,
1620
1661
discr_ty,
1662
+ old_ret_ty,
1663
+ old_yield_ty,
1621
1664
} ;
1622
1665
transform. visit_body ( body) ;
1623
1666
0 commit comments