Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Ws and Ipc providers #369

Merged
merged 14 commits into from
May 7, 2024
2 changes: 1 addition & 1 deletion crates/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async-openai = "0.10.0"
clap = { version = "3.1.18", features = ["derive"] }
colored = "2"
crossbeam-channel = "0.5.7"
ethers = "2.0.4"
ethers = { version = "2.0.4", features = [ "ipc", "ws", ] }
fancy-regex = "0.11.0"
heimdall-cache = { workspace = true }
indicatif = "0.17.0"
Expand Down
167 changes: 167 additions & 0 deletions crates/common/src/ether/http_or_ws_or_ipc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
//! Create a custom data transport to use with a Provider.
use async_trait::async_trait;
use ethers::prelude::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{fmt::Debug, str::FromStr};
use thiserror::Error;

/// First we must create an error type, and implement [`From`] for
/// [`ProviderError`].
///
/// Here we are using [`thiserror`](https://docs.rs/thiserror) to wrap
/// [`HttpClientError`], [`WsClientError`] and [`IpcError`].
///
/// This also provides a conversion implementation ([`From`]) for both, so we
/// can use the [question mark operator](https://doc.rust-lang.org/rust-by-example/std/result/question_mark.html)
/// later on in our implementations.
#[derive(Debug, Error)]
pub enum HttpOrWsOrIpcError {
#[error(transparent)]
Ws(#[from] WsClientError),

#[error(transparent)]
Ipc(#[from] IpcError),

#[error(transparent)]
Http(#[from] HttpClientError),
}

/// In order to use our `HttpOrWsOrIpcError` in the RPC client, we have to implement
/// this trait.
///
/// [`RpcError`] helps other parts off the stack get access to common provider
/// error cases. For example, any RPC connection may have a `serde_json` error,
/// so we want to make those easily accessible, so we implement
/// `as_serde_error()`
///
/// In addition, RPC requests may return JSON errors from the node, describing
/// why the request failed. In order to make these accessible, we implement
/// `as_error_response()`.
impl RpcError for HttpOrWsOrIpcError {
fn as_error_response(&self) -> Option<&ethers::providers::JsonRpcError> {
match self {
HttpOrWsOrIpcError::Ws(e) => e.as_error_response(),
HttpOrWsOrIpcError::Ipc(e) => e.as_error_response(),
HttpOrWsOrIpcError::Http(e) => e.as_error_response(),
}
}

fn as_serde_error(&self) -> Option<&serde_json::Error> {
match self {
HttpOrWsOrIpcError::Ws(WsClientError::JsonError(e)) => Some(e),
HttpOrWsOrIpcError::Ipc(IpcError::JsonError(e)) => Some(e),
HttpOrWsOrIpcError::Http(HttpClientError::SerdeJson { err: e, text: _ }) => Some(e),
_ => None,
}
}
}

/// This implementation helps us convert our Error to the library's
/// [`ProviderError`] so that we can use the `?` operator
impl From<HttpOrWsOrIpcError> for ProviderError {
fn from(value: HttpOrWsOrIpcError) -> Self {
Self::JsonRpcClientError(Box::new(value))
}
}

/// Next, we create our transport type, which in this case will be an enum that contains
/// either [`Http`], [`Ws`] or [`Ipc`].
#[derive(Clone, Debug)]
pub enum HttpOrWsOrIpc {
Ws(Ws),
Ipc(Ipc),
Http(Http),
}

// We implement a convenience "constructor" method, to easily initialize the transport.
// This will connect to [`Http`] if the rpc_url contains 'http', to [`Ws`] if it contains 'ws',
// otherwise it'll default to [`Ipc`].
impl HttpOrWsOrIpc {
pub async fn connect(rpc_url: &str) -> Result<Self, HttpOrWsOrIpcError> {
let this = if rpc_url.to_lowercase().contains("http") {
Self::Http(Http::from_str(rpc_url).unwrap())
} else if rpc_url.to_lowercase().contains("ws") {
Self::Ws(Ws::connect(rpc_url).await?)
} else {
Self::Ipc(Ipc::connect(rpc_url).await?)
};
Ok(this)
}
}

// Next, the most important step: implement [`JsonRpcClient`].
//
// For this implementation, we simply delegate to the wrapped transport and return the
// result.
//
// Note that we are using [`async-trait`](https://docs.rs/async-trait) for asynchronous
// functions in traits, as this is not yet supported in stable Rust; see:
// <https://blog.rust-lang.org/inside-rust/2022/11/17/async-fn-in-trait-nightly.html>
#[async_trait]
impl JsonRpcClient for HttpOrWsOrIpc {
type Error = HttpOrWsOrIpcError;

async fn request<T, R>(&self, method: &str, params: T) -> Result<R, Self::Error>
where
T: Debug + Serialize + Send + Sync,
R: DeserializeOwned + Send,
{
// println!("request");
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

let res = match self {
Self::Ws(ws) => JsonRpcClient::request(ws, method, params).await?,
Self::Ipc(ipc) => JsonRpcClient::request(ipc, method, params).await?,
Self::Http(http) => JsonRpcClient::request(http, method, params).await?,
};
Ok(res)
}
}

// We can also implement [`PubsubClient`], since both `Ws` and `Ipc` implement it, by
// doing the same as in the `JsonRpcClient` implementation above.
// Trying to subscribe on a `Http` will panic.
impl PubsubClient for HttpOrWsOrIpc {
// Since both `Ws` and `Ipc`'s `NotificationStream` associated type is the same,
// we can simply return one of them.
// In case they differed, we would have to create a `HttpOrWsOrIpcNotificationStream`,
// similar to the error type.
type NotificationStream = <Ws as PubsubClient>::NotificationStream;

fn subscribe<T: Into<U256>>(&self, id: T) -> Result<Self::NotificationStream, Self::Error> {
let stream = match self {
Self::Ws(ws) => PubsubClient::subscribe(ws, id)?,
Self::Ipc(ipc) => PubsubClient::subscribe(ipc, id)?,
HttpOrWsOrIpc::Http(_) => unreachable!("Http RPC cannot be used for subscriptions!"),
};
Ok(stream)
}

fn unsubscribe<T: Into<U256>>(&self, id: T) -> Result<(), Self::Error> {
match self {
Self::Ws(ws) => PubsubClient::unsubscribe(ws, id)?,
Self::Ipc(ipc) => PubsubClient::unsubscribe(ipc, id)?,
HttpOrWsOrIpc::Http(_) => unreachable!("Http RPC cannot be used for subscriptions!"),
};
Ok(())
}
}

#[tokio::test]
async fn test_subscription() {
fala13 marked this conversation as resolved.
Show resolved Hide resolved
// Spawn Anvil
let anvil = ethers::utils::Anvil::new().block_time(1u64).spawn();

// Connect to our transport
let transport = HttpOrWsOrIpc::connect(&anvil.ws_endpoint()).await.unwrap();

// Wrap the transport in a provider
let provider = Provider::new(transport);

// Now we can use our custom transport provider like normal
let block_number = provider.get_block_number().await.unwrap();
println!("Current block: {block_number}");

let mut subscription = provider.subscribe_blocks().await.unwrap().take(3);
while let Some(block) = subscription.next().await {
println!("New block: {:?}", block.number);
}
}
1 change: 1 addition & 0 deletions crates/common/src/ether/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod bytecode;
pub mod compiler;
pub mod evm;
pub mod http_or_ws_or_ipc;
pub mod lexers;
pub mod rpc;
pub mod selectors;
Expand Down
45 changes: 37 additions & 8 deletions crates/common/src/ether/rpc.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::error::Error;
use crate::{
error::Error,
ether::http_or_ws_or_ipc::{self, HttpOrWsOrIpc},
};
use backoff::ExponentialBackoff;
use ethers::{
core::types::Address,
providers::{Http, Middleware, Provider},
providers::{Middleware, Provider},
types::{
BlockNumber::{self},
BlockTrace, Filter, FilterBlockOption, StateDiff, TraceType, Transaction, H256,
Expand All @@ -12,6 +15,24 @@ use heimdall_cache::{read_cache, store_cache};
use std::{str::FromStr, time::Duration};
use tracing::{debug, error, trace};

/// Get the Provider object for RPC URL
///
/// ```no_run
/// use heimdall_common::ether::rpc::get_provider;
///
/// // let provider = get_provider("https://eth.llamarpc.com").await?;
/// //assert_eq!(provider.get_chainid().await.unwrap(), 1);
/// ```
pub async fn get_provider(rpc_url: &str) -> Result<Provider<HttpOrWsOrIpc>, Error> {
Ok(Provider::new(match http_or_ws_or_ipc::HttpOrWsOrIpc::connect(rpc_url).await {
Ok(provider) => provider,
Err(error) => {
error!("failed to connect to RPC provider '{}' .", &rpc_url);
return Err(Error::Generic(error.to_string()));
}
}))
}

/// Get the chainId of the provided RPC URL
///
/// ```no_run
Expand Down Expand Up @@ -44,7 +65,7 @@ pub async fn chain_id(rpc_url: &str) -> Result<u64, Error> {
}

// create new provider
let provider = match Provider::<Http>::try_from(rpc_url) {
let provider = match get_provider(rpc_url).await {
Ok(provider) => provider,
Err(_) => {
error!("failed to connect to RPC provider '{}' .", &rpc_url);
Expand Down Expand Up @@ -108,7 +129,7 @@ pub async fn get_code(contract_address: &str, rpc_url: &str) -> Result<Vec<u8>,
}

// create new provider
let provider = match Provider::<Http>::try_from(rpc_url) {
let provider = match get_provider(rpc_url).await {
Ok(provider) => provider,
Err(_) => {
error!("failed to connect to RPC provider '{}' .", &rpc_url);
Expand Down Expand Up @@ -175,7 +196,7 @@ pub async fn get_transaction(transaction_hash: &str, rpc_url: &str) -> Result<Tr
}

// create new provider
let provider = match Provider::<Http>::try_from(rpc_url) {
let provider = match get_provider(rpc_url).await {
Ok(provider) => provider,
Err(_) => {
error!("failed to connect to RPC provider '{}' .", &rpc_url);
Expand Down Expand Up @@ -250,7 +271,7 @@ pub async fn get_storage_diff(
);

// create new provider
let provider = match Provider::<Http>::try_from(rpc_url) {
let provider = match get_provider(rpc_url).await {
Ok(provider) => provider,
Err(_) => {
error!("failed to connect to RPC provider '{}' .", &rpc_url);
Expand Down Expand Up @@ -325,7 +346,7 @@ pub async fn get_trace(transaction_hash: &str, rpc_url: &str) -> Result<BlockTra
&transaction_hash);

// create new provider
let provider = match Provider::<Http>::try_from(rpc_url) {
let provider = match get_provider(rpc_url).await {
Ok(provider) => provider,
Err(_) => {
error!("failed to connect to RPC provider '{}' .", &rpc_url);
Expand Down Expand Up @@ -394,7 +415,7 @@ pub async fn get_block_logs(
trace!("fetching logs from node for block: '{}' .", &block_number);

// create new provider
let provider = match Provider::<Http>::try_from(rpc_url) {
let provider = match get_provider(rpc_url).await {
Ok(provider) => provider,
Err(_) => {
error!("failed to connect to RPC provider '{}' .", &rpc_url);
Expand Down Expand Up @@ -540,4 +561,12 @@ pub mod tests {

assert!(!logs.is_empty());
}

#[tokio::test]
async fn test_chain_id_with_ws_rpc() {
let rpc_url = "wss://zksync.drpc.org";
let rpc_chain_id = chain_id(rpc_url).await.expect("chain_id() returned an error!");

assert_eq!(rpc_chain_id, 324);
}
}
Loading