Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MIR inlining for generators too #99782

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions compiler/rustc_borrowck/src/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,28 +115,26 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}

let body_yield_ty = body.yield_ty(self.tcx());
debug!(
"equate_inputs_and_outputs: body.yield_ty {:?}, universal_regions.yield_ty {:?}",
body.yield_ty(),
universal_regions.yield_ty
body_yield_ty, universal_regions.yield_ty
);

// We will not have a universal_regions.yield_ty if we yield (by accident)
// outside of a generator and return an `impl Trait`, so emit a delay_span_bug
// because we don't want to panic in an assert here if we've already got errors.
if body.yield_ty().is_some() != universal_regions.yield_ty.is_some() {
if body_yield_ty.is_some() != universal_regions.yield_ty.is_some() {
self.tcx().sess.delay_span_bug(
body.span,
&format!(
"Expected body to have yield_ty ({:?}) iff we have a UR yield_ty ({:?})",
body.yield_ty(),
universal_regions.yield_ty,
body_yield_ty, universal_regions.yield_ty,
),
);
}

if let (Some(mir_yield_ty), Some(ur_yield_ty)) =
(body.yield_ty(), universal_regions.yield_ty)
if let (Some(mir_yield_ty), Some(ur_yield_ty)) = (body_yield_ty, universal_regions.yield_ty)
{
let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
self.equate_normalized_input_or_output(ur_yield_ty, mir_yield_ty, yield_span);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
self.check_operand(value, term_location);

let value_ty = value.ty(body, tcx);
match body.yield_ty() {
match body.yield_ty(tcx) {
None => span_mirbug!(self, term, "yield in non-generator"),
Some(ty) => {
if let Err(terr) = self.sub_types(
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1235,8 +1235,8 @@ fn generator_layout_and_saved_local_names<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
) -> (&'tcx GeneratorLayout<'tcx>, IndexVec<mir::GeneratorSavedLocal, Option<Symbol>>) {
let generator_layout = &tcx.mir_generator_info(def_id).generator_layout;
let body = tcx.optimized_mir(def_id);
let generator_layout = body.generator_layout().unwrap();
let mut generator_saved_local_names = IndexVec::from_elem(None, &generator_layout.field_tys);

let state_arg = mir::Local::new(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl<'mir, 'tcx> Checker<'mir, 'tcx> {

// `async` functions cannot be `const fn`. This is checked during AST lowering, so there's
// no need to emit duplicate errors here.
if self.ccx.is_async() || body.generator.is_some() {
if self.ccx.is_async() || tcx.generator_kind(def_id).is_some() {
tcx.sess.delay_span_bug(body.span, "`async` functions cannot be `const fn`");
return;
}
Expand Down
1 change: 0 additions & 1 deletion compiler/rustc_const_eval/src/transform/promote_consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,6 @@ pub fn promote_candidates<'tcx>(
0,
vec![],
body.span,
body.generator_kind(),
body.tainted_by_errors,
);
promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial);
Expand Down
22 changes: 8 additions & 14 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl<'tcx> MirPass<'tcx> for Validator {
storage_liveness,
place_cache: Vec::new(),
value_cache: Vec::new(),
is_generator: tcx.generator_kind(def_id).is_some(),
}
.visit_body(body);
}
Expand Down Expand Up @@ -117,6 +118,7 @@ struct TypeChecker<'a, 'tcx> {
storage_liveness: ResultsCursor<'a, 'tcx, MaybeStorageLive>,
place_cache: Vec<PlaceRef<'tcx>>,
value_cache: Vec<u128>,
is_generator: bool,
}

impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
Expand Down Expand Up @@ -323,16 +325,8 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
}
&ty::Generator(def_id, substs, _) => {
let f_ty = if let Some(var) = parent_ty.variant_index {
let gen_body = if def_id == self.body.source.def_id() {
self.body
} else {
self.tcx.optimized_mir(def_id)
};

let Some(layout) = gen_body.generator_layout() else {
self.fail(location, format!("No generator layout for {:?}", parent_ty));
return;
};
let generator_info = self.tcx.mir_generator_info(def_id);
let layout = &generator_info.generator_layout;

let Some(&local) = layout.variant_fields[var].get(f) else {
fail_out_of_bounds(self, location);
Expand Down Expand Up @@ -836,10 +830,10 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
}
}
TerminatorKind::Yield { resume, drop, .. } => {
if self.body.generator.is_none() {
if !self.is_generator {
self.fail(location, "`Yield` cannot appear outside generator bodies");
}
if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) {
if self.mir_phase >= MirPhase::Runtime(RuntimePhase::GeneratorsLowered) {
self.fail(location, "`Yield` should have been replaced by generator lowering");
}
self.check_edge(location, *resume, EdgeKind::Normal);
Expand Down Expand Up @@ -878,10 +872,10 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
}
}
TerminatorKind::GeneratorDrop => {
if self.body.generator.is_none() {
if !self.is_generator {
self.fail(location, "`GeneratorDrop` cannot appear outside generator bodies");
}
if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) {
if self.mir_phase >= MirPhase::Runtime(RuntimePhase::GeneratorsLowered) {
self.fail(
location,
"`GeneratorDrop` should have been replaced by generator lowering",
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ provide! { tcx, def_id, other, cdata,
thir_abstract_const => { table }
optimized_mir => { table }
mir_for_ctfe => { table }
mir_generator_info => { table }
promoted_mir => { table }
def_span => { table }
def_ident_span => { table }
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_metadata/src/rmeta/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,10 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
debug!("EntryBuilder::encode_mir({:?})", def_id);
if encode_opt {
record!(self.tables.optimized_mir[def_id.to_def_id()] <- tcx.optimized_mir(def_id));

if let DefKind::Generator = self.tcx.def_kind(def_id) {
record!(self.tables.mir_generator_info[def_id.to_def_id()] <- tcx.mir_generator_info(def_id));
}
}
if encode_const {
record!(self.tables.mir_for_ctfe[def_id.to_def_id()] <- tcx.mir_for_ctfe(def_id));
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_metadata/src/rmeta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ define_tables! {
object_lifetime_default: Table<DefIndex, LazyValue<ObjectLifetimeDefault>>,
optimized_mir: Table<DefIndex, LazyValue<mir::Body<'static>>>,
mir_for_ctfe: Table<DefIndex, LazyValue<mir::Body<'static>>>,
mir_generator_info: Table<DefIndex, LazyValue<mir::GeneratorInfo<'static>>>,
promoted_mir: Table<DefIndex, LazyValue<IndexVec<mir::Promoted, mir::Body<'static>>>>,
// FIXME(compiler-errors): Why isn't this a LazyArray?
thir_abstract_const: Table<DefIndex, LazyValue<&'static [ty::abstract_const::Node<'static>]>>,
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ macro_rules! arena_types {
[] steal_thir: rustc_data_structures::steal::Steal<rustc_middle::thir::Thir<'tcx>>,
[] steal_mir: rustc_data_structures::steal::Steal<rustc_middle::mir::Body<'tcx>>,
[decode] mir: rustc_middle::mir::Body<'tcx>,
[decode] generator_info: rustc_middle::mir::GeneratorInfo<'tcx>,
[] mir_generator_lowered: (
rustc_data_structures::steal::Steal<rustc_middle::mir::Body<'tcx>>,
Option<rustc_middle::mir::GeneratorInfo<'tcx>>,
),
[] steal_promoted:
rustc_data_structures::steal::Steal<
rustc_index::vec::IndexVec<
Expand Down
50 changes: 11 additions & 39 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,11 @@ impl<'tcx> MirSource<'tcx> {

#[derive(Clone, TyEncodable, TyDecodable, Debug, HashStable, TypeFoldable, TypeVisitable)]
pub struct GeneratorInfo<'tcx> {
/// The yield type of the function, if it is a generator.
pub yield_ty: Option<Ty<'tcx>>,

/// Generator drop glue.
pub generator_drop: Option<Body<'tcx>>,
pub generator_drop: Body<'tcx>,

/// The layout of a generator. Produced by the state transformation.
pub generator_layout: Option<GeneratorLayout<'tcx>>,

/// If this is a generator then record the type of source expression that caused this generator
/// to be created.
pub generator_kind: GeneratorKind,
pub generator_layout: GeneratorLayout<'tcx>,
}

/// The lowered representation of a single function.
Expand All @@ -213,8 +206,6 @@ pub struct Body<'tcx> {
/// and used for debuginfo. Indexed by a `SourceScope`.
pub source_scopes: IndexVec<SourceScope, SourceScopeData<'tcx>>,

pub generator: Option<Box<GeneratorInfo<'tcx>>>,

/// Declarations of locals.
///
/// The first local is the return value pointer, followed by `arg_count`
Expand Down Expand Up @@ -279,7 +270,6 @@ impl<'tcx> Body<'tcx> {
arg_count: usize,
var_debug_info: Vec<VarDebugInfo<'tcx>>,
span: Span,
generator_kind: Option<GeneratorKind>,
tainted_by_errors: Option<ErrorGuaranteed>,
) -> Self {
// We need `arg_count` locals, and one for the return place.
Expand All @@ -295,14 +285,6 @@ impl<'tcx> Body<'tcx> {
source,
basic_blocks: BasicBlocks::new(basic_blocks),
source_scopes,
generator: generator_kind.map(|generator_kind| {
Box::new(GeneratorInfo {
yield_ty: None,
generator_drop: None,
generator_layout: None,
generator_kind,
})
}),
local_decls,
user_type_annotations,
arg_count,
Expand All @@ -328,7 +310,6 @@ impl<'tcx> Body<'tcx> {
source: MirSource::item(CRATE_DEF_ID.to_def_id()),
basic_blocks: BasicBlocks::new(basic_blocks),
source_scopes: IndexVec::new(),
generator: None,
local_decls: IndexVec::new(),
user_type_annotations: IndexVec::new(),
arg_count: 0,
Expand Down Expand Up @@ -460,24 +441,15 @@ impl<'tcx> Body<'tcx> {
.unwrap_or_else(|| Either::Right(block_data.terminator()))
}

#[inline]
pub fn yield_ty(&self) -> Option<Ty<'tcx>> {
self.generator.as_ref().and_then(|generator| generator.yield_ty)
}

#[inline]
pub fn generator_layout(&self) -> Option<&GeneratorLayout<'tcx>> {
self.generator.as_ref().and_then(|generator| generator.generator_layout.as_ref())
}

#[inline]
pub fn generator_drop(&self) -> Option<&Body<'tcx>> {
self.generator.as_ref().and_then(|generator| generator.generator_drop.as_ref())
}

#[inline]
pub fn generator_kind(&self) -> Option<GeneratorKind> {
self.generator.as_ref().map(|generator| generator.generator_kind)
pub fn yield_ty(&self, tcx: TyCtxt<'_>) -> Option<Ty<'tcx>> {
if tcx.generator_kind(self.source.def_id()).is_none() {
return None;
};
let gen_ty = self.local_decls.raw[1].ty;
match *gen_ty.kind() {
ty::Generator(_, substs, _) => Some(substs.as_generator().sig().yield_ty),
_ => None,
}
}
}

Expand Down
7 changes: 5 additions & 2 deletions compiler/rustc_middle/src/mir/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ fn dump_matched_mir_node<'tcx, F>(
Some(promoted) => write!(file, "::{:?}`", promoted)?,
}
writeln!(file, " {} {}", disambiguator, pass_name)?;
if let Some(ref layout) = body.generator_layout() {
// Trying to fetch the layout before it has been computed would create a query cycle.
if body.phase >= MirPhase::Runtime(RuntimePhase::GeneratorsLowered)
&& let Some(layout) = tcx.generator_layout(body.source.def_id())
{
writeln!(file, "/* generator_layout = {:#?} */", layout)?;
}
writeln!(file)?;
Expand Down Expand Up @@ -1003,7 +1006,7 @@ fn write_mir_sig(tcx: TyCtxt<'_>, body: &Body<'_>, w: &mut dyn Write) -> io::Res
write!(w, ": {} =", body.return_ty())?;
}

if let Some(yield_ty) = body.yield_ty() {
if let Some(yield_ty) = body.yield_ty(tcx) {
writeln!(w)?;
writeln!(w, "yields {}", yield_ty)?;
}
Expand Down
10 changes: 7 additions & 3 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ pub enum RuntimePhase {
/// In addition to the semantic changes, beginning with this phase, the following variants are
/// disallowed:
/// * [`TerminatorKind::DropAndReplace`]
/// * [`TerminatorKind::Yield`]
/// * [`TerminatorKind::GeneratorDrop`]
/// * [`Rvalue::Aggregate`] for any `AggregateKind` except `Array`
///
/// And the following variants are allowed:
Expand All @@ -126,7 +124,13 @@ pub enum RuntimePhase {
/// Beginning with this phase, the following variant is disallowed:
/// * [`ProjectionElem::Deref`] of `Box`
PostCleanup = 1,
Optimized = 2,
/// Beginning with this phase, the following variant is disallowed:
/// * [`TerminatorKind::Yield`]
/// * [`TerminatorKind::GeneratorDrop`]
///
/// Furthermore, `Copy` operands are allowed for non-`Copy` types.
GeneratorsLowered = 2,
Optimized = 3,
}

///////////////////////////////////////////////////////////////////////////
Expand Down
10 changes: 0 additions & 10 deletions compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -979,16 +979,6 @@ macro_rules! extra_body_methods {

macro_rules! super_body {
($self:ident, $body:ident, $($mutability:ident, $invalidate:tt)?) => {
let span = $body.span;
if let Some(gen) = &$($mutability)? $body.generator {
if let Some(yield_ty) = $(& $mutability)? gen.yield_ty {
$self.visit_ty(
yield_ty,
TyContext::YieldTy(SourceInfo::outermost(span))
);
}
}

for (bb, data) in basic_blocks_iter!($body, $($mutability, $invalidate)?) {
$self.visit_basic_block_data(bb, data);
}
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,21 @@ rustc_queries! {
}
}

query mir_generator_lowered(key: LocalDefId) -> &'tcx (
Steal<mir::Body<'tcx>>,
Option<mir::GeneratorInfo<'tcx>>,

) {
no_hash
desc { |tcx| "computing generator MIR for `{}`", tcx.def_path_str(key.to_def_id()) }
}

query mir_generator_info(key: DefId) -> &'tcx mir::GeneratorInfo<'tcx> {
desc { |tcx| "generator glue MIR for `{}`", tcx.def_path_str(key) }
cache_on_disk_if { key.is_local() }
separate_provide_extern
}

/// MIR after our optimization passes have run. This is MIR that is ready
/// for codegen. This is also the only query that can fetch non-local MIR, at present.
query optimized_mir(key: DefId) -> &'tcx mir::Body<'tcx> {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/ty/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ impl_decodable_via_ref! {
&'tcx ty::List<ty::Binder<'tcx, ty::ExistentialPredicate<'tcx>>>,
&'tcx traits::ImplSource<'tcx, ()>,
&'tcx mir::Body<'tcx>,
&'tcx mir::GeneratorInfo<'tcx>,
&'tcx mir::UnsafetyCheckResult,
&'tcx mir::BorrowCheckResult<'tcx>,
&'tcx mir::coverage::CodeRegion,
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2291,7 +2291,11 @@ impl<'tcx> TyCtxt<'tcx> {
/// Returns layout of a generator. Layout might be unavailable if the
/// generator is tainted by errors.
pub fn generator_layout(self, def_id: DefId) -> Option<&'tcx GeneratorLayout<'tcx>> {
self.optimized_mir(def_id).generator_layout()
if self.generator_kind(def_id).is_some() {
Some(&self.mir_generator_info(def_id).generator_layout)
} else {
None
}
}

/// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/ty/parameterized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rustc_hir::def_id::DefId;
use rustc_index::vec::{Idx, IndexVec};

use crate::middle::exported_symbols::ExportedSymbol;
use crate::mir::Body;
use crate::mir::{Body, GeneratorInfo};
use crate::ty::abstract_const::Node;
use crate::ty::{
self, Const, FnSig, GeneratorDiagnosticData, GenericPredicates, Predicate, TraitRef, Ty,
Expand Down Expand Up @@ -117,6 +117,7 @@ parameterized_over_tcx! {
Predicate,
GeneratorDiagnosticData,
Body,
GeneratorInfo,
Node,
ExportedSymbol,
}
Loading