Skip to content

Commit a0cbc16

Browse files
Rework coroutine transform to be more flexible in preparation for async generators
1 parent ae612be commit a0cbc16

File tree

1 file changed

+123
-80
lines changed

1 file changed

+123
-80
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 123 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ use rustc_index::{Idx, IndexVec};
6666
use rustc_middle::mir::dump_mir;
6767
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
6868
use rustc_middle::mir::*;
69+
use rustc_middle::ty::CoroutineArgs;
6970
use rustc_middle::ty::InstanceDef;
70-
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
71-
use rustc_middle::ty::{CoroutineArgs, GenericArgsRef};
71+
use rustc_middle::ty::{self, Ty, TyCtxt};
7272
use rustc_mir_dataflow::impls::{
7373
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
7474
};
@@ -225,8 +225,6 @@ struct SuspensionPoint<'tcx> {
225225
struct TransformVisitor<'tcx> {
226226
tcx: TyCtxt<'tcx>,
227227
coroutine_kind: hir::CoroutineKind,
228-
state_adt_ref: AdtDef<'tcx>,
229-
state_args: GenericArgsRef<'tcx>,
230228

231229
// The type of the discriminant in the coroutine struct
232230
discr_ty: Ty<'tcx>,
@@ -245,21 +243,34 @@ struct TransformVisitor<'tcx> {
245243
always_live_locals: BitSet<Local>,
246244

247245
// The original RETURN_PLACE local
248-
new_ret_local: Local,
246+
old_ret_local: Local,
247+
248+
old_yield_ty: Ty<'tcx>,
249+
250+
old_ret_ty: Ty<'tcx>,
249251
}
250252

251253
impl<'tcx> TransformVisitor<'tcx> {
252254
fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
253-
let block = BasicBlock::new(body.basic_blocks.len());
255+
assert!(matches!(self.coroutine_kind, CoroutineKind::Gen(_)));
254256

257+
let block = BasicBlock::new(body.basic_blocks.len());
255258
let source_info = SourceInfo::outermost(body.span);
259+
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
256260

257-
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true);
258-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
259261
let statements = vec![Statement {
260262
kind: StatementKind::Assign(Box::new((
261263
Place::return_place(),
262-
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
264+
Rvalue::Aggregate(
265+
Box::new(AggregateKind::Adt(
266+
option_def_id,
267+
VariantIdx::from_usize(0),
268+
self.tcx.mk_args(&[self.old_yield_ty.into()]),
269+
None,
270+
None,
271+
)),
272+
IndexVec::new(),
273+
),
263274
))),
264275
source_info,
265276
}];
@@ -273,23 +284,6 @@ impl<'tcx> TransformVisitor<'tcx> {
273284
block
274285
}
275286

276-
fn coroutine_state_adt_and_variant_idx(
277-
&self,
278-
is_return: bool,
279-
) -> (AggregateKind<'tcx>, VariantIdx) {
280-
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
281-
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
282-
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
283-
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
284-
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
285-
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
286-
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
287-
});
288-
289-
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
290-
(kind, idx)
291-
}
292-
293287
// Make a `CoroutineState` or `Poll` variant assignment.
294288
//
295289
// `core::ops::CoroutineState` only has single element tuple variants,
@@ -302,51 +296,99 @@ impl<'tcx> TransformVisitor<'tcx> {
302296
is_return: bool,
303297
statements: &mut Vec<Statement<'tcx>>,
304298
) {
305-
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return);
306-
307-
match self.coroutine_kind {
308-
// `Poll::Pending`
299+
let rvalue = match self.coroutine_kind {
309300
CoroutineKind::Async(_) => {
310-
if !is_return {
311-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
312-
313-
// FIXME(swatinem): assert that `val` is indeed unit?
314-
statements.push(Statement {
315-
kind: StatementKind::Assign(Box::new((
316-
Place::return_place(),
317-
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
318-
))),
319-
source_info,
320-
});
321-
return;
301+
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
302+
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
303+
if is_return {
304+
// Poll::Ready(val)
305+
Rvalue::Aggregate(
306+
Box::new(AggregateKind::Adt(
307+
poll_def_id,
308+
VariantIdx::from_usize(0),
309+
args,
310+
None,
311+
None,
312+
)),
313+
IndexVec::from_raw(vec![val]),
314+
)
315+
} else {
316+
// Poll::Pending
317+
Rvalue::Aggregate(
318+
Box::new(AggregateKind::Adt(
319+
poll_def_id,
320+
VariantIdx::from_usize(1),
321+
args,
322+
None,
323+
None,
324+
)),
325+
IndexVec::new(),
326+
)
322327
}
323328
}
324-
// `Option::None`
325329
CoroutineKind::Gen(_) => {
330+
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
331+
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
326332
if is_return {
327-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
328-
329-
statements.push(Statement {
330-
kind: StatementKind::Assign(Box::new((
331-
Place::return_place(),
332-
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
333-
))),
334-
source_info,
335-
});
336-
return;
333+
// None
334+
Rvalue::Aggregate(
335+
Box::new(AggregateKind::Adt(
336+
option_def_id,
337+
VariantIdx::from_usize(0),
338+
args,
339+
None,
340+
None,
341+
)),
342+
IndexVec::new(),
343+
)
344+
} else {
345+
// Some(val)
346+
Rvalue::Aggregate(
347+
Box::new(AggregateKind::Adt(
348+
option_def_id,
349+
VariantIdx::from_usize(1),
350+
args,
351+
None,
352+
None,
353+
)),
354+
IndexVec::from_raw(vec![val]),
355+
)
337356
}
338357
}
339-
CoroutineKind::Coroutine => {}
340-
}
341-
342-
// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
343-
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
358+
CoroutineKind::Coroutine => {
359+
let coroutine_state_def_id =
360+
self.tcx.require_lang_item(LangItem::CoroutineState, None);
361+
let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
362+
if is_return {
363+
// CoroutineState::Complete(val)
364+
Rvalue::Aggregate(
365+
Box::new(AggregateKind::Adt(
366+
coroutine_state_def_id,
367+
VariantIdx::from_usize(1),
368+
args,
369+
None,
370+
None,
371+
)),
372+
IndexVec::from_raw(vec![val]),
373+
)
374+
} else {
375+
// CoroutineState::Yielded(val)
376+
Rvalue::Aggregate(
377+
Box::new(AggregateKind::Adt(
378+
coroutine_state_def_id,
379+
VariantIdx::from_usize(0),
380+
args,
381+
None,
382+
None,
383+
)),
384+
IndexVec::from_raw(vec![val]),
385+
)
386+
}
387+
}
388+
};
344389

345390
statements.push(Statement {
346-
kind: StatementKind::Assign(Box::new((
347-
Place::return_place(),
348-
Rvalue::Aggregate(Box::new(kind), [val].into()),
349-
))),
391+
kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
350392
source_info,
351393
});
352394
}
@@ -420,7 +462,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
420462

421463
let ret_val = match data.terminator().kind {
422464
TerminatorKind::Return => {
423-
Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None))
465+
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
424466
}
425467
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
426468
Some((false, Some((resume, resume_arg)), value.clone(), drop))
@@ -1493,10 +1535,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
14931535

14941536
impl<'tcx> MirPass<'tcx> for StateTransform {
14951537
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
1496-
let Some(yield_ty) = body.yield_ty() else {
1538+
let Some(old_yield_ty) = body.yield_ty() else {
14971539
// This only applies to coroutines
14981540
return;
14991541
};
1542+
let old_ret_ty = body.return_ty();
15001543

15011544
assert!(body.coroutine_drop().is_none());
15021545

@@ -1520,34 +1563,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15201563

15211564
let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
15221565
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
1523-
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
1566+
let new_ret_ty = match body.coroutine_kind().unwrap() {
15241567
CoroutineKind::Async(_) => {
15251568
// Compute Poll<return_ty>
15261569
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
15271570
let poll_adt_ref = tcx.adt_def(poll_did);
1528-
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
1529-
(poll_adt_ref, poll_args)
1571+
let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
1572+
Ty::new_adt(tcx, poll_adt_ref, poll_args)
15301573
}
15311574
CoroutineKind::Gen(_) => {
15321575
// Compute Option<yield_ty>
15331576
let option_did = tcx.require_lang_item(LangItem::Option, None);
15341577
let option_adt_ref = tcx.adt_def(option_did);
1535-
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
1536-
(option_adt_ref, option_args)
1578+
let option_args = tcx.mk_args(&[old_yield_ty.into()]);
1579+
Ty::new_adt(tcx, option_adt_ref, option_args)
15371580
}
15381581
CoroutineKind::Coroutine => {
15391582
// Compute CoroutineState<yield_ty, return_ty>
15401583
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
15411584
let state_adt_ref = tcx.adt_def(state_did);
1542-
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
1543-
(state_adt_ref, state_args)
1585+
let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
1586+
Ty::new_adt(tcx, state_adt_ref, state_args)
15441587
}
15451588
};
1546-
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
15471589

1548-
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1590+
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
15491591
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1550-
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
1592+
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
15511593

15521594
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
15531595
if is_async_kind {
@@ -1564,17 +1606,18 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15641606
} else {
15651607
body.local_decls[resume_local].ty
15661608
};
1567-
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
1609+
let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);
15681610

1569-
// When first entering the coroutine, move the resume argument into its new local.
1611+
// When first entering the coroutine, move the resume argument into its old local
1612+
// (which is now a generator interior).
15701613
let source_info = SourceInfo::outermost(body.span);
15711614
let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
15721615
stmts.insert(
15731616
0,
15741617
Statement {
15751618
source_info,
15761619
kind: StatementKind::Assign(Box::new((
1577-
new_resume_local.into(),
1620+
old_resume_local.into(),
15781621
Rvalue::Use(Operand::Move(resume_local.into())),
15791622
))),
15801623
},
@@ -1610,14 +1653,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16101653
let mut transform = TransformVisitor {
16111654
tcx,
16121655
coroutine_kind: body.coroutine_kind().unwrap(),
1613-
state_adt_ref,
1614-
state_args,
16151656
remap,
16161657
storage_liveness,
16171658
always_live_locals,
16181659
suspension_points: Vec::new(),
1619-
new_ret_local,
1660+
old_ret_local,
16201661
discr_ty,
1662+
old_ret_ty,
1663+
old_yield_ty,
16211664
};
16221665
transform.visit_body(body);
16231666

0 commit comments

Comments
 (0)