Skip to content

Commit

Permalink
Allow to use super trait bounds in where clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
spastorino committed Nov 27, 2020
1 parent 361543d commit 24dcf6f
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 23 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ rustc_queries! {

/// To avoid cycles within the predicates of a single item we compute
/// per-type-parameter predicates for resolving `T::AssocTy`.
query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the bounds for type parameter `{}`", {
let id = tcx.hir().local_def_id_to_hir_id(key.1);
tcx.hir().ty_param_name(id)
Expand Down
13 changes: 12 additions & 1 deletion compiler/rustc_middle/src/ty/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::ty::subst::{GenericArg, SubstsRef};
use crate::ty::{self, Ty, TyCtxt};
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
use rustc_query_system::query::DefaultCacheSelector;
use rustc_span::symbol::Symbol;
use rustc_span::symbol::{Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};

/// The `Key` trait controls what types can legally be used as the key
Expand Down Expand Up @@ -149,6 +149,17 @@ impl Key for (LocalDefId, DefId) {
}
}

impl Key for (DefId, LocalDefId, Ident) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
self.0.krate
}
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
self.1.default_span(tcx)
}
}

impl Key for (CrateNum, DefId) {
type CacheSelector = DefaultCacheSelector;

Expand Down
19 changes: 13 additions & 6 deletions compiler/rustc_typeck/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {

fn default_constness_for_trait_bounds(&self) -> Constness;

/// Returns predicates in scope of the form `X: Foo`, where `X` is
/// a type parameter `X` with the given id `def_id`. This is a
/// subset of the full set of predicates.
/// Returns predicates in scope of the form `X: Foo<T>`, where `X`
/// is a type parameter `X` with the given id `def_id` and T
/// matches assoc_name. This is a subset of the full set of
/// predicates.
///
/// This is used for one specific purpose: resolving "short-hand"
/// associated type references like `T::Item`. In principle, we
Expand All @@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
/// but this can lead to cycle errors. The problem is that we have
/// to do this resolution *in order to create the predicates in
/// the first place*. Hence, we have this "special pass".
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
fn get_type_parameter_bounds(
&self,
span: Span,
def_id: DefId,
assoc_name: Ident,
) -> ty::GenericPredicates<'tcx>;

/// Returns the lifetime to use when a lifetime is omitted (and not elided).
fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
Expand Down Expand Up @@ -1361,8 +1367,9 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ty_param_def_id, assoc_name, span,
);

let predicates =
&self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
let predicates = &self
.get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name)
.predicates;

debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);

Expand Down
28 changes: 24 additions & 4 deletions compiler/rustc_typeck/src/check/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::subst::GenericArgKind;
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_session::Session;
use rustc_span::symbol::Ident;
use rustc_span::{self, Span};
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};

Expand Down Expand Up @@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
}
}

fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
fn get_type_parameter_bounds(
&self,
_: Span,
def_id: DefId,
assoc_name: Ident,
) -> ty::GenericPredicates<'tcx> {
let tcx = self.tcx;
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
let item_id = tcx.hir().ty_param_owner(hir_id);
Expand All @@ -196,9 +202,23 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
self.param_env.caller_bounds().iter().filter_map(|predicate| {
match predicate.skip_binders() {
ty::PredicateAtom::Trait(data, _) if data.self_ty().is_param(index) => {
// HACK(eddyb) should get the original `Span`.
let span = tcx.def_span(def_id);
Some((predicate, span))
let trait_did = data.def_id();
if tcx
.associated_items(trait_did)
.find_by_name_and_kind(
tcx,
assoc_name,
ty::AssocKind::Type,
trait_did,
)
.is_some()
{
// HACK(eddyb) should get the original `Span`.
let span = tcx.def_span(def_id);
Some((predicate, span))
} else {
None
}
}
_ => None,
}
Expand Down
113 changes: 102 additions & 11 deletions compiler/rustc_typeck/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,17 @@ impl AstConv<'tcx> for ItemCtxt<'tcx> {
}
}

fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
self.tcx.at(span).type_param_predicates((self.item_def_id, def_id.expect_local()))
fn get_type_parameter_bounds(
&self,
span: Span,
def_id: DefId,
assoc_name: Ident,
) -> ty::GenericPredicates<'tcx> {
self.tcx.at(span).type_param_predicates((
self.item_def_id,
def_id.expect_local(),
assoc_name,
))
}

fn re_infer(&self, _: Option<&ty::GenericParamDef>, _: Span) -> Option<ty::Region<'tcx>> {
Expand Down Expand Up @@ -492,7 +501,7 @@ fn get_new_lifetime_name<'tcx>(
/// `X: Foo` where `X` is the type parameter `def_id`.
fn type_param_predicates(
tcx: TyCtxt<'_>,
(item_def_id, def_id): (DefId, LocalDefId),
(item_def_id, def_id, assoc_name): (DefId, LocalDefId, Ident),
) -> ty::GenericPredicates<'_> {
use rustc_hir::*;

Expand All @@ -517,7 +526,7 @@ fn type_param_predicates(
let mut result = parent
.map(|parent| {
let icx = ItemCtxt::new(tcx, parent);
icx.get_type_parameter_bounds(DUMMY_SP, def_id.to_def_id())
icx.get_type_parameter_bounds(DUMMY_SP, def_id.to_def_id(), assoc_name)
})
.unwrap_or_default();
let mut extend = None;
Expand Down Expand Up @@ -560,12 +569,18 @@ fn type_param_predicates(

let icx = ItemCtxt::new(tcx, item_def_id);
let extra_predicates = extend.into_iter().chain(
icx.type_parameter_bounds_in_generics(ast_generics, param_id, ty, OnlySelfBounds(true))
.into_iter()
.filter(|(predicate, _)| match predicate.skip_binders() {
ty::PredicateAtom::Trait(data, _) => data.self_ty().is_param(index),
_ => false,
}),
icx.type_parameter_bounds_in_generics(
ast_generics,
param_id,
ty,
OnlySelfBounds(true),
Some(assoc_name),
)
.into_iter()
.filter(|(predicate, _)| match predicate.skip_binders() {
ty::PredicateAtom::Trait(data, _) => data.self_ty().is_param(index),
_ => false,
}),
);
result.predicates =
tcx.arena.alloc_from_iter(result.predicates.iter().copied().chain(extra_predicates));
Expand All @@ -583,6 +598,7 @@ impl ItemCtxt<'tcx> {
param_id: hir::HirId,
ty: Ty<'tcx>,
only_self_bounds: OnlySelfBounds,
assoc_name: Option<Ident>,
) -> Vec<(ty::Predicate<'tcx>, Span)> {
let constness = self.default_constness_for_trait_bounds();
let from_ty_params = ast_generics
Expand All @@ -593,6 +609,10 @@ impl ItemCtxt<'tcx> {
_ => None,
})
.flat_map(|bounds| bounds.iter())
.filter(|b| match assoc_name {
Some(assoc_name) => self.bound_defines_assoc_item(b, assoc_name),
None => true,
})
.flat_map(|b| predicates_from_bound(self, ty, b, constness));

let from_where_clauses = ast_generics
Expand All @@ -611,12 +631,43 @@ impl ItemCtxt<'tcx> {
} else {
None
};
bp.bounds.iter().filter_map(move |b| bt.map(|bt| (bt, b)))
bp.bounds
.iter()
.filter(|b| match assoc_name {
Some(assoc_name) => self.bound_defines_assoc_item(b, assoc_name),
None => true,
})
.filter_map(move |b| bt.map(|bt| (bt, b)))
})
.flat_map(|(bt, b)| predicates_from_bound(self, bt, b, constness));

from_ty_params.chain(from_where_clauses).collect()
}

fn bound_defines_assoc_item(&self, b: &hir::GenericBound<'_>, assoc_name: Ident) -> bool {
debug!("bound_defines_assoc_item(b={:?}, assoc_name={:?})", b, assoc_name);

match b {
hir::GenericBound::Trait(poly_trait_ref, _) => {
let trait_ref = &poly_trait_ref.trait_ref;
let trait_did = trait_ref.trait_def_id().unwrap();
let traits_did = super_traits_of(self.tcx, trait_did);

traits_did.iter().any(|trait_did| {
self.tcx
.associated_items(*trait_did)
.find_by_name_and_kind(
self.tcx,
assoc_name,
ty::AssocKind::Type,
*trait_did,
)
.is_some()
})
}
_ => false,
}
}
}

/// Tests whether this is the AST for a reference to the type
Expand Down Expand Up @@ -1017,6 +1068,7 @@ fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredi
item.hir_id,
self_param_ty,
OnlySelfBounds(!is_trait_alias),
None,
);

// Combine the two lists to form the complete set of superbounds:
Expand All @@ -1034,6 +1086,45 @@ fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredi
ty::GenericPredicates { parent: None, predicates: superbounds }
}

pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator<Item = DefId> {
let mut set = FxHashSet::default();
let mut stack = vec![trait_def_id];
while let Some(trait_did) = stack.pop() {
if !set.insert(trait_did) {
continue;
}

if trait_did.is_local() {
let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_did.expect_local());

let item = match tcx.hir().get(trait_hir_id) {
Node::Item(item) => item,
_ => bug!("super_trait_of {} is not an item", trait_hir_id),
};

let supertraits = match item.kind {
hir::ItemKind::Trait(.., ref supertraits, _) => supertraits,
hir::ItemKind::TraitAlias(_, ref supertraits) => supertraits,
_ => span_bug!(item.span, "super_trait_of invoked on non-trait"),
};

for supertrait in supertraits.iter() {
let trait_ref = supertrait.trait_ref();
if let Some(trait_did) = trait_ref.and_then(|trait_ref| trait_ref.trait_def_id()) {
stack.push(trait_did);
}
}
} else {
let generic_predicates = tcx.super_predicates_of(trait_did);
for (predicate, _) in generic_predicates.predicates {
if let ty::PredicateAtom::Trait(data, _) = predicate.skip_binders() {
stack.push(data.def_id());
}
}
}
}
}

fn trait_def(tcx: TyCtxt<'_>, def_id: DefId) -> ty::TraitDef {
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
let item = tcx.hir().expect_item(hir_id);
Expand Down
15 changes: 15 additions & 0 deletions src/test/ui/associated-type-bounds/super-trait-referencing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// check-pass
trait Foo {
type Item;
}

trait Bar<T> {}

fn baz<T>()
where
T: Foo,
T: Bar<T::Item>,
{
}

fn main() {}

0 comments on commit 24dcf6f

Please sign in to comment.