From 63a0d557570a9cd176eb157631d3ff8168713953 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Fri, 1 Mar 2024 15:58:12 +0000 Subject: [PATCH] Implement async closure signature deduction --- compiler/rustc_hir_typeck/src/closure.rs | 74 +++++++++++++++--------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index 5bdd9412d0e51..b6bfc81bc1975 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -56,18 +56,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // It's always helpful for inference if we know the kind of // closure sooner rather than later, so first examine the expected // type, and see if can glean a closure kind from there. - let (expected_sig, expected_kind) = match closure.kind { - hir::ClosureKind::Closure => match expected.to_option(self) { - Some(ty) => { - self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty)) - } - None => (None, None), - }, - // We don't want to deduce a signature from `Fn` bounds for coroutines - // or coroutine-closures, because the former does not implement `Fn` - // ever, and the latter's signature doesn't correspond to the coroutine - // type that it returns. - hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => (None, None), + let (expected_sig, expected_kind) = match expected.to_option(self) { + Some(ty) => self.deduce_closure_signature( + self.try_structurally_resolve_type(expr_span, ty), + closure.kind, + ), + None => (None, None), }; let ClosureSignatures { bound_sig, mut liberated_sig } = @@ -323,11 +317,13 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { fn deduce_closure_signature( &self, expected_ty: Ty<'tcx>, + closure_kind: hir::ClosureKind, ) -> (Option>, Option) { match *expected_ty.kind() { ty::Alias(ty::Opaque, ty::AliasTy { def_id, args, .. }) => self .deduce_closure_signature_from_predicates( expected_ty, + closure_kind, self.tcx .explicit_item_bounds(def_id) .iter_instantiated_copied(self.tcx, args) @@ -336,7 +332,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ty::Dynamic(object_type, ..) => { let sig = object_type.projection_bounds().find_map(|pb| { let pb = pb.with_self_ty(self.tcx, self.tcx.types.trait_object_dummy_self); - self.deduce_sig_from_projection(None, pb) + self.deduce_sig_from_projection(None, closure_kind, pb) }); let kind = object_type .principal_def_id() @@ -345,12 +341,18 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } ty::Infer(ty::TyVar(vid)) => self.deduce_closure_signature_from_predicates( Ty::new_var(self.tcx, self.root_var(vid)), + closure_kind, self.obligations_for_self_ty(vid).map(|obl| (obl.predicate, obl.cause.span)), ), - ty::FnPtr(sig) => { - let expected_sig = ExpectedSig { cause_span: None, sig }; - (Some(expected_sig), Some(ty::ClosureKind::Fn)) - } + ty::FnPtr(sig) => match closure_kind { + hir::ClosureKind::Closure => { + let expected_sig = ExpectedSig { cause_span: None, sig }; + (Some(expected_sig), Some(ty::ClosureKind::Fn)) + } + hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => { + (None, None) + } + }, _ => (None, None), } } @@ -358,6 +360,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { fn deduce_closure_signature_from_predicates( &self, expected_ty: Ty<'tcx>, + closure_kind: hir::ClosureKind, predicates: impl DoubleEndedIterator, Span)>, ) -> (Option>, Option) { let mut expected_sig = None; @@ -386,6 +389,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { span, self.deduce_sig_from_projection( Some(span), + closure_kind, bound_predicate.rebind(proj_predicate), ), ); @@ -422,13 +426,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ty::PredicateKind::Clause(ty::ClauseKind::Trait(data)) => Some(data.def_id()), _ => None, }; - if let Some(closure_kind) = - trait_def_id.and_then(|def_id| self.tcx.fn_trait_kind_from_def_id(def_id)) - { - expected_kind = Some( - expected_kind - .map_or_else(|| closure_kind, |current| cmp::min(current, closure_kind)), - ); + + if let Some(trait_def_id) = trait_def_id { + let found_kind = match closure_kind { + hir::ClosureKind::Closure => self.tcx.fn_trait_kind_from_def_id(trait_def_id), + hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => { + self.tcx.async_fn_trait_kind_from_def_id(trait_def_id) + } + _ => None, + }; + + if let Some(found_kind) = found_kind { + expected_kind = Some( + expected_kind + .map_or_else(|| found_kind, |current| cmp::min(current, found_kind)), + ); + } } } @@ -445,14 +458,21 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { fn deduce_sig_from_projection( &self, cause_span: Option, + closure_kind: hir::ClosureKind, projection: ty::PolyProjectionPredicate<'tcx>, ) -> Option> { let tcx = self.tcx; let trait_def_id = projection.trait_def_id(tcx); - // For now, we only do signature deduction based off of the `Fn` traits. - if !tcx.is_fn_trait(trait_def_id) { - return None; + + // For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits, + // for closures and async closures, respectively. + match closure_kind { + hir::ClosureKind::Closure + if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => {} + hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) + if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() => {} + _ => return None, } let arg_param_ty = projection.skip_binder().projection_ty.args.type_at(1);