Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify fulfill_implication #133319

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ use rustc_hir as hir;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
use rustc_infer::traits::ObligationCause;
use rustc_infer::traits::specialization_graph::Node;
use rustc_middle::ty::trait_def::TraitSpecializationKind;
use rustc_middle::ty::{
Expand Down Expand Up @@ -210,13 +211,7 @@ fn get_impl_args(
impl1_def_id.to_def_id(),
impl1_args,
impl2_node,
|_, span| {
traits::ObligationCause::new(
impl1_span,
impl1_def_id,
traits::ObligationCauseCode::WhereClause(impl2_node.def_id(), span),
)
},
&ObligationCause::misc(impl1_span, impl1_def_id),
);

let errors = ocx.select_all_or_error();
Expand Down
191 changes: 100 additions & 91 deletions compiler/rustc_trait_selection/src/traits/specialize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@ use rustc_data_structures::fx::FxIndexSet;
use rustc_errors::codes::*;
use rustc_errors::{Diag, EmissionGuarantee};
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_infer::infer::DefineOpaqueTypes;
use rustc_middle::bug;
use rustc_middle::query::LocalCrate;
use rustc_middle::ty::print::PrintTraitRefExt as _;
use rustc_middle::ty::{
self, GenericArgsRef, ImplSubject, Ty, TyCtxt, TypeVisitableExt, TypingMode,
};
use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingMode};
use rustc_session::lint::builtin::{COHERENCE_LEAK_CHECK, ORDER_DEPENDENT_TRAIT_OBJECTS};
use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span, sym};
use rustc_type_ir::solve::NoSolution;
use specialization_graph::GraphExt;
use tracing::{debug, instrument};

use super::{SelectionContext, util};
use crate::error_reporting::traits::to_pretty_impl_header;
use crate::errors::NegativePositiveConflict;
use crate::infer::{InferCtxt, InferOk, TyCtxtInferExt};
use crate::infer::{InferCtxt, TyCtxtInferExt};
use crate::traits::select::IntercrateAmbiguityCause;
use crate::traits::{FutureCompatOverlapErrorKind, ObligationCause, ObligationCtxt, coherence};
use crate::traits::{
FutureCompatOverlapErrorKind, ObligationCause, ObligationCtxt, coherence,
predicates_for_generics,
};

/// Information pertinent to an overlapping impl error.
#[derive(Debug)]
Expand Down Expand Up @@ -87,9 +87,14 @@ pub fn translate_args<'tcx>(
source_args: GenericArgsRef<'tcx>,
target_node: specialization_graph::Node,
) -> GenericArgsRef<'tcx> {
translate_args_with_cause(infcx, param_env, source_impl, source_args, target_node, |_, _| {
ObligationCause::dummy()
})
translate_args_with_cause(
infcx,
param_env,
source_impl,
source_args,
target_node,
&ObligationCause::dummy(),
)
}

/// Like [translate_args], but obligations from the parent implementation
Expand All @@ -104,7 +109,7 @@ pub fn translate_args_with_cause<'tcx>(
source_impl: DefId,
source_args: GenericArgsRef<'tcx>,
target_node: specialization_graph::Node,
cause: impl Fn(usize, Span) -> ObligationCause<'tcx>,
cause: &ObligationCause<'tcx>,
) -> GenericArgsRef<'tcx> {
debug!(
"translate_args({:?}, {:?}, {:?}, {:?})",
Expand All @@ -123,7 +128,7 @@ pub fn translate_args_with_cause<'tcx>(
}

fulfill_implication(infcx, param_env, source_trait_ref, source_impl, target_impl, cause)
.unwrap_or_else(|()| {
.unwrap_or_else(|_| {
bug!(
"When translating generic parameters from {source_impl:?} to \
{target_impl:?}, the expected specialization failed to hold"
Expand All @@ -137,6 +142,84 @@ pub fn translate_args_with_cause<'tcx>(
source_args.rebase_onto(infcx.tcx, source_impl, target_args)
}

/// Attempt to fulfill all obligations of `target_impl` after unification with
/// `source_trait_ref`. If successful, returns the generic parameters for *all* the
/// generics of `target_impl`, including both those needed to unify with
/// `source_trait_ref` and those whose identity is determined via a where
/// clause in the impl.
fn fulfill_implication<'tcx>(
infcx: &InferCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
source_trait_ref: ty::TraitRef<'tcx>,
source_impl: DefId,
target_impl: DefId,
cause: &ObligationCause<'tcx>,
) -> Result<GenericArgsRef<'tcx>, NoSolution> {
debug!(
"fulfill_implication({:?}, trait_ref={:?} |- {:?} applies)",
param_env, source_trait_ref, target_impl
);

let ocx = ObligationCtxt::new(infcx);
let source_trait_ref = ocx.normalize(cause, param_env, source_trait_ref);

if !ocx.select_all_or_error().is_empty() {
infcx.dcx().span_delayed_bug(
infcx.tcx.def_span(source_impl),
format!("failed to fully normalize {source_trait_ref}"),
);
return Err(NoSolution);
}

let target_args = infcx.fresh_args_for_item(DUMMY_SP, target_impl);
let target_trait_ref = ocx.normalize(
cause,
param_env,
infcx
.tcx
.impl_trait_ref(target_impl)
.expect("expected source impl to be a trait impl")
.instantiate(infcx.tcx, target_args),
);

// do the impls unify? If not, no specialization.
ocx.eq(cause, param_env, source_trait_ref, target_trait_ref)?;

// Now check that the source trait ref satisfies all the where clauses of the target impl.
// This is not just for correctness; we also need this to constrain any params that may
// only be referenced via projection predicates.
let predicates = ocx.normalize(
cause,
param_env,
infcx.tcx.predicates_of(target_impl).instantiate(infcx.tcx, target_args),
);
let obligations = predicates_for_generics(|_, _| cause.clone(), param_env, predicates);
ocx.register_obligations(obligations);

let errors = ocx.select_all_or_error();
if !errors.is_empty() {
// no dice!
debug!(
"fulfill_implication: for impls on {:?} and {:?}, \
could not fulfill: {:?} given {:?}",
source_trait_ref,
target_trait_ref,
errors,
param_env.caller_bounds()
);
return Err(NoSolution);
}

debug!(
"fulfill_implication: an impl for {:?} specializes {:?}",
source_trait_ref, target_trait_ref
);

// Now resolve the *generic parameters* we built for the target earlier, replacing
// the inference variables inside with whatever we got from fulfillment.
Ok(infcx.resolve_vars_if_possible(target_args))
}

pub(super) fn specialization_enabled_in(tcx: TyCtxt<'_>, _: LocalCrate) -> bool {
tcx.features().specialization() || tcx.features().min_specialization()
}
Expand Down Expand Up @@ -182,99 +265,25 @@ pub(super) fn specializes(tcx: TyCtxt<'_>, (impl1_def_id, impl2_def_id): (DefId,
return false;
}

// create a parameter environment corresponding to a (placeholder) instantiation of impl1
let penv = tcx.param_env(impl1_def_id);
// create a parameter environment corresponding to an identity instantiation of impl1,
// i.e. the most generic instantiation of impl1.
let param_env = tcx.param_env(impl1_def_id);

// Create an infcx, taking the predicates of impl1 as assumptions:
let infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());

// Attempt to prove that impl2 applies, given all of the above.
fulfill_implication(
&infcx,
penv,
param_env,
impl1_trait_header.trait_ref.instantiate_identity(),
impl1_def_id,
impl2_def_id,
|_, _| ObligationCause::dummy(),
&ObligationCause::dummy(),
)
.is_ok()
}

/// Attempt to fulfill all obligations of `target_impl` after unification with
/// `source_trait_ref`. If successful, returns the generic parameters for *all* the
/// generics of `target_impl`, including both those needed to unify with
/// `source_trait_ref` and those whose identity is determined via a where
/// clause in the impl.
fn fulfill_implication<'tcx>(
infcx: &InferCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
source_trait_ref: ty::TraitRef<'tcx>,
source_impl: DefId,
target_impl: DefId,
error_cause: impl Fn(usize, Span) -> ObligationCause<'tcx>,
) -> Result<GenericArgsRef<'tcx>, ()> {
debug!(
"fulfill_implication({:?}, trait_ref={:?} |- {:?} applies)",
param_env, source_trait_ref, target_impl
);

let ocx = ObligationCtxt::new(infcx);
let source_trait_ref = ocx.normalize(&ObligationCause::dummy(), param_env, source_trait_ref);

if !ocx.select_all_or_error().is_empty() {
infcx.dcx().span_delayed_bug(
infcx.tcx.def_span(source_impl),
format!("failed to fully normalize {source_trait_ref}"),
);
}

let source_trait_ref = infcx.resolve_vars_if_possible(source_trait_ref);
let source_trait = ImplSubject::Trait(source_trait_ref);

let selcx = SelectionContext::new(infcx);
let target_args = infcx.fresh_args_for_item(DUMMY_SP, target_impl);
let (target_trait, obligations) =
util::impl_subject_and_oblig(&selcx, param_env, target_impl, target_args, error_cause);

// do the impls unify? If not, no specialization.
let Ok(InferOk { obligations: more_obligations, .. }) = infcx
.at(&ObligationCause::dummy(), param_env)
// Ok to use `Yes`, as all the generic params are already replaced by inference variables,
// which will match the opaque type no matter if it is defining or not.
// Any concrete type that would match the opaque would already be handled by coherence rules,
// and thus either be ok to match here and already have errored, or it won't match, in which
// case there is no issue anyway.
.eq(DefineOpaqueTypes::Yes, source_trait, target_trait)
else {
debug!("fulfill_implication: {:?} does not unify with {:?}", source_trait, target_trait);
return Err(());
};

// attempt to prove all of the predicates for impl2 given those for impl1
// (which are packed up in penv)
ocx.register_obligations(obligations.chain(more_obligations));

let errors = ocx.select_all_or_error();
if !errors.is_empty() {
// no dice!
debug!(
"fulfill_implication: for impls on {:?} and {:?}, \
could not fulfill: {:?} given {:?}",
source_trait,
target_trait,
errors,
param_env.caller_bounds()
);
return Err(());
}

debug!("fulfill_implication: an impl for {:?} specializes {:?}", source_trait, target_trait);

// Now resolve the *generic parameters* we built for the target earlier, replacing
// the inference variables inside with whatever we got from fulfillment.
Ok(infcx.resolve_vars_if_possible(target_args))
}

/// Query provider for `specialization_graph_of`.
pub(super) fn specialization_graph_provider(
tcx: TyCtxt<'_>,
Expand Down
35 changes: 2 additions & 33 deletions compiler/rustc_trait_selection/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@ use std::collections::BTreeMap;
use rustc_data_structures::fx::FxIndexMap;
use rustc_errors::Diag;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::{InferCtxt, InferOk};
use rustc_infer::infer::InferCtxt;
pub use rustc_infer::traits::util::*;
use rustc_middle::bug;
use rustc_middle::ty::{
self, GenericArgsRef, ImplSubject, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
TypeVisitableExt, Upcast,
self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, Upcast,
};
use rustc_span::Span;
use smallvec::{SmallVec, smallvec};
use tracing::debug;

use super::{NormalizeExt, ObligationCause, PredicateObligation, SelectionContext};

///////////////////////////////////////////////////////////////////////////
// `TraitAliasExpander` iterator
///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -166,34 +163,6 @@ impl<'tcx> Iterator for TraitAliasExpander<'tcx> {
// Other
///////////////////////////////////////////////////////////////////////////

/// Instantiate all bound parameters of the impl subject with the given args,
/// returning the resulting subject and all obligations that arise.
/// The obligations are closed under normalization.
pub(crate) fn impl_subject_and_oblig<'a, 'tcx>(
selcx: &SelectionContext<'a, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
impl_def_id: DefId,
impl_args: GenericArgsRef<'tcx>,
cause: impl Fn(usize, Span) -> ObligationCause<'tcx>,
) -> (ImplSubject<'tcx>, impl Iterator<Item = PredicateObligation<'tcx>>) {
let subject = selcx.tcx().impl_subject(impl_def_id);
let subject = subject.instantiate(selcx.tcx(), impl_args);

let InferOk { value: subject, obligations: normalization_obligations1 } =
selcx.infcx.at(&ObligationCause::dummy(), param_env).normalize(subject);

let predicates = selcx.tcx().predicates_of(impl_def_id);
let predicates = predicates.instantiate(selcx.tcx(), impl_args);
let InferOk { value: predicates, obligations: normalization_obligations2 } =
selcx.infcx.at(&ObligationCause::dummy(), param_env).normalize(predicates);
let impl_obligations = super::predicates_for_generics(cause, param_env, predicates);

let impl_obligations =
impl_obligations.chain(normalization_obligations1).chain(normalization_obligations2);

(subject, impl_obligations)
}

/// Casts a trait reference into a reference to one of its super
/// traits; returns `None` if `target_trait_def_id` is not a
/// supertrait.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,15 @@ error[E0477]: the type `&'a i32` does not fulfill the required lifetime
LL | impl<'a> Tr for &'a i32 {
| ^^^^^^^^^^^^^^^^^^^^^^^
|
note: type must satisfy the static lifetime as required by this binding
--> $DIR/specialize_with_generalize_lifetimes.rs:12:15
|
LL | impl<T: Any + 'static> Tr for T {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not care about specialization enough to keep this span around.

| ^^^^^^^
= note: type must satisfy the static lifetime

error[E0477]: the type `Wrapper<'a>` does not fulfill the required lifetime
--> $DIR/specialize_with_generalize_lifetimes.rs:31:1
|
LL | impl<'a> Tr for Wrapper<'a> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
note: type must satisfy the static lifetime as required by this binding
--> $DIR/specialize_with_generalize_lifetimes.rs:12:15
|
LL | impl<T: Any + 'static> Tr for T {
| ^^^^^^^
= note: type must satisfy the static lifetime

error: aborting due to 2 previous errors

Expand Down
Loading