Skip to content

Commit

Permalink
caching high hopes
Browse files Browse the repository at this point in the history
  • Loading branch information
lcnr committed Sep 27, 2024
1 parent 95b9dc1 commit 7bf456d
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 38 deletions.
10 changes: 4 additions & 6 deletions compiler/rustc_infer/src/infer/relate/type_relating.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use rustc_data_structures::sso::SsoHashSet;
use rustc_middle::traits::solve::Goal;
use rustc_middle::ty::relate::{
Relate, RelateResult, TypeRelation, relate_args_invariantly, relate_args_with_variances,
};
use rustc_middle::ty::{self, Ty, TyCtxt, TyVar};
use rustc_span::Span;
use rustc_type_ir::data_structures::DelayedSet;
use tracing::{debug, instrument};

use super::combine::CombineFields;
Expand All @@ -17,7 +17,7 @@ pub struct TypeRelating<'combine, 'a, 'tcx> {
fields: &'combine mut CombineFields<'a, 'tcx>,
structurally_relate_aliases: StructurallyRelateAliases,
ambient_variance: ty::Variance,
cache: SsoHashSet<(ty::Variance, Ty<'tcx>, Ty<'tcx>)>,
cache: DelayedSet<(ty::Variance, Ty<'tcx>, Ty<'tcx>)>,
}

impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
Expand Down Expand Up @@ -85,7 +85,7 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
let a = infcx.shallow_resolve(a);
let b = infcx.shallow_resolve(b);

if infcx.next_trait_solver() && self.cache.contains(&(self.ambient_variance, a, b)) {
if self.cache.contains(&(self.ambient_variance, a, b)) {
return Ok(a);
}

Expand Down Expand Up @@ -171,9 +171,7 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
}
}

if infcx.next_trait_solver() {
assert!(self.cache.insert((self.ambient_variance, a, b)));
}
assert!(self.cache.insert((self.ambient_variance, a, b)));

Ok(a)
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_infer/src/infer/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use rustc_middle::bug;
use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFolder, TypeSuperFoldable};
use rustc_middle::ty::visit::TypeVisitableExt;
use rustc_middle::ty::{self, Const, InferConst, Ty, TyCtxt, TypeFoldable};
use rustc_type_ir::data_structures::DelayedMap;

use super::{FixupError, FixupResult, InferCtxt};

Expand All @@ -16,7 +17,7 @@ use super::{FixupError, FixupResult, InferCtxt};
/// points for correctness.
pub struct OpportunisticVarResolver<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
cache: SsoHashMap<Ty<'tcx>, Ty<'tcx>>,
cache: DelayedMap<Ty<'tcx>, Ty<'tcx>>,
}

impl<'a, 'tcx> OpportunisticVarResolver<'a, 'tcx> {
Expand All @@ -40,7 +41,7 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
} else {
let shallow = self.infcx.shallow_resolve(t);
let res = shallow.super_fold_with(self);
assert!(self.cache.insert(t, res).is_none());
assert!(self.cache.insert(t, res));
res
}
}
Expand Down
25 changes: 13 additions & 12 deletions compiler/rustc_middle/src/ty/fold.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::sso::SsoHashMap;
use rustc_hir::def_id::DefId;
use rustc_type_ir::data_structures::DelayedMap;
pub use rustc_type_ir::fold::{
FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable, shift_region, shift_vars,
};
Expand Down Expand Up @@ -166,7 +166,7 @@ struct BoundVarReplacer<'tcx, D> {

delegate: D,

cache: SsoHashMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
}

impl<'tcx, D: BoundVarReplacerDelegate<'tcx>> BoundVarReplacer<'tcx, D> {
Expand Down Expand Up @@ -194,22 +194,23 @@ where
}

fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
if let Some(&ty) = self.cache.get(&(self.current_index, t)) {
return ty;
}

let res = match *t.kind() {
match *t.kind() {
ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
let ty = self.delegate.replace_ty(bound_ty);
debug_assert!(!ty.has_vars_bound_above(ty::INNERMOST));
ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32())
}
_ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self),
_ => t,
};
_ if t.has_vars_bound_at_or_above(self.current_index) => {
if let Some(&ty) = self.cache.get(&(self.current_index, t)) {
return ty;
}

assert!(self.cache.insert((self.current_index, t), res).is_none());
res
let res = t.super_fold_with(self);
assert!(self.cache.insert((self.current_index, t), res));
res
}
_ => t,
}
}

fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_next_trait_solver/src/resolve.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_type_ir::data_structures::HashMap;
use rustc_type_ir::data_structures::DelayedMap;
use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_type_ir::inherent::*;
use rustc_type_ir::visit::TypeVisitableExt;
Expand All @@ -16,7 +16,7 @@ where
I: Interner,
{
delegate: &'a D,
cache: HashMap<I::Ty, I::Ty>,
cache: DelayedMap<I::Ty, I::Ty>,
}

impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
Expand Down Expand Up @@ -48,7 +48,7 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
return ty;
}
let res = t.super_fold_with(self);
assert!(self.cache.insert(t, res).is_none());
assert!(self.cache.insert(t, res));
res
} else {
t
Expand Down
26 changes: 11 additions & 15 deletions compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,6 @@ where
goal: Goal<I, I::Predicate>,
) -> Result<(NestedNormalizationGoals<I>, bool, Certainty), NoSolution> {
let (orig_values, canonical_goal) = self.canonicalize_goal(goal);

let num_orig_values = orig_values.iter().filter(|v| v.as_region().is_none()).count();
if num_orig_values >= 128 {
println!("bail due to too many orig_values: {num_orig_values}");
return Ok((Default::default(), false, Certainty::overflow(false)));
}
let mut goal_evaluation =
self.inspect.new_goal_evaluation(goal, &orig_values, goal_evaluation_kind);
let canonical_response = EvalCtxt::evaluate_canonical_goal(
Expand Down Expand Up @@ -1049,11 +1043,7 @@ where
}

fn fold_ty(&mut self, ty: I::Ty) -> I::Ty {
if let Some(&entry) = self.cache.get(&ty) {
return entry;
}

let res = match ty.kind() {
match ty.kind() {
ty::Alias(..) if !ty.has_escaping_bound_vars() => {
let infer_ty = self.ecx.next_ty_infer();
let normalizes_to = ty::PredicateKind::AliasRelate(
Expand All @@ -1067,11 +1057,17 @@ where
);
infer_ty
}
_ => ty.super_fold_with(self),
};
_ if ty.has_aliases() => {
if let Some(&entry) = self.cache.get(&ty) {
return entry;
}

assert!(self.cache.insert(ty, res).is_none());
res
let res = ty.super_fold_with(self);
assert!(self.cache.insert(ty, res).is_none());
res
}
_ => ty,
}
}

fn fold_const(&mut self, ct: I::Const) -> I::Const {
Expand Down
68 changes: 68 additions & 0 deletions compiler/rustc_type_ir/src/data_structures/delayed_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use std::hash::Hash;

use crate::data_structures::{HashMap, HashSet};

const CACHE_CUTOFF: u32 = 16;

/// A hashmap which only starts hashing after ignoring the first few inputs.
///
/// This is used in type folders asin nearly all cases caching is not worth it
/// as nearly all folded types are tiny. However, there are very rare incredibly
/// large types for which caching is necessary to avoid hangs.
#[derive(Debug)]
pub struct DelayedMap<K, V> {
cache: HashMap<K, V>,
count: u32,
}

impl<K, V> Default for DelayedMap<K, V> {
fn default() -> Self {
DelayedMap { cache: Default::default(), count: 0 }
}
}

impl<K: Hash + Eq, V> DelayedMap<K, V> {
#[inline]
pub fn insert(&mut self, key: K, value: V) -> bool {
if self.count >= CACHE_CUTOFF {
self.cache.insert(key, value).is_none()
} else {
self.count += 1;
true
}
}

#[inline]
pub fn get(&self, key: &K) -> Option<&V> {
self.cache.get(key)
}
}

#[derive(Debug)]
pub struct DelayedSet<T> {
cache: HashSet<T>,
count: u32,
}

impl<T> Default for DelayedSet<T> {
fn default() -> Self {
DelayedSet { cache: Default::default(), count: 0 }
}
}

impl<T: Hash + Eq> DelayedSet<T> {
#[inline]
pub fn insert(&mut self, value: T) -> bool {
if self.count >= CACHE_CUTOFF {
self.cache.insert(value)
} else {
self.count += 1;
true
}
}

#[inline]
pub fn contains(&self, value: &T) -> bool {
self.cache.contains(value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
pub type IndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
pub type IndexSet<V> = indexmap::IndexSet<V, BuildHasherDefault<FxHasher>>;

mod delayed_map;

#[cfg(feature = "nightly")]
mod impl_ {
pub use rustc_data_structures::sso::{SsoHashMap, SsoHashSet};
Expand All @@ -24,4 +26,5 @@ mod impl_ {
}
}

pub use delayed_map::{DelayedMap, DelayedSet};
pub use impl_::*;

0 comments on commit 7bf456d

Please sign in to comment.