From 2ca4964db5d263a8f9222846bd70a7f26cf414cf Mon Sep 17 00:00:00 2001 From: Santiago Pastorino Date: Fri, 13 Nov 2020 14:01:16 -0300 Subject: [PATCH] Allow to self reference associated types in where clauses --- compiler/rustc_infer/src/traits/util.rs | 35 +++++ compiler/rustc_middle/src/query/mod.rs | 10 ++ compiler/rustc_middle/src/ty/query/keys.rs | 11 ++ .../rustc_trait_selection/src/traits/mod.rs | 3 +- compiler/rustc_typeck/src/astconv/mod.rs | 49 +++++- compiler/rustc_typeck/src/collect.rs | 141 +++++++++++------- .../rustc_typeck/src/collect/item_bounds.rs | 6 +- .../super-trait-referencing-self.rs | 12 ++ 8 files changed, 204 insertions(+), 63 deletions(-) create mode 100644 src/test/ui/associated-type-bounds/super-trait-referencing-self.rs diff --git a/compiler/rustc_infer/src/traits/util.rs b/compiler/rustc_infer/src/traits/util.rs index 8273c2d291d09..26f6fcd5dab68 100644 --- a/compiler/rustc_infer/src/traits/util.rs +++ b/compiler/rustc_infer/src/traits/util.rs @@ -4,6 +4,7 @@ use crate::traits::{Obligation, ObligationCause, PredicateObligation}; use rustc_data_structures::fx::FxHashSet; use rustc_middle::ty::outlives::Component; use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness}; +use rustc_span::symbol::Ident; pub fn anonymize_predicate<'tcx>( tcx: TyCtxt<'tcx>, @@ -89,6 +90,32 @@ pub fn elaborate_trait_refs<'tcx>( elaborate_predicates(tcx, predicates) } +pub fn elaborate_trait_refs_that_define_assoc_type<'tcx>( + tcx: TyCtxt<'tcx>, + trait_refs: impl Iterator>, + assoc_name: Ident, +) -> FxHashSet> { + let mut stack: Vec<_> = trait_refs.collect(); + let mut trait_refs = FxHashSet::default(); + + while let Some(trait_ref) = stack.pop() { + if trait_refs.insert(trait_ref) { + let super_predicates = + tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), Some(assoc_name))); + for (super_predicate, _) in super_predicates.predicates { + let bound_predicate = super_predicate.bound_atom(); + let subst_predicate = super_predicate + .subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder())); + if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() { + stack.push(binder.value); + } + } + } + } + + trait_refs +} + pub fn elaborate_predicates<'tcx>( tcx: TyCtxt<'tcx>, predicates: impl Iterator>, @@ -287,6 +314,14 @@ pub fn transitive_bounds<'tcx>( elaborate_trait_refs(tcx, bounds).filter_to_traits() } +pub fn transitive_bounds_that_define_assoc_type<'tcx>( + tcx: TyCtxt<'tcx>, + bounds: impl Iterator>, + assoc_name: Ident, +) -> FxHashSet> { + elaborate_trait_refs_that_define_assoc_type(tcx, bounds, assoc_name) +} + /////////////////////////////////////////////////////////////////////////// // Other /////////////////////////////////////////////////////////////////////////// diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index ed032220b54a2..c19943e1f8ccd 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -436,6 +436,16 @@ rustc_queries! { desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) } } + /// Maps from the `DefId` of a trait to the list of + /// super-predicates. This is a subset of the full list of + /// predicates. We store these in a separate map because we must + /// evaluate them even during type conversion, often before the + /// full predicates are available (note that supertraits have + /// additional acyclicity requirements). + query super_predicates_that_define_assoc_type(key: (DefId, Option)) -> ty::GenericPredicates<'tcx> { + desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key.0) } + } + /// To avoid cycles within the predicates of a single item we compute /// per-type-parameter predicates for resolving `T::AssocTy`. query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> { diff --git a/compiler/rustc_middle/src/ty/query/keys.rs b/compiler/rustc_middle/src/ty/query/keys.rs index 339a068205c8e..3949c303f727a 100644 --- a/compiler/rustc_middle/src/ty/query/keys.rs +++ b/compiler/rustc_middle/src/ty/query/keys.rs @@ -149,6 +149,17 @@ impl Key for (LocalDefId, DefId) { } } +impl Key for (DefId, Option) { + type CacheSelector = DefaultCacheSelector; + + fn query_crate(&self) -> CrateNum { + self.0.krate + } + fn default_span(&self, tcx: TyCtxt<'_>) -> Span { + tcx.def_span(self.0) + } +} + impl Key for (DefId, LocalDefId, Ident) { type CacheSelector = DefaultCacheSelector; diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs index 2d7df2ddd119d..d027628828441 100644 --- a/compiler/rustc_trait_selection/src/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/mod.rs @@ -65,7 +65,8 @@ pub use self::util::{ get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices, }; pub use self::util::{ - supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits, + supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type, + SupertraitDefIds, Supertraits, }; pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext; diff --git a/compiler/rustc_typeck/src/astconv/mod.rs b/compiler/rustc_typeck/src/astconv/mod.rs index e891ea3403f47..d8fbcb633cd3e 100644 --- a/compiler/rustc_typeck/src/astconv/mod.rs +++ b/compiler/rustc_typeck/src/astconv/mod.rs @@ -6,6 +6,7 @@ mod errors; mod generics; use crate::bounds::Bounds; +use crate::collect::super_traits_of; use crate::collect::PlaceholderHirTyCollector; use crate::errors::{ AmbiguousLifetimeBound, MultipleRelaxedDefaultBounds, TraitObjectDeclaredWithNoTraits, @@ -768,7 +769,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o { } // Returns `true` if a bounds list includes `?Sized`. - pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool { + pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool { let tcx = self.tcx(); // Try to find an unbound in bounds. @@ -826,7 +827,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o { fn add_bounds( &self, param_ty: Ty<'tcx>, - ast_bounds: &[hir::GenericBound<'_>], + ast_bounds: &[&hir::GenericBound<'_>], bounds: &mut Bounds<'tcx>, ) { let mut trait_bounds = Vec::new(); @@ -844,7 +845,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o { hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {} hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self .instantiate_lang_item_trait_ref( - lang_item, span, hir_id, args, param_ty, bounds, + *lang_item, *span, *hir_id, args, param_ty, bounds, ), hir::GenericBound::Outlives(ref l) => region_bounds.push(l), } @@ -878,7 +879,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o { pub fn compute_bounds( &self, param_ty: Ty<'tcx>, - ast_bounds: &[hir::GenericBound<'_>], + ast_bounds: &[&hir::GenericBound<'_>], sized_by_default: SizedByDefault, span: Span, ) -> Bounds<'tcx> { @@ -896,6 +897,39 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o { bounds } + pub fn compute_bounds_that_match_assoc_type( + &self, + param_ty: Ty<'tcx>, + ast_bounds: &[hir::GenericBound<'_>], + sized_by_default: SizedByDefault, + span: Span, + assoc_name: Ident, + ) -> Bounds<'tcx> { + let mut result = Vec::new(); + + for ast_bound in ast_bounds { + if let Some(trait_ref) = ast_bound.trait_ref() { + if let Some(trait_did) = trait_ref.trait_def_id() { + if super_traits_of(self.tcx(), trait_did).any(|trait_did| { + self.tcx() + .associated_items(trait_did) + .find_by_name_and_kind( + self.tcx(), + assoc_name, + ty::AssocKind::Type, + trait_did, + ) + .is_some() + }) { + result.push(ast_bound); + } + } + } + } + + self.compute_bounds(param_ty, &result, sized_by_default, span) + } + /// Given an HIR binding like `Item = Foo` or `Item: Foo`, pushes the corresponding predicates /// onto `bounds`. /// @@ -1050,7 +1084,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o { // Calling `skip_binder` is okay, because `add_bounds` expects the `param_ty` // parameter to have a skipped binder. let param_ty = tcx.mk_projection(assoc_ty.def_id, candidate.skip_binder().substs); - self.add_bounds(param_ty, ast_bounds, bounds); + let ast_bounds: Vec<_> = ast_bounds.iter().collect(); + self.add_bounds(param_ty, &ast_bounds, bounds); } } Ok(()) @@ -1377,12 +1412,14 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o { let param_name = tcx.hir().ty_param_name(param_hir_id); self.one_bound_for_assoc_type( || { - traits::transitive_bounds( + traits::transitive_bounds_that_define_assoc_type( tcx, predicates.iter().filter_map(|(p, _)| { p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value) }), + assoc_name, ) + .into_iter() }, || param_name.to_string(), assoc_name, diff --git a/compiler/rustc_typeck/src/collect.rs b/compiler/rustc_typeck/src/collect.rs index 756147ca54c2c..8b457c7ceec27 100644 --- a/compiler/rustc_typeck/src/collect.rs +++ b/compiler/rustc_typeck/src/collect.rs @@ -1,3 +1,4 @@ +// ignore-tidy-filelength //! "Collection" is the process of determining the type and other external //! details of each item in Rust. Collection is specifically concerned //! with *inter-procedural* things -- for example, for a function @@ -79,6 +80,7 @@ pub fn provide(providers: &mut Providers) { projection_ty_from_predicates, explicit_predicates_of, super_predicates_of, + super_predicates_that_define_assoc_type, trait_explicit_predicates_and_bounds, type_param_predicates, trait_def, @@ -651,17 +653,10 @@ impl ItemCtxt<'tcx> { hir::GenericBound::Trait(poly_trait_ref, _) => { let trait_ref = &poly_trait_ref.trait_ref; let trait_did = trait_ref.trait_def_id().unwrap(); - let traits_did = super_traits_of(self.tcx, trait_did); - - traits_did.iter().any(|trait_did| { + super_traits_of(self.tcx, trait_did).any(|trait_did| { self.tcx - .associated_items(*trait_did) - .find_by_name_and_kind( - self.tcx, - assoc_name, - ty::AssocKind::Type, - *trait_did, - ) + .associated_items(trait_did) + .find_by_name_and_kind(self.tcx, assoc_name, ty::AssocKind::Type, trait_did) .is_some() }) } @@ -1035,55 +1030,91 @@ fn adt_def(tcx: TyCtxt<'_>, def_id: DefId) -> &ty::AdtDef { /// the transitive super-predicates are converted. fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredicates<'_> { debug!("super_predicates(trait_def_id={:?})", trait_def_id); - let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_def_id.expect_local()); + tcx.super_predicates_that_define_assoc_type((trait_def_id, None)) +} - let item = match tcx.hir().get(trait_hir_id) { - Node::Item(item) => item, - _ => bug!("trait_node_id {} is not an item", trait_hir_id), - }; +/// Ensures that the super-predicates of the trait with a `DefId` +/// of `trait_def_id` are converted and stored. This also ensures that +/// the transitive super-predicates are converted. +fn super_predicates_that_define_assoc_type( + tcx: TyCtxt<'_>, + (trait_def_id, assoc_name): (DefId, Option), +) -> ty::GenericPredicates<'_> { + debug!( + "super_predicates_that_define_assoc_type(trait_def_id={:?}, assoc_name={:?})", + trait_def_id, assoc_name + ); + if trait_def_id.is_local() { + debug!("super_predicates_that_define_assoc_type: local trait_def_id={:?}", trait_def_id); + let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_def_id.expect_local()); - let (generics, bounds) = match item.kind { - hir::ItemKind::Trait(.., ref generics, ref supertraits, _) => (generics, supertraits), - hir::ItemKind::TraitAlias(ref generics, ref supertraits) => (generics, supertraits), - _ => span_bug!(item.span, "super_predicates invoked on non-trait"), - }; + let item = match tcx.hir().get(trait_hir_id) { + Node::Item(item) => item, + _ => bug!("trait_node_id {} is not an item", trait_hir_id), + }; - let icx = ItemCtxt::new(tcx, trait_def_id); - - // Convert the bounds that follow the colon, e.g., `Bar + Zed` in `trait Foo: Bar + Zed`. - let self_param_ty = tcx.types.self_param; - let superbounds1 = - AstConv::compute_bounds(&icx, self_param_ty, bounds, SizedByDefault::No, item.span); - - let superbounds1 = superbounds1.predicates(tcx, self_param_ty); - - // Convert any explicit superbounds in the where-clause, - // e.g., `trait Foo where Self: Bar`. - // In the case of trait aliases, however, we include all bounds in the where-clause, - // so e.g., `trait Foo = where u32: PartialEq` would include `u32: PartialEq` - // as one of its "superpredicates". - let is_trait_alias = tcx.is_trait_alias(trait_def_id); - let superbounds2 = icx.type_parameter_bounds_in_generics( - generics, - item.hir_id, - self_param_ty, - OnlySelfBounds(!is_trait_alias), - None, - ); + let (generics, bounds) = match item.kind { + hir::ItemKind::Trait(.., ref generics, ref supertraits, _) => (generics, supertraits), + hir::ItemKind::TraitAlias(ref generics, ref supertraits) => (generics, supertraits), + _ => span_bug!(item.span, "super_predicates invoked on non-trait"), + }; - // Combine the two lists to form the complete set of superbounds: - let superbounds = &*tcx.arena.alloc_from_iter(superbounds1.into_iter().chain(superbounds2)); + let icx = ItemCtxt::new(tcx, trait_def_id); - // Now require that immediate supertraits are converted, - // which will, in turn, reach indirect supertraits. - for &(pred, span) in superbounds { - debug!("superbound: {:?}", pred); - if let ty::PredicateAtom::Trait(bound, _) = pred.skip_binders() { - tcx.at(span).super_predicates_of(bound.def_id()); + // Convert the bounds that follow the colon, e.g., `Bar + Zed` in `trait Foo: Bar + Zed`. + let self_param_ty = tcx.types.self_param; + let superbounds1 = if let Some(assoc_name) = assoc_name { + AstConv::compute_bounds_that_match_assoc_type( + &icx, + self_param_ty, + &bounds, + SizedByDefault::No, + item.span, + assoc_name, + ) + } else { + let bounds: Vec<_> = bounds.iter().collect(); + AstConv::compute_bounds(&icx, self_param_ty, &bounds, SizedByDefault::No, item.span) + }; + + let superbounds1 = superbounds1.predicates(tcx, self_param_ty); + + // Convert any explicit superbounds in the where-clause, + // e.g., `trait Foo where Self: Bar`. + // In the case of trait aliases, however, we include all bounds in the where-clause, + // so e.g., `trait Foo = where u32: PartialEq` would include `u32: PartialEq` + // as one of its "superpredicates". + let is_trait_alias = tcx.is_trait_alias(trait_def_id); + let superbounds2 = icx.type_parameter_bounds_in_generics( + generics, + item.hir_id, + self_param_ty, + OnlySelfBounds(!is_trait_alias), + assoc_name, + ); + + // Combine the two lists to form the complete set of superbounds: + let superbounds = &*tcx.arena.alloc_from_iter(superbounds1.into_iter().chain(superbounds2)); + + // Now require that immediate supertraits are converted, + // which will, in turn, reach indirect supertraits. + if assoc_name.is_none() { + // FIXME: move this into the `super_predicates_of` query + for &(pred, span) in superbounds { + debug!("superbound: {:?}", pred); + if let ty::PredicateAtom::Trait(bound, _) = pred.skip_binders() { + tcx.at(span).super_predicates_of(bound.def_id()); + } + } } - } - ty::GenericPredicates { parent: None, predicates: superbounds } + ty::GenericPredicates { parent: None, predicates: superbounds } + } else { + // if `assoc_name` is None, then the query should've been redirected to an + // external provider + assert!(assoc_name.is_some()); + tcx.super_predicates_of(trait_def_id) + } } pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator { @@ -1123,6 +1154,8 @@ pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator, def_id: DefId) -> ty::TraitDef { @@ -1976,8 +2009,8 @@ fn gather_explicit_predicates_of(tcx: TyCtxt<'_>, def_id: DefId) -> ty::GenericP index += 1; let sized = SizedByDefault::Yes; - let bounds = - AstConv::compute_bounds(&icx, param_ty, ¶m.bounds, sized, param.span); + let bounds: Vec<_> = param.bounds.iter().collect(); + let bounds = AstConv::compute_bounds(&icx, param_ty, &bounds, sized, param.span); predicates.extend(bounds.predicates(tcx, param_ty)); } GenericParamKind::Const { .. } => { diff --git a/compiler/rustc_typeck/src/collect/item_bounds.rs b/compiler/rustc_typeck/src/collect/item_bounds.rs index e596dd1a396c9..62586d793b468 100644 --- a/compiler/rustc_typeck/src/collect/item_bounds.rs +++ b/compiler/rustc_typeck/src/collect/item_bounds.rs @@ -25,10 +25,11 @@ fn associated_type_bounds<'tcx>( InternalSubsts::identity_for_item(tcx, assoc_item_def_id), ); + let bounds: Vec<_> = bounds.iter().collect(); let bounds = AstConv::compute_bounds( &ItemCtxt::new(tcx, assoc_item_def_id), item_ty, - bounds, + &bounds, SizedByDefault::Yes, span, ); @@ -65,10 +66,11 @@ fn opaque_type_bounds<'tcx>( let item_ty = tcx.mk_opaque(opaque_def_id, InternalSubsts::identity_for_item(tcx, opaque_def_id)); + let bounds: Vec<_> = bounds.iter().collect(); let bounds = AstConv::compute_bounds( &ItemCtxt::new(tcx, opaque_def_id), item_ty, - bounds, + &bounds, SizedByDefault::Yes, span, ) diff --git a/src/test/ui/associated-type-bounds/super-trait-referencing-self.rs b/src/test/ui/associated-type-bounds/super-trait-referencing-self.rs new file mode 100644 index 0000000000000..c82ec01f4d61d --- /dev/null +++ b/src/test/ui/associated-type-bounds/super-trait-referencing-self.rs @@ -0,0 +1,12 @@ +// check-pass +trait Foo { + type Bar; +} +trait Qux: Foo + AsRef {} +trait Foo2 {} + +trait Qux2: Foo2 + AsRef { + type Bar; +} + +fn main() {}