Skip to content

Commit

Permalink
very cool, 10/10
Browse files Browse the repository at this point in the history
  • Loading branch information
lcnr committed Sep 25, 2024
1 parent 1b5aa96 commit 5a95672
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 109 deletions.
244 changes: 135 additions & 109 deletions compiler/rustc_next_trait_solver/src/canonicalizer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::cmp::Ordering;

use rustc_type_ir::data_structures::HashMap;
use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_type_ir::inherent::*;
use rustc_type_ir::visit::TypeVisitableExt;
Expand Down Expand Up @@ -44,8 +45,12 @@ pub struct Canonicalizer<'a, D: SolverDelegate<Interner = I>, I: Interner> {
canonicalize_mode: CanonicalizeMode,

variables: &'a mut Vec<I::GenericArg>,
variable_lookup_table: HashMap<I::GenericArg, usize>,

primitive_var_infos: Vec<CanonicalVarInfo<I>>,
binder_index: ty::DebruijnIndex,

cache: HashMap<(ty::DebruijnIndex, I::Ty), I::Ty>,
}

impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
Expand All @@ -60,12 +65,14 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
canonicalize_mode,

variables,
variable_lookup_table: Default::default(),
primitive_var_infos: Vec::new(),
binder_index: ty::INNERMOST,

cache: Default::default(),
};

let value = value.fold_with(&mut canonicalizer);
// FIXME: Restore these assertions. Should we uplift type flags?
assert!(!value.has_infer(), "unexpected infer in {value:?}");
assert!(!value.has_placeholders(), "unexpected placeholders in {value:?}");

Expand All @@ -75,6 +82,37 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
Canonical { defining_opaque_types, max_universe, variables, value }
}

fn get_or_insert_bound_var(
&mut self,
arg: impl Into<I::GenericArg>,
canonical_var_info: CanonicalVarInfo<I>,
) -> ty::BoundVar {
// FIXME: 16 is made up and arbitrary. We should look at some
// perf data here.
let arg = arg.into();
let idx = if self.variables.len() > 16 {
if self.variable_lookup_table.is_empty() {
self.variable_lookup_table.extend(self.variables.iter().copied().zip(0..));
}

*self.variable_lookup_table.entry(arg).or_insert_with(|| {
let var = self.variables.len();
self.variables.push(arg);
self.primitive_var_infos.push(canonical_var_info);
var
})
} else {
self.variables.iter().position(|&v| v == arg).unwrap_or_else(|| {
let var = self.variables.len();
self.variables.push(arg);
self.primitive_var_infos.push(canonical_var_info);
var
})
};

ty::BoundVar::from(idx)
}

fn finalize(self) -> (ty::UniverseIndex, I::CanonicalVars) {
let mut var_infos = self.primitive_var_infos;
// See the rustc-dev-guide section about how we deal with universes
Expand Down Expand Up @@ -124,8 +162,8 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
// - var_infos: [E0, U1, E2, U1, E1, E6, U6], curr_compressed_uv: 2, next_orig_uv: 6
// - var_infos: [E0, U1, E1, U1, E1, E3, U3], curr_compressed_uv: 2, next_orig_uv: -
//
// This algorithm runs in `O()` where `n` is the number of different universe
// indices in the input. This should be fine as `n` is expected to be small.
// This algorithm runs in `O(mn)` where `n` is the number of different universes and
// `m` the number of variables. This should be fine as both are expected to be small.
let mut curr_compressed_uv = ty::UniverseIndex::ROOT;
let mut existential_in_new_uv = None;
let mut next_orig_uv = Some(ty::UniverseIndex::ROOT);
Expand Down Expand Up @@ -185,14 +223,16 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
for var in var_infos.iter_mut() {
// We simply put all regions from the input into the highest
// compressed universe, so we only deal with them at the end.
if !var.is_region() && is_existential == var.is_existential() {
update_uv(var, orig_uv, is_existential)
if !var.is_region() {
if is_existential == var.is_existential() {
update_uv(var, orig_uv, is_existential)
}
}
}
}
}

// We uniquify regions and always put them into their own universe
// We put all regions into a separate universe.
let mut first_region = true;
for var in var_infos.iter_mut() {
if var.is_region() {
Expand All @@ -208,93 +248,8 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
let var_infos = self.delegate.cx().mk_canonical_var_infos(&var_infos);
(curr_compressed_uv, var_infos)
}
}

impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicalizer<'_, D, I> {
fn cx(&self) -> I {
self.delegate.cx()
}

fn fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T>
where
T: TypeFoldable<I>,
{
self.binder_index.shift_in(1);
let t = t.super_fold_with(self);
self.binder_index.shift_out(1);
t
}

fn fold_region(&mut self, r: I::Region) -> I::Region {
let kind = match r.kind() {
ty::ReBound(..) => return r,

// We may encounter `ReStatic` in item signatures or the hidden type
// of an opaque. `ReErased` should only be encountered in the hidden
// type of an opaque for regions that are ignored for the purposes of
// captures.
//
// FIXME: We should investigate the perf implications of not uniquifying
// `ReErased`. We may be able to short-circuit registering region
// obligations if we encounter a `ReErased` on one side, for example.
ty::ReStatic | ty::ReErased | ty::ReError(_) => match self.canonicalize_mode {
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { .. } => return r,
},

ty::ReEarlyParam(_) | ty::ReLateParam(_) => match self.canonicalize_mode {
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { .. } => {
panic!("unexpected region in response: {r:?}")
}
},

ty::RePlaceholder(placeholder) => match self.canonicalize_mode {
// We canonicalize placeholder regions as existentials in query inputs.
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { max_input_universe } => {
// If we have a placeholder region inside of a query, it must be from
// a new universe.
if max_input_universe.can_name(placeholder.universe()) {
panic!("new placeholder in universe {max_input_universe:?}: {r:?}");
}
CanonicalVarKind::PlaceholderRegion(placeholder)
}
},

ty::ReVar(vid) => {
assert_eq!(
self.delegate.opportunistic_resolve_lt_var(vid),
r,
"region vid should have been resolved fully before canonicalization"
);
match self.canonicalize_mode {
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { .. } => {
CanonicalVarKind::Region(self.delegate.universe_of_lt(vid).unwrap())
}
}
}
};

let existing_bound_var = match self.canonicalize_mode {
CanonicalizeMode::Input => None,
CanonicalizeMode::Response { .. } => {
self.variables.iter().position(|&v| v == r.into()).map(ty::BoundVar::from)
}
};

let var = existing_bound_var.unwrap_or_else(|| {
let var = ty::BoundVar::from(self.variables.len());
self.variables.push(r.into());
self.primitive_var_infos.push(CanonicalVarInfo { kind });
var
});

Region::new_anon_bound(self.cx(), self.binder_index, var)
}

fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
fn cached_fold_ty(&mut self, t: I::Ty) -> I::Ty {
let kind = match t.kind() {
ty::Infer(i) => match i {
ty::TyVar(vid) => {
Expand Down Expand Up @@ -368,20 +323,98 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
| ty::Tuple(_)
| ty::Alias(_, _)
| ty::Bound(_, _)
| ty::Error(_) => return t.super_fold_with(self),
| ty::Error(_) => {
return t.super_fold_with(self);
}
};

let var = ty::BoundVar::from(
self.variables.iter().position(|&v| v == t.into()).unwrap_or_else(|| {
let var = self.variables.len();
self.variables.push(t.into());
self.primitive_var_infos.push(CanonicalVarInfo { kind });
var
}),
);
let var = self.get_or_insert_bound_var(t, CanonicalVarInfo { kind });

Ty::new_anon_bound(self.cx(), self.binder_index, var)
}
}

impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicalizer<'_, D, I> {
fn cx(&self) -> I {
self.delegate.cx()
}

fn fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T>
where
T: TypeFoldable<I>,
{
self.binder_index.shift_in(1);
let t = t.super_fold_with(self);
self.binder_index.shift_out(1);
t
}

fn fold_region(&mut self, r: I::Region) -> I::Region {
let kind = match r.kind() {
ty::ReBound(..) => return r,

// We may encounter `ReStatic` in item signatures or the hidden type
// of an opaque. `ReErased` should only be encountered in the hidden
// type of an opaque for regions that are ignored for the purposes of
// captures.
//
// FIXME: We should investigate the perf implications of not uniquifying
// `ReErased`. We may be able to short-circuit registering region
// obligations if we encounter a `ReErased` on one side, for example.
ty::ReStatic | ty::ReErased | ty::ReError(_) => match self.canonicalize_mode {
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { .. } => return r,
},

ty::ReEarlyParam(_) | ty::ReLateParam(_) => match self.canonicalize_mode {
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { .. } => {
panic!("unexpected region in response: {r:?}")
}
},

ty::RePlaceholder(placeholder) => match self.canonicalize_mode {
// We canonicalize placeholder regions as existentials in query inputs.
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { max_input_universe } => {
// If we have a placeholder region inside of a query, it must be from
// a new universe.
if max_input_universe.can_name(placeholder.universe()) {
panic!("new placeholder in universe {max_input_universe:?}: {r:?}");
}
CanonicalVarKind::PlaceholderRegion(placeholder)
}
},

ty::ReVar(vid) => {
assert_eq!(
self.delegate.opportunistic_resolve_lt_var(vid),
r,
"region vid should have been resolved fully before canonicalization"
);
match self.canonicalize_mode {
CanonicalizeMode::Input => CanonicalVarKind::Region(ty::UniverseIndex::ROOT),
CanonicalizeMode::Response { .. } => {
CanonicalVarKind::Region(self.delegate.universe_of_lt(vid).unwrap())
}
}
}
};

let var = self.get_or_insert_bound_var(r, CanonicalVarInfo { kind });

Region::new_anon_bound(self.cx(), self.binder_index, var)
}

fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
if let Some(&ty) = self.cache.get(&(self.binder_index, t)) {
ty
} else {
let res = self.cached_fold_ty(t);
assert!(self.cache.insert((self.binder_index, t), res).is_none());
res
}
}

fn fold_const(&mut self, c: I::Const) -> I::Const {
let kind = match c.kind() {
Expand Down Expand Up @@ -419,14 +452,7 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
| ty::ConstKind::Expr(_) => return c.super_fold_with(self),
};

let var = ty::BoundVar::from(
self.variables.iter().position(|&v| v == c.into()).unwrap_or_else(|| {
let var = self.variables.len();
self.variables.push(c.into());
self.primitive_var_infos.push(CanonicalVarInfo { kind });
var
}),
);
let var = self.get_or_insert_bound_var(c, CanonicalVarInfo { kind });

Const::new_anon_bound(self.cx(), self.binder_index, var)
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ where
goal: Goal<I, I::Predicate>,
) -> Result<(NestedNormalizationGoals<I>, bool, Certainty), NoSolution> {
let (orig_values, canonical_goal) = self.canonicalize_goal(goal);

let num_orig_values = orig_values.iter().filter(|v| v.as_region().is_none()).count();
if num_orig_values >= 128 {
println!("bail due to too many orig_values: {num_orig_values}");
return Ok((Default::default(), false, Certainty::overflow(false)));
}
let mut goal_evaluation =
self.inspect.new_goal_evaluation(goal, &orig_values, goal_evaluation_kind);
let canonical_response = EvalCtxt::evaluate_canonical_goal(
Expand Down

0 comments on commit 5a95672

Please sign in to comment.