Skip to content

Commit

Permalink
cache the world
Browse files Browse the repository at this point in the history
  • Loading branch information
lcnr committed Sep 25, 2024
1 parent 5a95672 commit 0dac974
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 23 deletions.
11 changes: 10 additions & 1 deletion compiler/rustc_infer/src/infer/relate/type_relating.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::traits::solve::Goal;
use rustc_middle::ty::relate::{
Relate, RelateResult, TypeRelation, relate_args_invariantly, relate_args_with_variances,
Expand All @@ -16,6 +17,7 @@ pub struct TypeRelating<'combine, 'a, 'tcx> {
fields: &'combine mut CombineFields<'a, 'tcx>,
structurally_relate_aliases: StructurallyRelateAliases,
ambient_variance: ty::Variance,
cache: FxHashSet<(ty::Variance, Ty<'tcx>, Ty<'tcx>)>,
}

impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
Expand All @@ -24,7 +26,7 @@ impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
structurally_relate_aliases: StructurallyRelateAliases,
ambient_variance: ty::Variance,
) -> TypeRelating<'combine, 'infcx, 'tcx> {
TypeRelating { fields: f, structurally_relate_aliases, ambient_variance }
TypeRelating { fields: f, structurally_relate_aliases, ambient_variance, cache: Default::default() }
}
}

Expand Down Expand Up @@ -74,9 +76,14 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
return Ok(a);
}


let infcx = self.fields.infcx;
let a = infcx.shallow_resolve(a);
let b = infcx.shallow_resolve(b);

if self.cache.contains(&(self.ambient_variance, a, b)) {
return Ok(a);
}

match (a.kind(), b.kind()) {
(&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
Expand Down Expand Up @@ -160,6 +167,8 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
}
}

assert!(self.cache.insert((self.ambient_variance, a, b)));

Ok(a)
}

Expand Down
15 changes: 12 additions & 3 deletions compiler/rustc_infer/src/infer/resolve.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::bug;
use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFolder, TypeSuperFoldable};
use rustc_middle::ty::visit::TypeVisitableExt;
Expand All @@ -15,12 +16,13 @@ use super::{FixupError, FixupResult, InferCtxt};
/// points for correctness.
pub struct OpportunisticVarResolver<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
cache: FxHashMap<Ty<'tcx>, Ty<'tcx>>,
}

impl<'a, 'tcx> OpportunisticVarResolver<'a, 'tcx> {
#[inline]
pub fn new(infcx: &'a InferCtxt<'tcx>) -> Self {
OpportunisticVarResolver { infcx }
OpportunisticVarResolver { infcx, cache: Default::default() }
}
}

Expand All @@ -31,12 +33,19 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {

#[inline]
fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
if !t.has_non_region_infer() {
if let Some(&ty) = self.cache.get(&t) {
return ty;
}

let res = if !t.has_non_region_infer() {
t // micro-optimize -- if there is nothing in this type that this fold affects...
} else {
let t = self.infcx.shallow_resolve(t);
t.super_fold_with(self)
}
};

assert!(self.cache.insert(t, res).is_none());
res
}

fn fold_const(&mut self, ct: Const<'tcx>) -> Const<'tcx> {
Expand Down
17 changes: 13 additions & 4 deletions compiler/rustc_middle/src/ty/fold.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
use rustc_hir::def_id::DefId;
pub use rustc_type_ir::fold::{
FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable, shift_region, shift_vars,
Expand Down Expand Up @@ -164,11 +164,13 @@ struct BoundVarReplacer<'tcx, D> {
current_index: ty::DebruijnIndex,

delegate: D,

cache: FxHashMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
}

impl<'tcx, D: BoundVarReplacerDelegate<'tcx>> BoundVarReplacer<'tcx, D> {
fn new(tcx: TyCtxt<'tcx>, delegate: D) -> Self {
BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate }
BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate, cache: Default::default() }
}
}

Expand All @@ -191,15 +193,22 @@ where
}

fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
match *t.kind() {
if let Some(&ty) = self.cache.get(&(self.current_index, t)) {
return ty;
}

let res = match *t.kind() {
ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
let ty = self.delegate.replace_ty(bound_ty);
debug_assert!(!ty.has_vars_bound_above(ty::INNERMOST));
ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32())
}
_ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self),
_ => t,
}
};

assert!(self.cache.insert((self.current_index, t), res).is_none());
res
}

fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
Expand Down
15 changes: 12 additions & 3 deletions compiler/rustc_next_trait_solver/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_type_ir::inherent::*;
use rustc_type_ir::visit::TypeVisitableExt;
use rustc_type_ir::{self as ty, InferCtxtLike, Interner};
use rustc_type_ir::data_structures::HashMap;

use crate::delegate::SolverDelegate;

Expand All @@ -15,11 +16,12 @@ where
I: Interner,
{
delegate: &'a D,
cache: HashMap<I::Ty, I::Ty>,
}

impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
pub fn new(delegate: &'a D) -> Self {
EagerResolver { delegate }
EagerResolver { delegate, cache: Default::default() }
}
}

Expand All @@ -29,7 +31,11 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
}

fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
match t.kind() {
if let Some(&ty) = self.cache.get(&t) {
return ty;
}

let res = match t.kind() {
ty::Infer(ty::TyVar(vid)) => {
let resolved = self.delegate.opportunistic_resolve_ty_var(vid);
if t != resolved && resolved.has_infer() {
Expand All @@ -47,7 +53,10 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
t
}
}
}
};

assert!(self.cache.insert(t, res).is_none());
res
}

fn fold_region(&mut self, r: I::Region) -> I::Region {
Expand Down
45 changes: 35 additions & 10 deletions compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::ControlFlow;
use derive_where::derive_where;
#[cfg(feature = "nightly")]
use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable};
use rustc_type_ir::data_structures::ensure_sufficient_stack;
use rustc_type_ir::data_structures::{HashMap, HashSet, ensure_sufficient_stack};
use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_type_ir::inherent::*;
use rustc_type_ir::relate::Relate;
Expand Down Expand Up @@ -587,7 +587,7 @@ where
pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<I, ty::NormalizesTo<I>>) {
goal.predicate = goal
.predicate
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
self.inspect.add_normalizes_to_goal(self.delegate, self.max_input_universe, goal);
self.nested_goals.normalizes_to_goals.push(goal);
}
Expand All @@ -596,7 +596,7 @@ where
pub(super) fn add_goal(&mut self, source: GoalSource, mut goal: Goal<I, I::Predicate>) {
goal.predicate = goal
.predicate
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
self.inspect.add_goal(self.delegate, self.max_input_universe, source, goal);
self.nested_goals.goals.push((source, goal));
}
Expand Down Expand Up @@ -660,6 +660,7 @@ where
term: I::Term,
universe_of_term: ty::UniverseIndex,
delegate: &'a D,
cache: HashSet<I::Ty>,
}

impl<D: SolverDelegate<Interner = I>, I: Interner> ContainsTermOrNotNameable<'_, D, I> {
Expand All @@ -677,6 +678,10 @@ where
{
type Result = ControlFlow<()>;
fn visit_ty(&mut self, t: I::Ty) -> Self::Result {
if self.cache.contains(&t) {
return ControlFlow::Continue(());
}

match t.kind() {
ty::Infer(ty::TyVar(vid)) => {
if let ty::TermKind::Ty(term) = self.term.kind() {
Expand All @@ -689,17 +694,18 @@ where
}
}

self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())
self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())?;
}
ty::Placeholder(p) => self.check_nameable(p.universe()),
ty::Placeholder(p) => self.check_nameable(p.universe())?,
_ => {
if t.has_non_region_infer() || t.has_placeholders() {
t.super_visit_with(self)
} else {
ControlFlow::Continue(())
t.super_visit_with(self)?
}
}
}

assert!(self.cache.insert(t));
ControlFlow::Continue(())
}

fn visit_const(&mut self, c: I::Const) -> Self::Result {
Expand Down Expand Up @@ -734,6 +740,7 @@ where
delegate: self.delegate,
universe_of_term,
term: goal.predicate.term,
cache: Default::default(),
};
goal.predicate.alias.visit_with(&mut visitor).is_continue()
&& goal.param_env.visit_with(&mut visitor).is_continue()
Expand Down Expand Up @@ -1021,6 +1028,17 @@ where
{
ecx: &'me mut EvalCtxt<'a, D>,
param_env: I::ParamEnv,
cache: HashMap<I::Ty, I::Ty>,
}

impl<'me, 'a, D, I> ReplaceAliasWithInfer<'me, 'a, D, I>
where
D: SolverDelegate<Interner = I>,
I: Interner,
{
fn new(ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv) -> Self {
ReplaceAliasWithInfer { ecx, param_env, cache: Default::default() }
}
}

impl<D, I> TypeFolder<I> for ReplaceAliasWithInfer<'_, '_, D, I>
Expand All @@ -1033,7 +1051,11 @@ where
}

fn fold_ty(&mut self, ty: I::Ty) -> I::Ty {
match ty.kind() {
if let Some(&entry) = self.cache.get(&ty) {
return entry;
}

let res = match ty.kind() {
ty::Alias(..) if !ty.has_escaping_bound_vars() => {
let infer_ty = self.ecx.next_ty_infer();
let normalizes_to = ty::PredicateKind::AliasRelate(
Expand All @@ -1048,7 +1070,10 @@ where
infer_ty
}
_ => ty.super_fold_with(self),
}
};

assert!(self.cache.insert(ty, res).is_none());
res
}

fn fold_const(&mut self, ct: I::Const) -> I::Const {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//@ check-pass
//@ revisions: current next
//[next]@ compile-flags: -Znext-solver
//@ revisions: ai_current ai_next ia_current ia_next ii_current ii_next
//@[ai_next] compile-flags: -Znext-solver
//@[ia_next] compile-flags: -Znext-solver
//@[ii_next] compile-flags: -Znext-solver
// check-pass

// Regression test for nalgebra hang <https://github.com/rust-lang/rust/issues/130056>.

Expand All @@ -15,7 +18,12 @@ trait Trait {
type Assoc: ?Sized;
}
impl<T: ?Sized + Trait> Trait for W<T, T> {
#[cfg(any(ai_current, ai_next))]
type Assoc = W<T::Assoc, Id<T::Assoc>>;
#[cfg(any(ia_current, ia_next))]
type Assoc = W<Id<T::Assoc>, T::Assoc>;
#[cfg(any(ii_current, ii_next))]
type Assoc = W<Id<T::Assoc>, Id<T::Assoc>>;
}

trait Overlap<T: ?Sized> {}
Expand Down

0 comments on commit 0dac974

Please sign in to comment.