Skip to content

Commit

Permalink
Rollup merge of rust-lang#119666 - compiler-errors:construct-coroutin…
Browse files Browse the repository at this point in the history
…e-info-immediately, r=cjgillot

Populate `yield` and `resume` types in MIR body while body is being initialized

I found it weird that we went back and populated these types *after* the body was constructed. Let's just do it all at once.
  • Loading branch information
compiler-errors authored Jan 7, 2024
2 parents 5117dc7 + 5e2b66f commit 854d113
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 78 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_const_eval/src/transform/promote_consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ pub fn promote_candidates<'tcx>(
0,
vec![],
body.span,
body.coroutine_kind(),
None,
body.tainted_by_errors,
);
promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial);
Expand Down
29 changes: 19 additions & 10 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,23 @@ pub struct CoroutineInfo<'tcx> {
pub coroutine_kind: CoroutineKind,
}

impl<'tcx> CoroutineInfo<'tcx> {
// Sets up `CoroutineInfo` for a pre-coroutine-transform MIR body.
pub fn initial(
coroutine_kind: CoroutineKind,
yield_ty: Ty<'tcx>,
resume_ty: Ty<'tcx>,
) -> CoroutineInfo<'tcx> {
CoroutineInfo {
coroutine_kind,
yield_ty: Some(yield_ty),
resume_ty: Some(resume_ty),
coroutine_drop: None,
coroutine_layout: None,
}
}
}

/// The lowered representation of a single function.
#[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)]
pub struct Body<'tcx> {
Expand Down Expand Up @@ -367,7 +384,7 @@ impl<'tcx> Body<'tcx> {
arg_count: usize,
var_debug_info: Vec<VarDebugInfo<'tcx>>,
span: Span,
coroutine_kind: Option<CoroutineKind>,
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
tainted_by_errors: Option<ErrorGuaranteed>,
) -> Self {
// We need `arg_count` locals, and one for the return place.
Expand All @@ -384,15 +401,7 @@ impl<'tcx> Body<'tcx> {
source,
basic_blocks: BasicBlocks::new(basic_blocks),
source_scopes,
coroutine: coroutine_kind.map(|coroutine_kind| {
Box::new(CoroutineInfo {
yield_ty: None,
resume_ty: None,
coroutine_drop: None,
coroutine_layout: None,
coroutine_kind,
})
}),
coroutine,
local_decls,
user_type_annotations,
arg_count,
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ macro_rules! thir_with_elements {
}
}

pub const UPVAR_ENV_PARAM: ParamId = ParamId::from_u32(0);

thir_with_elements! {
body_type: BodyTy<'tcx>,

Expand Down
120 changes: 60 additions & 60 deletions compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_errors::ErrorGuaranteed;
use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::{CoroutineKind, Node};
use rustc_hir::Node;
use rustc_index::bit_set::GrowableBitSet;
use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
Expand Down Expand Up @@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> {
check_overflow: bool,
fn_span: Span,
arg_count: usize,
coroutine_kind: Option<CoroutineKind>,
coroutine: Option<Box<CoroutineInfo<'tcx>>>,

/// The current set of scopes, updated as we traverse;
/// see the `scope` module for more details.
Expand Down Expand Up @@ -458,7 +458,6 @@ fn construct_fn<'tcx>(
) -> Body<'tcx> {
let span = tcx.def_span(fn_def);
let fn_id = tcx.local_def_id_to_hir_id(fn_def);
let coroutine_kind = tcx.coroutine_kind(fn_def);

// The representation of thir for `-Zunpretty=thir-tree` relies on
// the entry expression being the last element of `thir.exprs`.
Expand Down Expand Up @@ -488,17 +487,15 @@ fn construct_fn<'tcx>(

let arguments = &thir.params;

let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let coroutine_sig = match coroutine_ty.kind() {
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
_ => {
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
}
};
(Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
} else {
(None, None, fn_sig.output())
let return_ty = fn_sig.output();
let coroutine = match tcx.type_of(fn_def).instantiate_identity().kind() {
ty::Coroutine(_, args) => Some(Box::new(CoroutineInfo::initial(
tcx.coroutine_kind(fn_def).unwrap(),
args.as_coroutine().yield_ty(),
args.as_coroutine().resume_ty(),
))),
ty::Closure(..) | ty::FnDef(..) => None,
ty => span_bug!(span_with_body, "unexpected type of body: {ty:?}"),
};

if let Some(custom_mir_attr) =
Expand Down Expand Up @@ -529,7 +526,7 @@ fn construct_fn<'tcx>(
safety,
return_ty,
return_ty_span,
coroutine_kind,
coroutine,
);

let call_site_scope =
Expand Down Expand Up @@ -563,11 +560,6 @@ fn construct_fn<'tcx>(
None
};

if coroutine_kind.is_some() {
body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
}

body
}

Expand Down Expand Up @@ -632,47 +624,62 @@ fn construct_const<'a, 'tcx>(
fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -> Body<'_> {
let span = tcx.def_span(def_id);
let hir_id = tcx.local_def_id_to_hir_id(def_id);
let coroutine_kind = tcx.coroutine_kind(def_id);

let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
let (inputs, output, coroutine) = match tcx.def_kind(def_id) {
DefKind::Const
| DefKind::AssocConst
| DefKind::AnonConst
| DefKind::InlineConst
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
let sig = tcx.liberate_late_bound_regions(
def_id.to_def_id(),
tcx.fn_sig(def_id).instantiate_identity(),
);
(sig.inputs().to_vec(), sig.output(), None, None)
}
DefKind::Closure if coroutine_kind.is_some() => {
let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
bug!("expected type of coroutine-like closure to be a coroutine")
};
let args = args.as_coroutine();
let resume_ty = args.resume_ty();
let yield_ty = args.yield_ty();
let return_ty = args.return_ty();
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
(sig.inputs().to_vec(), sig.output(), None)
}
DefKind::Closure => {
let closure_ty = tcx.type_of(def_id).instantiate_identity();
let ty::Closure(_, args) = closure_ty.kind() else {
bug!("expected type of closure to be a closure")
};
let args = args.as_closure();
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
let self_ty = match args.kind() {
ty::ClosureKind::Fn => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnOnce => closure_ty,
};
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
match closure_ty.kind() {
ty::Closure(_, args) => {
let args = args.as_closure();
let sig = tcx.liberate_late_bound_regions(def_id.to_def_id(), args.sig());
let self_ty = match args.kind() {
ty::ClosureKind::Fn => {
Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
}
ty::ClosureKind::FnMut => {
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty)
}
ty::ClosureKind::FnOnce => closure_ty,
};
(
[self_ty].into_iter().chain(sig.inputs().to_vec()).collect(),
sig.output(),
None,
)
}
ty::Coroutine(_, args) => {
let args = args.as_coroutine();
let resume_ty = args.resume_ty();
let yield_ty = args.yield_ty();
let return_ty = args.return_ty();
(
vec![closure_ty, args.resume_ty()],
return_ty,
Some(Box::new(CoroutineInfo::initial(
tcx.coroutine_kind(def_id).unwrap(),
yield_ty,
resume_ty,
))),
)
}
_ => {
span_bug!(span, "expected type of closure body to be a closure or coroutine");
}
}
}
dk => bug!("{:?} is not a body: {:?}", def_id, dk),
dk => span_bug!(span, "{:?} is not a body: {:?}", def_id, dk),
};

let source_info = SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE };
Expand All @@ -696,7 +703,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -

cfg.terminate(START_BLOCK, source_info, TerminatorKind::Unreachable);

let mut body = Body::new(
Body::new(
MirSource::item(def_id.to_def_id()),
cfg.basic_blocks,
source_scopes,
Expand All @@ -705,16 +712,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
inputs.len(),
vec![],
span,
coroutine_kind,
coroutine,
Some(guar),
);

body.coroutine.as_mut().map(|gen| {
gen.yield_ty = yield_ty;
gen.resume_ty = resume_ty;
});

body
)
}

impl<'a, 'tcx> Builder<'a, 'tcx> {
Expand All @@ -728,7 +728,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
safety: Safety,
return_ty: Ty<'tcx>,
return_span: Span,
coroutine_kind: Option<CoroutineKind>,
coroutine: Option<Box<CoroutineInfo<'tcx>>>,
) -> Builder<'a, 'tcx> {
let tcx = infcx.tcx;
let attrs = tcx.hir().attrs(hir_id);
Expand Down Expand Up @@ -759,7 +759,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
cfg: CFG { basic_blocks: IndexVec::new() },
fn_span: span,
arg_count,
coroutine_kind,
coroutine,
scopes: scope::Scopes::new(),
block_context: BlockContext::new(),
source_scopes: IndexVec::new(),
Expand Down Expand Up @@ -803,7 +803,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.arg_count,
self.var_debug_info,
self.fn_span,
self.coroutine_kind,
self.coroutine,
None,
)
}
Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_mir_build/src/build/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// If we are emitting a `drop` statement, we need to have the cached
// diverge cleanup pads ready in case that drop panics.
let needs_cleanup = self.scopes.scopes.last().is_some_and(|scope| scope.needs_cleanup());
let is_coroutine = self.coroutine_kind.is_some();
let is_coroutine = self.coroutine.is_some();
let unwind_to = if needs_cleanup { self.diverge_cleanup() } else { DropIdx::MAX };

let scope = self.scopes.scopes.last().expect("leave_top_scope called with no scopes");
Expand Down Expand Up @@ -960,7 +960,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// path, we only need to invalidate the cache for drops that happen on
// the unwind or coroutine drop paths. This means that for
// non-coroutines we don't need to invalidate caches for `DropKind::Storage`.
let invalidate_caches = needs_drop || self.coroutine_kind.is_some();
let invalidate_caches = needs_drop || self.coroutine.is_some();
for scope in self.scopes.scopes.iter_mut().rev() {
if invalidate_caches {
scope.invalidate_cache();
Expand Down Expand Up @@ -1073,7 +1073,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
return cached_drop;
}

let is_coroutine = self.coroutine_kind.is_some();
let is_coroutine = self.coroutine.is_some();
for scope in &mut self.scopes.scopes[uncached_scope..=target] {
for drop in &scope.drops {
if is_coroutine || drop.kind == DropKind::Value {
Expand Down Expand Up @@ -1318,7 +1318,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {
blocks[ROOT_NODE] = continue_block;

drops.build_mir::<ExitScopes>(&mut self.cfg, &mut blocks);
let is_coroutine = self.coroutine_kind.is_some();
let is_coroutine = self.coroutine.is_some();

// Link the exit drop tree to unwind drop tree.
if drops.drops.iter().any(|(drop, _)| drop.kind == DropKind::Value) {
Expand Down Expand Up @@ -1355,7 +1355,7 @@ impl<'a, 'tcx: 'a> Builder<'a, 'tcx> {

/// Build the unwind and coroutine drop trees.
pub(crate) fn build_drop_trees(&mut self) {
if self.coroutine_kind.is_some() {
if self.coroutine.is_some() {
self.build_coroutine_drop_trees();
} else {
Self::build_unwind_tree(
Expand Down

0 comments on commit 854d113

Please sign in to comment.