Skip to content

Commit

Permalink
Auto merge of rust-lang#122387 - DianQK:re-enable-early-otherwise-bra…
Browse files Browse the repository at this point in the history
…nch, r=cjgillot

Re-enable the early otherwise branch optimization

Closes rust-lang#95162. Fixes rust-lang#119014.

This is the first part of rust-lang#121397.

An invalid enum discriminant can come from anywhere. We have to check to see if all successors contain the discriminant statement. This should have a pass to hoist instructions.

r? cjgillot
  • Loading branch information
bors committed Apr 9, 2024
2 parents b234e44 + 166bb1b commit 59c808f
Show file tree
Hide file tree
Showing 16 changed files with 616 additions and 278 deletions.
172 changes: 53 additions & 119 deletions compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{Ty, TyCtxt};
use std::fmt::Debug;

use super::simplify::simplify_cfg;
Expand All @@ -11,6 +11,7 @@ use super::simplify::simplify_cfg;
/// let y: Option<()>;
/// match (x,y) {
/// (Some(_), Some(_)) => {0},
/// (None, None) => {2},
/// _ => {1}
/// }
/// ```
Expand All @@ -23,10 +24,10 @@ use super::simplify::simplify_cfg;
/// if discriminant_x == discriminant_y {
/// match x {
/// Some(_) => 0,
/// _ => 1, // <----
/// } // | Actually the same bb
/// } else { // |
/// 1 // <--------------
/// None => 2,
/// }
/// } else {
/// 1
/// }
/// ```
///
Expand All @@ -47,18 +48,18 @@ use super::simplify::simplify_cfg;
/// | | |
/// ================= | | |
/// | BBU | <-| | | ============================
/// |---------------| | \-------> | BBD |
/// |---------------| | | |--------------------------|
/// | unreachable | | | | _dl = discriminant(P) |
/// ================= | | |--------------------------|
/// | | | switchInt(_dl) |
/// ================= | | | d | ---> BBD.2
/// |---------------| \-------> | BBD |
/// |---------------| | |--------------------------|
/// | unreachable | | | _dl = discriminant(P) |
/// ================= | |--------------------------|
/// | | switchInt(_dl) |
/// ================= | | d | ---> BBD.2
/// | BB9 | <--------------- | otherwise |
/// |---------------| ============================
/// | ... |
/// =================
/// ```
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
/// code:
/// - `BB1` is `parent` and `BBC, BBD` are children
/// - `P` is `child_place`
Expand All @@ -78,7 +79,7 @@ use super::simplify::simplify_cfg;
/// |---------------------| | | switchInt(Q) |
/// | switchInt(_t) | | | c | ---> BBC.2
/// | false | --------/ | d | ---> BBD.2
/// | otherwise | ---------------- | otherwise |
/// | otherwise | /--------- | otherwise |
/// ======================= | ============================
/// |
/// ================= |
Expand All @@ -87,16 +88,11 @@ use super::simplify::simplify_cfg;
/// | ... |
/// =================
/// ```
///
/// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
/// The filter on which `P` are allowed (together with discussion of its correctness) is found in
/// `may_hoist`.
pub struct EarlyOtherwiseBranch;

impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
// unsound: https://github.com/rust-lang/rust/issues/95162
sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts
sess.mir_opt_level() >= 2
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
Expand Down Expand Up @@ -172,7 +168,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
};
(value, targets.target_for_value(value))
});
let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
// The otherwise either is the same target branch or an unreachable.
let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());

// Create `bbEq` in example above
let eq_switch = BasicBlockData::new(Some(Terminator {
Expand Down Expand Up @@ -217,85 +214,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
}
}

/// Returns true if computing the discriminant of `place` may be hoisted out of the branch
fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
// FIXME(JakobDegen): This is unsound. Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// unsafe { *ptr = 10; } // Any invalid value for the type
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
// invalid value, which is UB.
//
// In order to fix this, we would either need to show that the discriminant computation of
// `place` is computed in all branches, including the `otherwise` branch, or we would need
// another analysis pass to determine that the place is fully initialized. It might even be best
// to have the hoisting be performed in a different pass and just do the CFG changing in this
// pass.
for (place, proj) in place.iter_projections() {
match proj {
// Dereferencing in the computation of `place` might cause issues from one of two
// categories. First, the referent might be invalid. We protect against this by
// dereferencing references only (not pointers). Second, the use of a reference may
// invalidate other references that are used later (for aliasing reasons). Consider
// where such an invalidated reference may appear:
// - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
// cannot contain referenced data.
// - In `BBU`: Not possible since that block contains only the `unreachable` terminator
// - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
// reaching that block in the input to our transformation, and so any data
// invalidated by that computation could not have been used there.
// - In `BB9`: Not possible since control flow might have reached `BB9` via the
// `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
// have invalidated the data when computing `discriminant(P)`
// So dereferencing here is correct.
ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
ty::Ref(..) => {}
_ => return false,
},
// Field projections are always valid
ProjectionElem::Field(..) => {}
// We cannot allow
// downcasts either, since the correctness of the downcast may depend on the parent
// branch being taken. An easy example of this is
// ```
// Q = discriminant(_3)
// P = (_3 as Variant)
// ```
// However, checking if the child and parent place are the same and only erroring then
// is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
// be replaced by another optimization pass with any other condition that can be proven
// equivalent.
ProjectionElem::Downcast(..) => {
return false;
}
// We cannot allow indexing since the index may be out of bounds.
_ => {
return false;
}
}
}
true
}

#[derive(Debug)]
struct OptimizationData<'tcx> {
destination: BasicBlock,
Expand All @@ -315,18 +233,40 @@ fn evaluate_candidate<'tcx>(
return None;
};
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
let parent_dest = {
let poss = targets.otherwise();
// If the fallthrough on the parent is trivially unreachable, we can let the
// children choose the destination
if bbs[poss].statements.len() == 0
&& bbs[poss].terminator().kind == TerminatorKind::Unreachable
{
None
} else {
Some(poss)
}
};
if !bbs[targets.otherwise()].is_empty_unreachable() {
// Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// // It may be difficult for us to effectively determine whether values are valid.
// // Invalid values can come from all sorts of corners.
// unsafe { *ptr = 10; }
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
// invalid value, which is UB.
// In order to fix this, **we would either need to show that the discriminant computation of
// `place` is computed in all branches**.
// FIXME(#95162) For the moment, we adopt a conservative approach and
// consider only the `otherwise` branch has no statements and an unreachable terminator.
return None;
}
let (_, child) = targets.iter().next()?;
let child_terminator = &bbs[child].terminator();
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
Expand All @@ -344,13 +284,7 @@ fn evaluate_candidate<'tcx>(
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
return None;
};
let destination = parent_dest.unwrap_or(child_targets.otherwise());

// Verify that the optimization is legal in general
// We can hoist evaluating the child discriminant out of the branch
if !may_hoist(tcx, body, *child_place) {
return None;
}
let destination = child_targets.otherwise();

// Verify that the optimization is legal for each branch
for (value, child) in targets.iter() {
Expand Down Expand Up @@ -411,5 +345,5 @@ fn verify_candidate_branch<'tcx>(
if let Some(_) = iter.next() {
return false;
}
return true;
true
}
26 changes: 26 additions & 0 deletions tests/codegen/enum/enum-early-otherwise-branch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//@ compile-flags: -O
//@ min-llvm-version: 18

#![crate_type = "lib"]

pub enum Enum {
A(u32),
B(u32),
C(u32),
}

#[no_mangle]
pub fn foo(lhs: &Enum, rhs: &Enum) -> bool {
// CHECK-LABEL: define{{.*}}i1 @foo(
// CHECK-NOT: switch
// CHECK-NOT: br
// CHECK: [[SELECT:%.*]] = select
// CHECK-NEXT: ret i1 [[SELECT]]
// CHECK-NEXT: }
match (lhs, rhs) {
(Enum::A(lhs), Enum::A(rhs)) => lhs == rhs,
(Enum::B(lhs), Enum::B(rhs)) => lhs == rhs,
(Enum::C(lhs), Enum::C(rhs)) => lhs == rhs,
_ => false,
}
}
39 changes: 13 additions & 26 deletions tests/mir-opt/early_otherwise_branch.opt1.EarlyOtherwiseBranch.diff
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
let mut _7: isize;
let _8: u32;
let _9: u32;
+ let mut _10: isize;
+ let mut _11: bool;
scope 1 {
debug a => _8;
debug b => _9;
Expand All @@ -29,48 +27,37 @@
StorageDead(_5);
StorageDead(_4);
_7 = discriminant((_3.0: std::option::Option<u32>));
- switchInt(move _7) -> [1: bb2, otherwise: bb1];
+ StorageLive(_10);
+ _10 = discriminant((_3.1: std::option::Option<u32>));
+ StorageLive(_11);
+ _11 = Ne(_7, move _10);
+ StorageDead(_10);
+ switchInt(move _11) -> [0: bb4, otherwise: bb1];
switchInt(move _7) -> [1: bb2, 0: bb1, otherwise: bb5];
}

bb1: {
+ StorageDead(_11);
_0 = const 1_u32;
- goto -> bb4;
+ goto -> bb3;
goto -> bb4;
}

bb2: {
- _6 = discriminant((_3.1: std::option::Option<u32>));
- switchInt(move _6) -> [1: bb3, otherwise: bb1];
- }
-
- bb3: {
_6 = discriminant((_3.1: std::option::Option<u32>));
switchInt(move _6) -> [1: bb3, 0: bb1, otherwise: bb5];
}
bb3: {
StorageLive(_8);
_8 = (((_3.0: std::option::Option<u32>) as Some).0: u32);
StorageLive(_9);
_9 = (((_3.1: std::option::Option<u32>) as Some).0: u32);
_0 = const 0_u32;
StorageDead(_9);
StorageDead(_8);
- goto -> bb4;
+ goto -> bb3;
goto -> bb4;
}

- bb4: {
+ bb3: {
bb4: {
StorageDead(_3);
return;
+ }
+
+ bb4: {
+ StorageDead(_11);
+ switchInt(_7) -> [1: bb2, otherwise: bb1];
}

bb5: {
unreachable;
}
}

Loading

0 comments on commit 59c808f

Please sign in to comment.