Skip to content

Commit

Permalink
Add trait obligation tracking to FulfillCtxt and expose FnCtxt in rus…
Browse files Browse the repository at this point in the history
…tc_infer using callback.

Pass each obligation to an fn callback with its respective inference context. This avoids needing to keep around copies of obligations or inference contexts.
  • Loading branch information
gavinleroy committed Jan 18, 2024
1 parent 2457c02 commit 47d53cd
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 66 deletions.
20 changes: 17 additions & 3 deletions compiler/rustc_hir_typeck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ use rustc_hir::{HirIdMap, Node};
use rustc_hir_analysis::astconv::AstConv;
use rustc_hir_analysis::check::check_abi;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_infer::traits::ObligationInspector;
use rustc_middle::query::Providers;
use rustc_middle::traits;
use rustc_middle::ty::{self, Ty, TyCtxt};
Expand Down Expand Up @@ -139,7 +140,7 @@ fn used_trait_imports(tcx: TyCtxt<'_>, def_id: LocalDefId) -> &UnordSet<LocalDef

fn typeck<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::TypeckResults<'tcx> {
let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity();
typeck_with_fallback(tcx, def_id, fallback)
typeck_with_fallback(tcx, def_id, fallback, None)
}

/// Used only to get `TypeckResults` for type inference during error recovery.
Expand All @@ -149,14 +150,24 @@ fn diagnostic_only_typeck<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::T
let span = tcx.hir().span(tcx.local_def_id_to_hir_id(def_id));
Ty::new_error_with_message(tcx, span, "diagnostic only typeck table used")
};
typeck_with_fallback(tcx, def_id, fallback)
typeck_with_fallback(tcx, def_id, fallback, None)
}

#[instrument(level = "debug", skip(tcx, fallback), ret)]
pub fn inspect_typeck<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: LocalDefId,
inspect: ObligationInspector<'tcx>,
) -> &'tcx ty::TypeckResults<'tcx> {
let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity();
typeck_with_fallback(tcx, def_id, fallback, Some(inspect))
}

#[instrument(level = "debug", skip(tcx, fallback, inspector), ret)]
fn typeck_with_fallback<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: LocalDefId,
fallback: impl Fn() -> Ty<'tcx> + 'tcx,
inspector: Option<ObligationInspector<'tcx>>,
) -> &'tcx ty::TypeckResults<'tcx> {
// Closures' typeck results come from their outermost function,
// as they are part of the same "inference environment".
Expand All @@ -178,6 +189,9 @@ fn typeck_with_fallback<'tcx>(
let param_env = tcx.param_env(def_id);

let inh = Inherited::new(tcx, def_id);
if let Some(inspector) = inspector {
inh.infcx.attach_obligation_inspector(inspector);
}
let mut fcx = FnCtxt::new(&inh, param_env, def_id);

if let Some(hir::FnSig { header, decl, .. }) = fn_sig {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_infer/src/infer/at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl<'tcx> InferCtxt<'tcx> {
universe: self.universe.clone(),
intercrate,
next_trait_solver: self.next_trait_solver,
obligation_inspector: self.obligation_inspector.clone(),
}
}
}
Expand Down
15 changes: 14 additions & 1 deletion compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use rustc_middle::infer::unify_key::{ConstVidKey, EffectVidKey};
use self::opaque_types::OpaqueTypeStorage;
pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog};

use crate::traits::{self, ObligationCause, PredicateObligations, TraitEngine, TraitEngineExt};
use crate::traits::{
self, ObligationCause, ObligationInspector, PredicateObligations, TraitEngine, TraitEngineExt,
};

use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
Expand Down Expand Up @@ -334,6 +336,8 @@ pub struct InferCtxt<'tcx> {
pub intercrate: bool,

next_trait_solver: bool,

pub obligation_inspector: Cell<Option<ObligationInspector<'tcx>>>,
}

impl<'tcx> ty::InferCtxtLike for InferCtxt<'tcx> {
Expand Down Expand Up @@ -708,6 +712,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
universe: Cell::new(ty::UniverseIndex::ROOT),
intercrate,
next_trait_solver,
obligation_inspector: Cell::new(None),
}
}
}
Expand Down Expand Up @@ -1724,6 +1729,14 @@ impl<'tcx> InferCtxt<'tcx> {
}
}
}

pub fn attach_obligation_inspector(&self, inspector: ObligationInspector<'tcx>) {
debug_assert!(
self.obligation_inspector.get().is_none(),
"shouldn't override a set obligation inspector"
);
self.obligation_inspector.set(Some(inspector));
}
}

impl<'tcx> TypeErrCtxt<'_, 'tcx> {
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_infer/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ use std::hash::{Hash, Hasher};

use hir::def_id::LocalDefId;
use rustc_hir as hir;
use rustc_middle::traits::query::NoSolution;
use rustc_middle::traits::solve::Certainty;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::{self, Const, ToPredicate, Ty, TyCtxt};
use rustc_span::Span;

pub use self::ImplSource::*;
pub use self::SelectionError::*;
use crate::infer::InferCtxt;

pub use self::engine::{TraitEngine, TraitEngineExt};
pub use self::project::MismatchedProjectionTypes;
Expand Down Expand Up @@ -116,6 +119,12 @@ pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;

pub type Selection<'tcx> = ImplSource<'tcx, PredicateObligation<'tcx>>;

/// A callback that can be provided to `inspect_typeck`. Invoked on evaluation
/// of root obligations.
pub type ObligationInspector<'tcx> =
fn(&InferCtxt<'tcx>, &PredicateObligation<'tcx>, Result<Certainty, NoSolution>);

#[derive(Clone)]
pub struct FulfillmentError<'tcx> {
pub obligation: PredicateObligation<'tcx>,
pub code: FulfillmentErrorCode<'tcx>,
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,9 @@ written to standard error output)"),
"for every macro invocation, print its name and arguments (default: no)"),
track_diagnostics: bool = (false, parse_bool, [UNTRACKED],
"tracks where in rustc a diagnostic was emitted"),
track_trait_obligations: bool = (false, parse_bool, [TRACKED],
"tracks evaluated obligations while trait solving, option is only \
valid when -Z next-solver=globally (default: no)"),
// Diagnostics are considered side-effects of a query (see `QuerySideEffects`) and are saved
// alongside query results and changes to translation options can affect diagnostics - so
// translation options should be tracked.
Expand Down
143 changes: 81 additions & 62 deletions compiler/rustc_trait_selection/src/solve/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rustc_middle::ty;
use rustc_middle::ty::error::{ExpectedFound, TypeError};

use super::eval_ctxt::GenerateProofTree;
use super::{Certainty, InferCtxtEvalExt};
use super::{Certainty, Goal, InferCtxtEvalExt};

/// A trait engine using the new trait solver.
///
Expand Down Expand Up @@ -43,6 +43,21 @@ impl<'tcx> FulfillmentCtxt<'tcx> {
);
FulfillmentCtxt { obligations: Vec::new(), usable_in_snapshot: infcx.num_open_snapshots() }
}

fn track_evaluated_obligation(
&self,
infcx: &InferCtxt<'tcx>,
obligation: &PredicateObligation<'tcx>,
result: &Result<(bool, Certainty, Vec<Goal<'tcx, ty::Predicate<'tcx>>>), NoSolution>,
) {
if let Some(inspector) = infcx.obligation_inspector.get() {
let result = match result {
Ok((_, c, _)) => Ok(*c),
Err(NoSolution) => Err(NoSolution),
};
(inspector)(infcx, &obligation, result);
}
}
}

impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
Expand All @@ -57,7 +72,8 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
}

fn collect_remaining_errors(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
self.obligations
let errors = self
.obligations
.drain(..)
.map(|obligation| {
let code = infcx.probe(|_| {
Expand Down Expand Up @@ -86,7 +102,9 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
root_obligation: obligation,
}
})
.collect()
.collect();

errors
}

fn select_where_possible(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
Expand All @@ -100,65 +118,66 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
let mut has_changed = false;
for obligation in mem::take(&mut self.obligations) {
let goal = obligation.clone().into();
let (changed, certainty, nested_goals) =
match infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0 {
Ok(result) => result,
Err(NoSolution) => {
errors.push(FulfillmentError {
obligation: obligation.clone(),
code: match goal.predicate.kind().skip_binder() {
ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => {
FulfillmentErrorCode::ProjectionError(
// FIXME: This could be a `Sorts` if the term is a type
MismatchedProjectionTypes { err: TypeError::Mismatch },
)
}
ty::PredicateKind::NormalizesTo(..) => {
FulfillmentErrorCode::ProjectionError(
MismatchedProjectionTypes { err: TypeError::Mismatch },
)
}
ty::PredicateKind::AliasRelate(_, _, _) => {
FulfillmentErrorCode::ProjectionError(
MismatchedProjectionTypes { err: TypeError::Mismatch },
)
}
ty::PredicateKind::Subtype(pred) => {
let (a, b) = infcx.instantiate_binder_with_placeholders(
goal.predicate.kind().rebind((pred.a, pred.b)),
);
let expected_found = ExpectedFound::new(true, a, b);
FulfillmentErrorCode::SubtypeError(
expected_found,
TypeError::Sorts(expected_found),
)
}
ty::PredicateKind::Coerce(pred) => {
let (a, b) = infcx.instantiate_binder_with_placeholders(
goal.predicate.kind().rebind((pred.a, pred.b)),
);
let expected_found = ExpectedFound::new(false, a, b);
FulfillmentErrorCode::SubtypeError(
expected_found,
TypeError::Sorts(expected_found),
)
}
ty::PredicateKind::Clause(_)
| ty::PredicateKind::ObjectSafe(_)
| ty::PredicateKind::Ambiguous => {
FulfillmentErrorCode::SelectionError(
SelectionError::Unimplemented,
)
}
ty::PredicateKind::ConstEquate(..) => {
bug!("unexpected goal: {goal:?}")
}
},
root_obligation: obligation,
});
continue;
}
};
let result = infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0;
self.track_evaluated_obligation(infcx, &obligation, &result);
let (changed, certainty, nested_goals) = match result {
Ok(result) => result,
Err(NoSolution) => {
errors.push(FulfillmentError {
obligation: obligation.clone(),
code: match goal.predicate.kind().skip_binder() {
ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => {
FulfillmentErrorCode::ProjectionError(
// FIXME: This could be a `Sorts` if the term is a type
MismatchedProjectionTypes { err: TypeError::Mismatch },
)
}
ty::PredicateKind::NormalizesTo(..) => {
FulfillmentErrorCode::ProjectionError(
MismatchedProjectionTypes { err: TypeError::Mismatch },
)
}
ty::PredicateKind::AliasRelate(_, _, _) => {
FulfillmentErrorCode::ProjectionError(
MismatchedProjectionTypes { err: TypeError::Mismatch },
)
}
ty::PredicateKind::Subtype(pred) => {
let (a, b) = infcx.instantiate_binder_with_placeholders(
goal.predicate.kind().rebind((pred.a, pred.b)),
);
let expected_found = ExpectedFound::new(true, a, b);
FulfillmentErrorCode::SubtypeError(
expected_found,
TypeError::Sorts(expected_found),
)
}
ty::PredicateKind::Coerce(pred) => {
let (a, b) = infcx.instantiate_binder_with_placeholders(
goal.predicate.kind().rebind((pred.a, pred.b)),
);
let expected_found = ExpectedFound::new(false, a, b);
FulfillmentErrorCode::SubtypeError(
expected_found,
TypeError::Sorts(expected_found),
)
}
ty::PredicateKind::Clause(_)
| ty::PredicateKind::ObjectSafe(_)
| ty::PredicateKind::Ambiguous => {
FulfillmentErrorCode::SelectionError(
SelectionError::Unimplemented,
)
}
ty::PredicateKind::ConstEquate(..) => {
bug!("unexpected goal: {goal:?}")
}
},
root_obligation: obligation,
});
continue;
}
};
// Push any nested goals that we get from unifying our canonical response
// with our obligation onto the fulfillment context.
self.obligations.extend(nested_goals.into_iter().map(|goal| {
Expand Down
32 changes: 32 additions & 0 deletions tests/ui/traits/track_trait_obligations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// compile-flags: -Ztrack-trait-obligations
// run-pass

// Just making sure this flag is accepted and doesn't crash the compiler
use traits::IntoString;

fn does_impl_into_string<T: IntoString>(_: T) {}

fn main() {
let v = vec![(0, 1), (2, 3)];

does_impl_into_string(v);
}

mod traits {
pub trait IntoString {
fn to_string(&self) -> String;
}

impl IntoString for (i32, i32) {
fn to_string(&self) -> String {
format!("({}, {})", self.0, self.1)
}
}

impl<T: IntoString> IntoString for Vec<T> {
fn to_string(&self) -> String {
let s = self.iter().map(|v| v.to_string()).collect::<Vec<_>>().join(", ");
format!("[{s}]")
}
}
}

0 comments on commit 47d53cd

Please sign in to comment.