Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check projection types before inlining MIR #100571

Merged
merged 3 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,20 @@ pub fn equal_up_to_regions<'tcx>(

// Normalize lifetimes away on both sides, then compare.
let normalize = |ty: Ty<'tcx>| {
tcx.normalize_erasing_regions(
param_env,
ty.fold_with(&mut BottomUpFolder {
tcx,
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
// since one may have an `impl SomeTrait for fn(&32)` and
// `impl SomeTrait for fn(&'static u32)` at the same time which
// specify distinct values for Assoc. (See also #56105)
lt_op: |_| tcx.lifetimes.re_erased,
// Leave consts and types unchanged.
ct_op: |ct| ct,
ty_op: |ty| ty,
}),
)
let ty = ty.fold_with(&mut BottomUpFolder {
tcx,
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
Comment on lines +92 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this relate to the changes in #100121?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes exist to avoid ICEing if normalize_erasing_region fails. I think it's orthogonal to your PR.

// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
// since one may have an `impl SomeTrait for fn(&32)` and
// `impl SomeTrait for fn(&'static u32)` at the same time which
// specify distinct values for Assoc. (See also #56105)
lt_op: |_| tcx.lifetimes.re_erased,
// Leave consts and types unchanged.
ct_op: |ct| ct,
ty_op: |ty| ty,
});
tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty)
};
tcx.infer_ctxt().enter(|infcx| infcx.can_eq(param_env, normalize(src), normalize(dest)).is_ok())
}
Expand Down
314 changes: 222 additions & 92 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rustc_middle::ty::subst::Subst;
use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
use rustc_session::config::OptLevel;
use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
use rustc_target::abi::VariantIdx;
use rustc_target::spec::abi::Abi;

use super::simplify::{remove_dead_blocks, CfgSimplifier};
Expand Down Expand Up @@ -409,118 +410,60 @@ impl<'tcx> Inliner<'tcx> {
debug!(" final inline threshold = {}", threshold);

// FIXME: Give a bonus to functions with only a single caller
let mut first_block = true;
let mut cost = 0;
let diverges = matches!(
callee_body.basic_blocks()[START_BLOCK].terminator().kind,
TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
);
if diverges && !matches!(callee_attrs.inline, InlineAttr::Always) {
return Err("callee diverges unconditionally");
}

let mut checker = CostChecker {
tcx: self.tcx,
param_env: self.param_env,
instance: callsite.callee,
callee_body,
cost: 0,
validation: Ok(()),
};

// Traverse the MIR manually so we can account for the effects of
// inlining on the CFG.
// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
let mut work_list = vec![START_BLOCK];
let mut visited = BitSet::new_empty(callee_body.basic_blocks().len());
while let Some(bb) = work_list.pop() {
if !visited.insert(bb.index()) {
continue;
}

let blk = &callee_body.basic_blocks()[bb];
checker.visit_basic_block_data(bb, blk);

for stmt in &blk.statements {
// Don't count StorageLive/StorageDead in the inlining cost.
match stmt.kind {
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Deinit(_)
| StatementKind::Nop => {}
_ => cost += INSTR_COST,
}
}
let term = blk.terminator();
let mut is_drop = false;
match term.kind {
TerminatorKind::Drop { ref place, target, unwind }
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => {
is_drop = true;
work_list.push(target);
// If the place doesn't actually need dropping, treat it like
// a regular goto.
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
if ty.needs_drop(tcx, self.param_env) {
cost += CALL_PENALTY;
if let Some(unwind) = unwind {
cost += LANDINGPAD_PENALTY;
work_list.push(unwind);
}
} else {
cost += INSTR_COST;
}
}

TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
if first_block =>
{
// If the function always diverges, don't inline
// unless the cost is zero
threshold = 0;
}
wesleywiser marked this conversation as resolved.
Show resolved Hide resolved

TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
if let ty::FnDef(def_id, _) =
*callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind()
{
// Don't give intrinsics the extra penalty for calls
if tcx.is_intrinsic(def_id) {
cost += INSTR_COST;
} else {
cost += CALL_PENALTY;
}
} else {
cost += CALL_PENALTY;
}
if cleanup.is_some() {
cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Assert { cleanup, .. } => {
cost += CALL_PENALTY;

if cleanup.is_some() {
cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Resume => cost += RESUME_PENALTY,
TerminatorKind::InlineAsm { cleanup, .. } => {
cost += INSTR_COST;
if let TerminatorKind::Drop { ref place, target, unwind }
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } = term.kind
{
work_list.push(target);

if cleanup.is_some() {
cost += LANDINGPAD_PENALTY;
// If the place doesn't actually need dropping, treat it like a regular goto.
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind {
work_list.push(unwind);
}
}
_ => cost += INSTR_COST,
}

if !is_drop {
for succ in term.successors() {
work_list.push(succ);
}
} else {
work_list.extend(term.successors())
}

first_block = false;
}

// Count up the cost of local variables and temps, if we know the size
// use that, otherwise we use a moderately-large dummy cost.

let ptr_size = tcx.data_layout.pointer_size.bytes();

for v in callee_body.vars_and_temps_iter() {
let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty);
// Cost of the var is the size in machine-words, if we know
// it.
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
cost += ((size + ptr_size - 1) / ptr_size) as usize;
} else {
cost += UNKNOWN_SIZE_COST;
}
checker.visit_local_decl(v, &callee_body.local_decls[v]);
}

// Abort if type validation found anything fishy.
checker.validation?;

let cost = checker.cost;
if let InlineAttr::Always = callee_attrs.inline {
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
Ok(())
Expand Down Expand Up @@ -790,6 +733,193 @@ fn type_size_of<'tcx>(
tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes())
}

/// Verify that the callee body is compatible with the caller.
///
/// This visitor mostly computes the inlining cost,
/// but also needs to verify that types match because of normalization failure.
struct CostChecker<'b, 'tcx> {
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
cost: usize,
callee_body: &'b Body<'tcx>,
instance: ty::Instance<'tcx>,
validation: Result<(), &'static str>,
}

impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
// Don't count StorageLive/StorageDead in the inlining cost.
match statement.kind {
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Deinit(_)
| StatementKind::Nop => {}
_ => self.cost += INSTR_COST,
}

self.super_statement(statement, location);
}

fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
let tcx = self.tcx;
match terminator.kind {
TerminatorKind::Drop { ref place, unwind, .. }
| TerminatorKind::DropAndReplace { ref place, unwind, .. } => {
// If the place doesn't actually need dropping, treat it like a regular goto.
let ty = self.instance.subst_mir(tcx, &place.ty(self.callee_body, tcx).ty);
if ty.needs_drop(tcx, self.param_env) {
self.cost += CALL_PENALTY;
if unwind.is_some() {
self.cost += LANDINGPAD_PENALTY;
}
} else {
self.cost += INSTR_COST;
}
}
TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
let fn_ty = self.instance.subst_mir(tcx, &f.literal.ty());
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
// Don't give intrinsics the extra penalty for calls
INSTR_COST
} else {
CALL_PENALTY
};
if cleanup.is_some() {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Assert { cleanup, .. } => {
self.cost += CALL_PENALTY;
if cleanup.is_some() {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Resume => self.cost += RESUME_PENALTY,
TerminatorKind::InlineAsm { cleanup, .. } => {
self.cost += INSTR_COST;
if cleanup.is_some() {
self.cost += LANDINGPAD_PENALTY;
}
}
_ => self.cost += INSTR_COST,
}

self.super_terminator(terminator, location);
}

/// Count up the cost of local variables and temps, if we know the size
/// use that, otherwise we use a moderately-large dummy cost.
fn visit_local_decl(&mut self, local: Local, local_decl: &LocalDecl<'tcx>) {
let tcx = self.tcx;
let ptr_size = tcx.data_layout.pointer_size.bytes();

let ty = self.instance.subst_mir(tcx, &local_decl.ty);
// Cost of the var is the size in machine-words, if we know
// it.
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
self.cost += ((size + ptr_size - 1) / ptr_size) as usize;
} else {
self.cost += UNKNOWN_SIZE_COST;
}

self.super_local_decl(local, local_decl)
}

/// This method duplicates code from MIR validation in an attempt to detect type mismatches due
/// to normalization failure.
fn visit_projection_elem(
&mut self,
local: Local,
proj_base: &[PlaceElem<'tcx>],
elem: PlaceElem<'tcx>,
context: PlaceContext,
location: Location,
) {
if let ProjectionElem::Field(f, ty) = elem {
let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
let check_equal = |this: &mut Self, f_ty| {
if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
trace!(?ty, ?f_ty);
this.validation = Err("failed to normalize projection type");
return;
}
};

let kind = match parent_ty.ty.kind() {
&ty::Opaque(def_id, substs) => {
self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
}
kind => kind,
};

match kind {
ty::Tuple(fields) => {
let Some(f_ty) = fields.get(f.as_usize()) else {
self.validation = Err("malformed MIR");
return;
};
check_equal(self, *f_ty);
}
ty::Adt(adt_def, substs) => {
let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0));
let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else {
self.validation = Err("malformed MIR");
return;
};
check_equal(self, field.ty(self.tcx, substs));
}
ty::Closure(_, substs) => {
let substs = substs.as_closure();
let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else {
self.validation = Err("malformed MIR");
return;
};
check_equal(self, f_ty);
}
&ty::Generator(def_id, substs, _) => {
let f_ty = if let Some(var) = parent_ty.variant_index {
let gen_body = if def_id == self.callee_body.source.def_id() {
self.callee_body
} else {
self.tcx.optimized_mir(def_id)
};

let Some(layout) = gen_body.generator_layout() else {
self.validation = Err("malformed MIR");
return;
};

let Some(&local) = layout.variant_fields[var].get(f) else {
self.validation = Err("malformed MIR");
return;
};

let Some(&f_ty) = layout.field_tys.get(local) else {
self.validation = Err("malformed MIR");
return;
};

f_ty
} else {
let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
self.validation = Err("malformed MIR");
return;
};

f_ty
};

check_equal(self, f_ty);
}
_ => self.validation = Err("malformed MIR"),
}
}

self.super_projection_elem(local, proj_base, elem, context, location);
}
}

/**
* Integrator.
*
Expand Down
Loading