From f2ef88ba064cd6799922e85ae0748d298ad436d1 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Fri, 19 Jan 2024 21:28:37 +0000 Subject: [PATCH] Consolidate logic around resolving built-in coroutine trait impls --- compiler/rustc_hir/src/lang_items.rs | 3 ++ compiler/rustc_middle/src/ty/instance.rs | 50 ++++++++++++++++++++ compiler/rustc_span/src/symbol.rs | 1 + compiler/rustc_ty_utils/src/instance.rs | 59 +----------------------- library/core/src/ops/coroutine.rs | 1 + 5 files changed, 56 insertions(+), 58 deletions(-) diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 1cc1f11b3c858..85d10872b3d1d 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -213,8 +213,11 @@ language_item_table! { Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0); Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0); AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0); + CoroutineState, sym::coroutine_state, coroutine_state, Target::Enum, GenericRequirement::None; Coroutine, sym::coroutine, coroutine_trait, Target::Trait, GenericRequirement::Minimum(1); + CoroutineResume, sym::coroutine_resume, coroutine_resume, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None; + Unpin, sym::unpin, unpin_trait, Target::Trait, GenericRequirement::None; Pin, sym::pin, pin_type, Target::Struct, GenericRequirement::None; diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index dd41cb5a61f44..b6c3c34078f76 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -3,6 +3,7 @@ use crate::ty::print::{FmtPrinter, Printer}; use crate::ty::{self, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable}; use crate::ty::{EarlyBinder, GenericArgs, GenericArgsRef, TypeVisitableExt}; use rustc_errors::ErrorGuaranteed; +use rustc_hir as hir; use rustc_hir::def::Namespace; use rustc_hir::def_id::{CrateNum, DefId}; use rustc_hir::lang_items::LangItem; @@ -11,6 +12,7 @@ use rustc_macros::HashStable; use rustc_middle::ty::normalize_erasing_regions::NormalizationError; use rustc_span::Symbol; +use std::assert_matches::assert_matches; use std::fmt; /// A monomorphized `InstanceDef`. @@ -572,6 +574,54 @@ impl<'tcx> Instance<'tcx> { Some(Instance { def, args }) } + pub fn try_resolve_item_for_coroutine( + tcx: TyCtxt<'tcx>, + trait_item_id: DefId, + trait_id: DefId, + rcvr_args: ty::GenericArgsRef<'tcx>, + ) -> Option> { + let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else { + return None; + }; + let coroutine_kind = tcx.coroutine_kind(coroutine_def_id).unwrap(); + + let lang_items = tcx.lang_items(); + let coroutine_callable_item = if Some(trait_id) == lang_items.future_trait() { + assert_matches!( + coroutine_kind, + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) + ); + hir::LangItem::FuturePoll + } else if Some(trait_id) == lang_items.iterator_trait() { + assert_matches!( + coroutine_kind, + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) + ); + hir::LangItem::IteratorNext + } else if Some(trait_id) == lang_items.async_iterator_trait() { + assert_matches!( + coroutine_kind, + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) + ); + hir::LangItem::AsyncIteratorPollNext + } else if Some(trait_id) == lang_items.coroutine_trait() { + assert_matches!(coroutine_kind, hir::CoroutineKind::Coroutine(_)); + hir::LangItem::CoroutineResume + } else { + return None; + }; + + if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) { + Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args }) + } else { + // All other methods should be defaulted methods of the built-in trait. + // This is important for `Iterator`'s combinators, but also useful for + // adding future default methods to `Future`, for instance. + debug_assert!(tcx.defaultness(trait_item_id).has_value()); + Some(Instance::new(trait_item_id, rcvr_args)) + } + } + /// Depending on the kind of `InstanceDef`, the MIR body associated with an /// instance is expressed in terms of the generic parameters of `self.def_id()`, and in other /// cases the MIR body is expressed in terms of the types found in the substitution array. diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 44795022cbab3..bebeecebbb643 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -600,6 +600,7 @@ symbols! { core_panic_macro, coroutine, coroutine_clone, + coroutine_resume, coroutine_state, coroutines, cosf32, diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index 81d5304b81265..e5e31f7caaa2b 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -245,63 +245,6 @@ fn resolve_associated_item<'tcx>( span: tcx.def_span(trait_item_id), }) } - } else if Some(trait_ref.def_id) == lang_items.future_trait() { - let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else { - bug!() - }; - if Some(trait_item_id) == tcx.lang_items().future_poll_fn() { - // `Future::poll` is generated by the compiler. - Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args }) - } else { - // All other methods are default methods of the `Future` trait. - // (this assumes that `ImplSource::Builtin` is only used for methods on `Future`) - debug_assert!(tcx.defaultness(trait_item_id).has_value()); - Some(Instance::new(trait_item_id, rcvr_args)) - } - } else if Some(trait_ref.def_id) == lang_items.iterator_trait() { - let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else { - bug!() - }; - if Some(trait_item_id) == tcx.lang_items().next_fn() { - // `Iterator::next` is generated by the compiler. - Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) - } else { - // All other methods are default methods of the `Iterator` trait. - // (this assumes that `ImplSource::Builtin` is only used for methods on `Iterator`) - debug_assert!(tcx.defaultness(trait_item_id).has_value()); - Some(Instance::new(trait_item_id, rcvr_args)) - } - } else if Some(trait_ref.def_id) == lang_items.async_iterator_trait() { - let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else { - bug!() - }; - - if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::poll_next { - span_bug!( - tcx.def_span(coroutine_def_id), - "no definition for `{trait_ref}::{}` for built-in coroutine type", - tcx.item_name(trait_item_id) - ) - } - - // `AsyncIterator::poll_next` is generated by the compiler. - Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) - } else if Some(trait_ref.def_id) == lang_items.coroutine_trait() { - let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else { - bug!() - }; - if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::resume { - // For compiler developers who'd like to add new items to `Coroutine`, - // you either need to generate a shim body, or perhaps return - // `InstanceDef::Item` pointing to a trait default method body if - // it is given a default implementation by the trait. - span_bug!( - tcx.def_span(coroutine_def_id), - "no definition for `{trait_ref}::{}` for built-in coroutine type", - tcx.item_name(trait_item_id) - ) - } - Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) } else if tcx.fn_trait_kind_from_def_id(trait_ref.def_id).is_some() { // FIXME: This doesn't check for malformed libcore that defines, e.g., // `trait Fn { fn call_once(&self) { .. } }`. This is mostly for extension @@ -334,7 +277,7 @@ fn resolve_associated_item<'tcx>( ), } } else { - None + Instance::try_resolve_item_for_coroutine(tcx, trait_item_id, trait_id, rcvr_args) } } traits::ImplSource::Param(..) diff --git a/library/core/src/ops/coroutine.rs b/library/core/src/ops/coroutine.rs index e58c9068af85c..6faded76a4a49 100644 --- a/library/core/src/ops/coroutine.rs +++ b/library/core/src/ops/coroutine.rs @@ -111,6 +111,7 @@ pub trait Coroutine { /// been returned previously. While coroutine literals in the language are /// guaranteed to panic on resuming after `Complete`, this is not guaranteed /// for all implementations of the `Coroutine` trait. + #[cfg_attr(not(bootstrap), lang = "coroutine_resume")] fn resume(self: Pin<&mut Self>, arg: R) -> CoroutineState; }