Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
feat: cached provider
Browse files Browse the repository at this point in the history
  • Loading branch information
gakonst committed Feb 24, 2022
1 parent f5ef814 commit a4405f8
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 10 deletions.
96 changes: 88 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ethers-providers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ tracing = { version = "0.1.31", default-features = false }
tracing-futures = { version = "0.2.5", default-features = false, features = ["std-future"] }

bytes = { version = "1.1.0", default-features = false, optional = true }
dashmap = { version = "5.1.0", features = ["serde", "rayon"] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
# tokio
Expand Down
98 changes: 98 additions & 0 deletions ethers-providers/src/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use crate::ProviderError;
use dashmap::DashMap;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
fs::File,
io::{BufReader, BufWriter},
path::PathBuf,
};

#[derive(Clone, Debug, Default)]
/// Simple in-memory K-V cache using concurrent dashmap which flushes
/// its state to disk on `Drop`.
pub struct Cache {
path: PathBuf,
// serialized request / response pair
requests: DashMap<String, String>,
}

// Helper type for (de)serialization
#[derive(Serialize, Deserialize)]
struct CachedRequest<'a, T> {
method: &'a str,
params: T,
}

impl Cache {
/// Instantiates a new cache at a file path.
pub fn new(path: PathBuf) -> Result<Self, ProviderError> {
// try to read the already existing requests
let reader =
BufReader::new(File::options().write(true).read(true).create(true).open(&path)?);
let requests = serde_json::from_reader(reader).unwrap_or_default();
Ok(Self { path, requests })
}

pub fn get<T: Serialize, R: DeserializeOwned>(
&self,
method: &str,
params: &T,
) -> Result<Option<R>, ProviderError> {
let key = serde_json::to_string(&CachedRequest { method, params })?;
let value = self.requests.get(&key);
value.map(|x| serde_json::from_str(&x).map_err(ProviderError::SerdeJson)).transpose()
}

pub fn set<T: Serialize, R: Serialize>(
&self,
method: &str,
params: T,
response: R,
) -> Result<(), ProviderError> {
let key = serde_json::to_string(&CachedRequest { method, params })?;
let value = serde_json::to_string(&response)?;
self.requests.insert(key, value);
Ok(())
}
}

impl Drop for Cache {
fn drop(&mut self) {
let file = match File::options().write(true).read(true).create(true).open(&self.path) {
Ok(inner) => BufWriter::new(inner),
Err(err) => {
tracing::error!("could not open cache file {}", err);
return
}
};

// overwrite the cache
if let Err(err) = serde_json::to_writer(file, &self.requests) {
tracing::error!("could not write to cache file {}", err);
};
}
}

#[cfg(test)]
mod tests {
use crate::Provider;
use ethers_core::types::{Address, U256};

#[tokio::test]
async fn test_cache() {
let tmp = tempfile::tempdir().unwrap();
let cache = tmp.path().join("cache");
let (provider, mock) = Provider::mocked();
let provider = provider.with_cache(cache.clone());
let addr = Address::random();

assert!(provider.cache().unwrap().requests.is_empty());

mock.push(U256::from(100u64)).unwrap();
let res = provider.get_balance(addr, None).await.unwrap();
assert_eq!(res, 100.into());

assert!(!provider.cache().unwrap().requests.is_empty());
dbg!(&provider.cache);
}
}
2 changes: 2 additions & 0 deletions ethers-providers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub use transports::*;

mod provider;

mod cache;

// ENS support
pub mod ens;

Expand Down
34 changes: 32 additions & 2 deletions ethers-providers/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
cache::Cache,
ens, erc, maybe,
pubsub::{PubsubClient, SubscriptionStream},
stream::{FilterWatcher, DEFAULT_POLL_INTERVAL},
Expand Down Expand Up @@ -28,7 +29,7 @@ use thiserror::Error;
use url::{ParseError, Url};

use futures_util::{lock::Mutex, try_join};
use std::{convert::TryFrom, fmt::Debug, str::FromStr, sync::Arc, time::Duration};
use std::{convert::TryFrom, fmt::Debug, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
use tracing::trace;
use tracing_futures::Instrument;

Expand Down Expand Up @@ -83,6 +84,7 @@ pub struct Provider<P> {
ens: Option<Address>,
interval: Option<Duration>,
from: Option<Address>,
cache: Option<Cache>,
/// Node client hasn't been checked yet = `None`
/// Unsupported node client = `Some(None)`
/// Supported node client = `Some(Some(NodeClient))`
Expand Down Expand Up @@ -132,6 +134,9 @@ pub enum ProviderError {

#[error("Attempted to sign a transaction with no available signer. Hint: did you mean to use a SignerMiddleware?")]
SignerUnavailable,

#[error(transparent)]
Io(#[from] std::io::Error),
}

/// Types of filters supported by the JSON-RPC.
Expand All @@ -157,9 +162,14 @@ impl<P: JsonRpcClient> Provider<P> {
interval: None,
from: None,
_node_client: Arc::new(Mutex::new(None)),
cache: None,
}
}

pub fn cache(&self) -> Option<&Cache> {
self.cache.as_ref()
}

/// Returns the type of node we're connected to, while also caching the value for use
/// in other node-specific API calls, such as the get_block_receipts call.
pub async fn node_client(&self) -> Result<NodeClient, ProviderError> {
Expand Down Expand Up @@ -193,9 +203,22 @@ impl<P: JsonRpcClient> Provider<P> {
tracing::trace_span!("rpc", method = method, params = ?serde_json::to_string(&params)?);
// https://docs.rs/tracing/0.1.22/tracing/span/struct.Span.html#in-asynchronous-code
let res = async move {
// if there's a cache hit, return it
if let Some(ref cache) = self.cache {
if let Some(res) = cache.get(method, &params)? {
return Ok(res)
}
}

trace!("tx");
let res: R = self.inner.request(method, params).await.map_err(Into::into)?;
let res: R = self.inner.request(method, &params).await.map_err(Into::into)?;
trace!(rx = ?serde_json::to_string(&res)?);

// save the response if there was a cache set
if let Some(ref cache) = self.cache {
cache.set(method, params, &res)?;
}

Ok::<_, ProviderError>(res)
}
.instrument(span)
Expand Down Expand Up @@ -1181,6 +1204,13 @@ impl<P: JsonRpcClient> Provider<P> {
self
}

#[must_use]
/// Sets the provider's cache to avoid making redundant network requests.
pub fn with_cache(mut self, cache: PathBuf) -> Self {
self.cache = Some(Cache::new(cache).unwrap());
self
}

/// Gets the polling interval which the provider currently uses for event filters
/// and pending transactions (default: 7 seconds)
pub fn get_interval(&self) -> Duration {
Expand Down

0 comments on commit a4405f8

Please sign in to comment.