diff --git a/Cargo.lock b/Cargo.lock index 687a951ed9fe..3e4d62a7d033 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2898,6 +2898,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-stream", "tokio-util", "tracing", "url", diff --git a/Cargo.toml b/Cargo.toml index 0dd381a93725..aadafe932fa0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,6 +90,7 @@ target-lexicon = { version = "0.12.13" } tempfile = { version = "3.9.0" } textwrap = { version = "0.15.2" } thiserror = { version = "1.0.56" } +tokio-stream = { version = "0.1.14" } tl = { version = "0.7.7" } tokio = { version = "1.35.1", features = ["rt-multi-thread"] } tokio-tar = { version = "0.3.1" } diff --git a/crates/puffin-client/src/cached_client.rs b/crates/puffin-client/src/cached_client.rs index 2cb34f5c413f..4c82ae31b466 100644 --- a/crates/puffin-client/src/cached_client.rs +++ b/crates/puffin-client/src/cached_client.rs @@ -104,7 +104,7 @@ impl CachedClient { /// client. #[instrument(skip_all)] pub async fn get_cached_with_callback< - Payload: Serialize + DeserializeOwned + Send, + Payload: Serialize + DeserializeOwned + Send + 'static, CallBackError, Callback, CallbackReturn, @@ -172,7 +172,7 @@ impl CachedClient { } } - async fn read_cache( + async fn read_cache( cache_entry: &CacheEntry, ) -> Option> { let read_span = info_span!("read_cache", file = %cache_entry.path().display()); @@ -185,8 +185,12 @@ impl CachedClient { "parse_cache", path = %cache_entry.path().display() ); - let parse_result = parse_span - .in_scope(|| rmp_serde::from_slice::>(&cached)); + let parse_result = tokio::task::spawn_blocking(move || { + parse_span + .in_scope(|| rmp_serde::from_slice::>(&cached)) + }) + .await + .expect("Tokio executor failed, was there a panic?"); match parse_result { Ok(data) => Some(data), Err(err) => { diff --git a/crates/puffin-resolver/Cargo.toml b/crates/puffin-resolver/Cargo.toml index 67064f539f82..a716a75ea15d 100644 --- a/crates/puffin-resolver/Cargo.toml +++ b/crates/puffin-resolver/Cargo.toml @@ -54,6 +54,7 @@ sha2 = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["macros"] } +tokio-stream = { workspace = true } tokio-util = { workspace = true, features = ["compat"] } tracing = { workspace = true } url = { workspace = true } diff --git a/crates/puffin-resolver/src/error.rs b/crates/puffin-resolver/src/error.rs index d6f18685b3ba..6e44f0315dc2 100644 --- a/crates/puffin-resolver/src/error.rs +++ b/crates/puffin-resolver/src/error.rs @@ -24,14 +24,11 @@ pub enum ResolveError { #[error("Failed to find a version of {0} that satisfies the requirement")] NotFound(Requirement), - #[error("The request stream terminated unexpectedly")] - StreamTermination, - #[error(transparent)] Client(#[from] puffin_client::Error), - #[error(transparent)] - TrySend(#[from] futures::channel::mpsc::SendError), + #[error("The channel is closed, was there a panic?")] + ChannelClosed, #[error(transparent)] Join(#[from] tokio::task::JoinError), @@ -88,9 +85,11 @@ pub enum ResolveError { Failure(String), } -impl From> for ResolveError { - fn from(value: futures::channel::mpsc::TrySendError) -> Self { - value.into_send_error().into() +impl From> for ResolveError { + /// Drop the value we want to send to not leak the private type we're sending. + /// The tokio error only says "channel closed", so we don't lose information. + fn from(_value: tokio::sync::mpsc::error::SendError) -> Self { + Self::ChannelClosed } } diff --git a/crates/puffin-resolver/src/resolver/mod.rs b/crates/puffin-resolver/src/resolver/mod.rs index 74155a33a482..f320bd2ac1f0 100644 --- a/crates/puffin-resolver/src/resolver/mod.rs +++ b/crates/puffin-resolver/src/resolver/mod.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use anyhow::Result; use dashmap::{DashMap, DashSet}; -use futures::channel::mpsc::UnboundedReceiver; use futures::{FutureExt, StreamExt}; use itertools::Itertools; use pubgrub::error::PubGrubError; @@ -14,6 +13,7 @@ use pubgrub::solver::{Incompatibility, State}; use pubgrub::type_aliases::DependencyConstraints; use rustc_hash::{FxHashMap, FxHashSet}; use tokio::select; +use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, info_span, instrument, trace, Instrument}; use url::Url; @@ -202,7 +202,8 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { pub async fn resolve(self) -> Result { // A channel to fetch package metadata (e.g., given `flask`, fetch all versions) and version // metadata (e.g., given `flask==1.0.0`, fetch the metadata for that version). - let (request_sink, request_stream) = futures::channel::mpsc::unbounded(); + // Channel size is set to the same size as the task buffer for simplicity. + let (request_sink, request_stream) = tokio::sync::mpsc::channel(50); // Run the fetcher. let requests_fut = self.fetch(request_stream).fuse(); @@ -213,7 +214,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { let resolution = select! { result = requests_fut => { result?; - return Err(ResolveError::StreamTermination); + return Err(ResolveError::ChannelClosed); } resolution = resolve_fut => { resolution.map_err(|err| { @@ -241,7 +242,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { #[instrument(skip_all)] async fn solve( &self, - request_sink: &futures::channel::mpsc::UnboundedSender, + request_sink: &tokio::sync::mpsc::Sender, ) -> Result { let root = PubGrubPackage::Root(self.project.clone()); @@ -265,7 +266,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { state.unit_propagation(next)?; // Pre-visit all candidate packages, to allow metadata to be fetched in parallel. - Self::pre_visit(state.partial_solution.prioritized_packages(), request_sink)?; + Self::pre_visit(state.partial_solution.prioritized_packages(), request_sink).await?; // Choose a package version. let Some(highest_priority_pkg) = @@ -386,7 +387,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { &self, package: &PubGrubPackage, priorities: &mut PubGrubPriorities, - request_sink: &futures::channel::mpsc::UnboundedSender, + request_sink: &tokio::sync::mpsc::Sender, ) -> Result<(), ResolveError> { match package { PubGrubPackage::Root(_) => {} @@ -395,10 +396,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { // Emit a request to fetch the metadata for this package. if self.index.packages.register(package_name.clone()) { priorities.add(package_name.clone()); - request_sink.unbounded_send(Request::Package(package_name.clone()))?; - - // Yield to allow subscribers to continue, as the channel is sync. - tokio::task::yield_now().await; + request_sink + .send(Request::Package(package_name.clone())) + .await?; } } PubGrubPackage::Package(package_name, _extra, Some(url)) => { @@ -406,10 +406,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { let dist = Dist::from_url(package_name.clone(), url.clone())?; if self.index.distributions.register(dist.package_id()) { priorities.add(dist.name().clone()); - request_sink.unbounded_send(Request::Dist(dist))?; - - // Yield to allow subscribers to continue, as the channel is sync. - tokio::task::yield_now().await; + request_sink.send(Request::Dist(dist)).await?; } } } @@ -418,9 +415,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { /// Visit the set of [`PubGrubPackage`] candidates prior to selection. This allows us to fetch /// metadata for all of the packages in parallel. - fn pre_visit<'data>( + async fn pre_visit<'data>( packages: impl Iterator)>, - request_sink: &futures::channel::mpsc::UnboundedSender, + request_sink: &tokio::sync::mpsc::Sender, ) -> Result<(), ResolveError> { // Iterate over the potential packages, and fetch file metadata for any of them. These // represent our current best guesses for the versions that we _might_ select. @@ -428,7 +425,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { let PubGrubPackage::Package(package_name, _extra, None) = package else { continue; }; - request_sink.unbounded_send(Request::Prefetch(package_name.clone(), range.clone()))?; + request_sink + .send(Request::Prefetch(package_name.clone(), range.clone())) + .await?; } Ok(()) } @@ -441,9 +440,9 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { package: &PubGrubPackage, range: &Range, pins: &mut FilePins, - request_sink: &futures::channel::mpsc::UnboundedSender, + request_sink: &tokio::sync::mpsc::Sender, ) -> Result, ResolveError> { - return match package { + match package { PubGrubPackage::Root(_) => Ok(Some(MIN_VERSION.clone())), PubGrubPackage::Python(PubGrubPython::Installed) => { @@ -576,24 +575,22 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { // Emit a request to fetch the metadata for this version. if self.index.distributions.register(candidate.package_id()) { let dist = candidate.resolve().dist.clone(); - request_sink.unbounded_send(Request::Dist(dist))?; - - // Yield to allow subscribers to continue, as the channel is sync. - tokio::task::yield_now().await; + request_sink.send(Request::Dist(dist)).await?; } Ok(Some(version)) } - }; + } } /// Given a candidate package and version, return its dependencies. + #[instrument(skip_all, fields(%package, %version))] async fn get_dependencies( &self, package: &PubGrubPackage, version: &Version, priorities: &mut PubGrubPriorities, - request_sink: &futures::channel::mpsc::UnboundedSender, + request_sink: &tokio::sync::mpsc::Sender, ) -> Result { match package { PubGrubPackage::Root(_) => { @@ -724,8 +721,11 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { } /// Fetch the metadata for a stream of packages and versions. - async fn fetch(&self, request_stream: UnboundedReceiver) -> Result<(), ResolveError> { - let mut response_stream = request_stream + async fn fetch( + &self, + request_stream: tokio::sync::mpsc::Receiver, + ) -> Result<(), ResolveError> { + let mut response_stream = ReceiverStream::new(request_stream) .map(|request| self.process_request(request).boxed()) .buffer_unordered(50); @@ -769,9 +769,6 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { } None => {} } - - // Yield to allow subscribers to continue, as the channel is sync. - tokio::task::yield_now().await; } Ok::<(), ResolveError>(()) @@ -902,7 +899,7 @@ impl<'a, Provider: ResolverProvider> Resolver<'a, Provider> { /// Fetch the metadata for an item #[derive(Debug)] #[allow(clippy::large_enum_variant)] -enum Request { +pub(crate) enum Request { /// A request to fetch the metadata for a package. Package(PackageName), /// A request to fetch the metadata for a built or source distribution. @@ -915,10 +912,10 @@ impl Display for Request { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Request::Package(package_name) => { - write!(f, "Package {package_name}") + write!(f, "Versions {package_name}") } Request::Dist(dist) => { - write!(f, "Dist {dist}") + write!(f, "Metadata {dist}") } Request::Prefetch(package_name, range) => { write!(f, "Prefetch {package_name} {range}") diff --git a/crates/puffin-resolver/src/resolver/provider.rs b/crates/puffin-resolver/src/resolver/provider.rs index b8f7e26a8e2a..297671869df6 100644 --- a/crates/puffin-resolver/src/resolver/provider.rs +++ b/crates/puffin-resolver/src/resolver/provider.rs @@ -1,8 +1,9 @@ use std::future::Future; +use std::ops::Deref; +use std::sync::Arc; use anyhow::Result; use chrono::{DateTime, Utc}; -use futures::FutureExt; use url::Url; use distribution_types::Dist; @@ -45,17 +46,30 @@ pub trait ResolverProvider: Send + Sync { /// The main IO backend for the resolver, which does cached requests network requests using the /// [`RegistryClient`] and [`DistributionDatabase`]. pub struct DefaultResolverProvider<'a, Context: BuildContext + Send + Sync> { - /// The [`RegistryClient`] used to query the index. - client: &'a RegistryClient, /// The [`DistributionDatabase`] used to build source distributions. fetcher: DistributionDatabase<'a, Context>, + /// Allow moving the parameters to `VersionMap::from_metadata` to a different thread. + inner: Arc, +} + +pub struct DefaultResolverProviderInner { + /// The [`RegistryClient`] used to query the index. + client: RegistryClient, /// These are the entries from `--find-links` that act as overrides for index responses. - flat_index: &'a FlatIndex, - tags: &'a Tags, + flat_index: FlatIndex, + tags: Tags, python_requirement: PythonRequirement, exclude_newer: Option>, allowed_yanks: AllowedYanks, - no_binary: &'a NoBinary, + no_binary: NoBinary, +} + +impl<'a, Context: BuildContext + Send + Sync> Deref for DefaultResolverProvider<'a, Context> { + type Target = DefaultResolverProviderInner; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } } impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Context> { @@ -72,14 +86,16 @@ impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Contex no_binary: &'a NoBinary, ) -> Self { Self { - client, fetcher, - flat_index, - tags, - python_requirement, - exclude_newer, - allowed_yanks, - no_binary, + inner: Arc::new(DefaultResolverProviderInner { + client: client.clone(), + flat_index: flat_index.clone(), + tags: tags.clone(), + python_requirement, + exclude_newer, + allowed_yanks, + no_binary: no_binary.clone(), + }), } } } @@ -87,43 +103,48 @@ impl<'a, Context: BuildContext + Send + Sync> DefaultResolverProvider<'a, Contex impl<'a, Context: BuildContext + Send + Sync> ResolverProvider for DefaultResolverProvider<'a, Context> { - fn get_version_map<'io>( - &'io self, - package_name: &'io PackageName, - ) -> impl Future + Send + 'io { - self.client - .simple(package_name) - .map(move |result| match result { - Ok((index, metadata)) => Ok(VersionMap::from_metadata( - metadata, - package_name, - &index, - self.tags, - &self.python_requirement, - &self.allowed_yanks, - self.exclude_newer.as_ref(), - self.flat_index.get(package_name).cloned(), - self.no_binary, - )), - Err(err) => match err.into_kind() { - kind @ (puffin_client::ErrorKind::PackageNotFound(_) - | puffin_client::ErrorKind::NoIndex(_)) => { - if let Some(flat_index) = self.flat_index.get(package_name).cloned() { - Ok(VersionMap::from(flat_index)) - } else { - Err(kind.into()) - } + /// Make a simple api request for the package and convert the result to a [`VersionMap`]. + async fn get_version_map<'io>(&'io self, package_name: &'io PackageName) -> VersionMapResponse { + let result = self.client.simple(package_name).await; + + // If the simple api request was successful, perform on the slow conversion to `VersionMap` on the tokio + // threadpool + match result { + Ok((index, metadata)) => { + let self_send = self.inner.clone(); + let package_name_owned = package_name.clone(); + Ok(tokio::task::spawn_blocking(move || { + VersionMap::from_metadata( + metadata, + &package_name_owned, + &index, + &self_send.tags, + &self_send.python_requirement, + &self_send.allowed_yanks, + self_send.exclude_newer.as_ref(), + self_send.flat_index.get(&package_name_owned).cloned(), + &self_send.no_binary, + ) + }) + .await + .expect("Tokio executor failed, was there a panic?")) + } + Err(err) => match err.into_kind() { + kind @ (puffin_client::ErrorKind::PackageNotFound(_) + | puffin_client::ErrorKind::NoIndex(_)) => { + if let Some(flat_index) = self.flat_index.get(package_name).cloned() { + Ok(VersionMap::from(flat_index)) + } else { + Err(kind.into()) } - kind => Err(kind.into()), - }, - }) + } + kind => Err(kind.into()), + }, + } } - fn get_or_build_wheel_metadata<'io>( - &'io self, - dist: &'io Dist, - ) -> impl Future + Send + 'io { - self.fetcher.get_or_build_wheel_metadata(dist) + async fn get_or_build_wheel_metadata<'io>(&'io self, dist: &'io Dist) -> WheelMetadataResponse { + self.fetcher.get_or_build_wheel_metadata(dist).await } /// Set the [`puffin_distribution::Reporter`] to use for this installer. diff --git a/crates/puffin-traits/src/lib.rs b/crates/puffin-traits/src/lib.rs index 6d62adedcd3d..9a779770243f 100644 --- a/crates/puffin-traits/src/lib.rs +++ b/crates/puffin-traits/src/lib.rs @@ -160,7 +160,7 @@ impl Display for BuildKind { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum NoBinary { /// Allow installation of any wheel. None,