From 982b3a510dcea9ebac6134dca3d8b729c2c00678 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Wed, 31 Jan 2024 20:26:58 +0100 Subject: [PATCH] chore: simplify PubsubFrontend --- crates/pubsub/src/frontend.rs | 87 +++++++++++++++-------------------- crates/transport/src/lib.rs | 6 +-- 2 files changed, 39 insertions(+), 54 deletions(-) diff --git a/crates/pubsub/src/frontend.rs b/crates/pubsub/src/frontend.rs index c0e9d0b9a0d..d5e99646181 100644 --- a/crates/pubsub/src/frontend.rs +++ b/crates/pubsub/src/frontend.rs @@ -2,9 +2,12 @@ use crate::{ix::PubSubInstruction, managers::InFlight}; use alloy_json_rpc::{RequestPacket, Response, ResponsePacket, SerializedRequest}; use alloy_primitives::U256; use alloy_transport::{TransportError, TransportErrorKind, TransportFut}; -use futures::future::try_join_all; +use futures::{future::try_join_all, FutureExt, TryFutureExt}; use serde_json::value::RawValue; -use std::{future::Future, pin::Pin}; +use std::{ + future::Future, + task::{Context, Poll}, +}; use tokio::sync::{broadcast, mpsc, oneshot}; /// A `PubSubFrontend` is [`Transport`] composed of a channel to a running @@ -23,57 +26,51 @@ impl PubSubFrontend { } /// Get the subscription ID for a local ID. - pub async fn get_subscription( + pub fn get_subscription( &self, id: U256, - ) -> Result>, TransportError> { - let (tx, rx) = oneshot::channel(); - self.tx - .send(PubSubInstruction::GetSub(id, tx)) - .map_err(|_| TransportErrorKind::backend_gone())?; - rx.await.map_err(|_| TransportErrorKind::backend_gone()) + ) -> impl Future>, TransportError>> + Send + 'static + { + let backend_tx = self.tx.clone(); + async move { + let (tx, rx) = oneshot::channel(); + backend_tx + .send(PubSubInstruction::GetSub(id, tx)) + .map_err(|_| TransportErrorKind::backend_gone())?; + rx.await.map_err(|_| TransportErrorKind::backend_gone()) + } } /// Unsubscribe from a subscription. - pub async fn unsubscribe(&self, id: U256) -> Result<(), TransportError> { + pub fn unsubscribe(&self, id: U256) -> Result<(), TransportError> { self.tx .send(PubSubInstruction::Unsubscribe(id)) - .map_err(|_| TransportErrorKind::backend_gone())?; - Ok(()) + .map_err(|_| TransportErrorKind::backend_gone()) } /// Send a request. pub fn send( &self, req: SerializedRequest, - ) -> Pin> + Send>> { - let (in_flight, rx) = InFlight::new(req); - let ix = PubSubInstruction::Request(in_flight); + ) -> impl Future> + Send + 'static { let tx = self.tx.clone(); - - Box::pin(async move { - tx.send(ix).map_err(|_| TransportErrorKind::backend_gone())?; + async move { + let (in_flight, rx) = InFlight::new(req); + tx.send(PubSubInstruction::Request(in_flight)) + .map_err(|_| TransportErrorKind::backend_gone())?; rx.await.map_err(|_| TransportErrorKind::backend_gone())? - }) + } } /// Send a packet of requests, by breaking it up into individual requests. /// /// Once all responses are received, we return a single response packet. - /// This is a bit annoying - pub fn send_packet( - &self, - req: RequestPacket, - ) -> Pin> + Send>> { + pub fn send_packet(&self, req: RequestPacket) -> TransportFut<'static> { match req { - RequestPacket::Single(req) => { - let fut = self.send(req); - Box::pin(async move { Ok(ResponsePacket::Single(fut.await?)) }) - } - RequestPacket::Batch(reqs) => { - let futs = try_join_all(reqs.into_iter().map(|req| self.send(req))); - Box::pin(async move { Ok(futs.await?.into()) }) - } + RequestPacket::Single(req) => self.send(req).map_ok(ResponsePacket::Single).boxed(), + RequestPacket::Batch(reqs) => try_join_all(reqs.into_iter().map(|req| self.send(req))) + .map_ok(ResponsePacket::Batch) + .boxed(), } } } @@ -84,36 +81,26 @@ impl tower::Service for PubSubFrontend { type Future = TransportFut<'static>; #[inline] - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if self.tx.is_closed() { - return std::task::Poll::Ready(Err(TransportErrorKind::backend_gone())); - } - std::task::Poll::Ready(Ok(())) + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + (&*self).poll_ready(cx) } #[inline] fn call(&mut self, req: RequestPacket) -> Self::Future { - self.send_packet(req) + (&*self).call(req) } } impl tower::Service for &PubSubFrontend { type Response = ResponsePacket; type Error = TransportError; - type Future = Pin> + Send>>; + type Future = TransportFut<'static>; #[inline] - fn poll_ready( - &mut self, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if self.tx.is_closed() { - return std::task::Poll::Ready(Err(TransportErrorKind::backend_gone())); - } - std::task::Poll::Ready(Ok(())) + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + let result = + if self.tx.is_closed() { Err(TransportErrorKind::backend_gone()) } else { Ok(()) }; + Poll::Ready(result) } #[inline] diff --git a/crates/transport/src/lib.rs b/crates/transport/src/lib.rs index 89e9211961b..1cca6314114 100644 --- a/crates/transport/src/lib.rs +++ b/crates/transport/src/lib.rs @@ -41,9 +41,8 @@ pub use type_aliases::*; #[cfg(not(target_arch = "wasm32"))] mod type_aliases { - use alloy_json_rpc::ResponsePacket; - use crate::{TransportError, TransportResult}; + use alloy_json_rpc::ResponsePacket; /// Pin-boxed future. pub type Pbf<'a, T, E> = @@ -60,9 +59,8 @@ mod type_aliases { #[cfg(target_arch = "wasm32")] mod type_aliases { - use alloy_json_rpc::ResponsePacket; - use crate::{TransportError, TransportResult}; + use alloy_json_rpc::ResponsePacket; /// Pin-boxed future. pub type Pbf<'a, T, E> =