From 970276b5594a6205709b0f2d0f6af9b6a0121683 Mon Sep 17 00:00:00 2001 From: iDawer Date: Tue, 26 Apr 2022 19:25:10 +0500 Subject: [PATCH] 'inference': collect RPIT obligations Collect obligations from RPITs (Return Position `impl Trait`) of a function which is being inferred. This allows inferring {unknown}s from RPIT bounds. --- crates/hir-ty/src/infer.rs | 52 ++++++++++++++++++++++++++----- crates/hir-ty/src/tests/traits.rs | 36 ++++++++++++++++++--- 2 files changed, 75 insertions(+), 13 deletions(-) diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 7f5ad415096db..2a11c9d9bf15c 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use chalk_ir::{cast::Cast, ConstValue, DebruijnIndex, Mutability, Safety, Scalar, TypeFlags}; use hir_def::{ body::Body, - data::{ConstData, FunctionData, StaticData}, + data::{ConstData, StaticData}, expr::{BindingAnnotation, ExprId, PatId}, lang_item::LangItemTarget, path::{path, Path}, @@ -32,12 +32,13 @@ use hir_expand::name::{name, Name}; use itertools::Either; use la_arena::ArenaMap; use rustc_hash::FxHashMap; -use stdx::impl_from; +use stdx::{always, impl_from}; use crate::{ - db::HirDatabase, fold_tys_and_consts, infer::coerce::CoerceMany, lower::ImplTraitLoweringMode, - to_assoc_type_id, AliasEq, AliasTy, Const, DomainGoal, GenericArg, Goal, InEnvironment, - Interner, ProjectionTy, Substitution, TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind, + db::HirDatabase, fold_tys, fold_tys_and_consts, infer::coerce::CoerceMany, + lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Const, DomainGoal, + GenericArg, Goal, ImplTraitId, InEnvironment, Interner, ProjectionTy, Substitution, + TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind, }; // This lint has a false positive here. See the link below for details. @@ -64,7 +65,7 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc ctx.collect_const(&db.const_data(c)), - DefWithBodyId::FunctionId(f) => ctx.collect_fn(&db.function_data(f)), + DefWithBodyId::FunctionId(f) => ctx.collect_fn(f), DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_data(s)), } @@ -457,7 +458,8 @@ impl<'a> InferenceContext<'a> { self.return_ty = self.make_ty(&data.type_ref); } - fn collect_fn(&mut self, data: &FunctionData) { + fn collect_fn(&mut self, func: FunctionId) { + let data = self.db.function_data(func); let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver) .with_impl_trait_mode(ImplTraitLoweringMode::Param); let param_tys = @@ -474,8 +476,42 @@ impl<'a> InferenceContext<'a> { } else { &*data.ret_type }; - let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Disallowed); // FIXME implement RPIT + let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Opaque); self.return_ty = return_ty; + + if let Some(rpits) = self.db.return_type_impl_traits(func) { + // RPIT opaque types use substitution of their parent function. + let fn_placeholders = TyBuilder::placeholder_subst(self.db, func); + self.return_ty = fold_tys( + self.return_ty.clone(), + |ty, _| { + let opaque_ty_id = match ty.kind(Interner) { + TyKind::OpaqueType(opaque_ty_id, _) => *opaque_ty_id, + _ => return ty, + }; + let idx = match self.db.lookup_intern_impl_trait_id(opaque_ty_id.into()) { + ImplTraitId::ReturnTypeImplTrait(_, idx) => idx, + _ => unreachable!(), + }; + let bounds = (*rpits).map_ref(|rpits| { + rpits.impl_traits[idx as usize].bounds.map_ref(|it| it.into_iter()) + }); + let var = self.table.new_type_var(); + let var_subst = Substitution::from1(Interner, var.clone()); + for bound in bounds { + let predicate = + bound.map(|it| it.cloned()).substitute(Interner, &fn_placeholders); + let (var_predicate, binders) = predicate + .substitute(Interner, &var_subst) + .into_value_and_skipped_binders(); + always!(binders.len(Interner) == 0); // quantified where clauses not yet handled + self.push_obligation(var_predicate.cast(Interner)); + } + var + }, + DebruijnIndex::INNERMOST, + ); + } } fn infer_body(&mut self) { diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 5e58d5ad8381c..0b08aa4711ca3 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -1255,6 +1255,32 @@ fn test() { ); } +#[test] +fn infer_from_return_pos_impl_trait() { + check_infer_with_mismatches( + r#" +//- minicore: fn, sized +trait Trait {} +struct Bar(T); +impl Trait for Bar {} +fn foo() -> (impl FnOnce(&str, T), impl Trait) { + (|input, t| {}, Bar(C)) +} +"#, + expect![[r#" + 134..165 '{ ...(C)) }': (|&str, T| -> (), Bar) + 140..163 '(|inpu...ar(C))': (|&str, T| -> (), Bar) + 141..154 '|input, t| {}': |&str, T| -> () + 142..147 'input': &str + 149..150 't': T + 152..154 '{}': () + 156..159 'Bar': Bar(u8) -> Bar + 156..162 'Bar(C)': Bar + 160..161 'C': u8 + "#]], + ); +} + #[test] fn dyn_trait() { check_infer( @@ -2392,7 +2418,7 @@ fn test() -> impl Trait { 171..182 '{ loop {} }': T 173..180 'loop {}': ! 178..180 '{}': () - 213..309 '{ ...t()) }': S<{unknown}> + 213..309 '{ ...t()) }': S 223..225 's1': S 228..229 'S': S(u32) -> S 228..240 'S(default())': S @@ -2408,10 +2434,10 @@ fn test() -> impl Trait { 276..288 'S(default())': S 278..285 'default': fn default() -> i32 278..287 'default()': i32 - 295..296 'S': S<{unknown}>({unknown}) -> S<{unknown}> - 295..307 'S(default())': S<{unknown}> - 297..304 'default': fn default<{unknown}>() -> {unknown} - 297..306 'default()': {unknown} + 295..296 'S': S(i32) -> S + 295..307 'S(default())': S + 297..304 'default': fn default() -> i32 + 297..306 'default()': i32 "#]], ); }