diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index 30cb2c9a7c..1177c1fdd3 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -213,6 +213,7 @@ impl Debug for MethodCallback { #[derive(Default, Debug, Clone)] pub struct Methods { callbacks: Arc>, + extensions: Extensions, } impl Methods { @@ -366,22 +367,25 @@ impl Methods { subscription_permit: SubscriptionPermit, ) -> RawRpcResponse { let (tx, mut rx) = mpsc::channel(buf_size); - let Request { id, method, params, mut extensions, .. } = req; + // The extensions is always empty when calling the method directly because decoding an JSON-RPC + // request doesn't have any extensions. + let Request { id, method, params, .. } = req; let params = Params::new(params.as_ref().map(|params| params.as_ref().get())); let max_response_size = usize::MAX; let conn_id = ConnectionId(0); - extensions.insert(conn_id); + let mut ext = self.extensions.clone(); + ext.insert(conn_id); let response = match self.method(&method) { None => MethodResponse::error(id, ErrorObject::from(ErrorCode::MethodNotFound)), - Some(MethodCallback::Sync(cb)) => (cb)(id, params, max_response_size, extensions), + Some(MethodCallback::Sync(cb)) => (cb)(id, params, max_response_size, ext), Some(MethodCallback::Async(cb)) => { - (cb)(id.into_owned(), params.into_owned(), conn_id, max_response_size, extensions).await + (cb)(id.into_owned(), params.into_owned(), conn_id, max_response_size, ext).await } Some(MethodCallback::Subscription(cb)) => { let conn_state = SubscriptionState { conn_id, id_provider: &RandomIntegerIdProvider, subscription_permit }; - let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state, extensions).await; + let res = (cb)(id, params, MethodSink::new(tx.clone()), conn_state, ext).await; // This message is not used because it's used for metrics so we discard in other to // not read once this is used for subscriptions. @@ -391,7 +395,7 @@ impl Methods { res } - Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, conn_id, max_response_size, extensions), + Some(MethodCallback::Unsubscription(cb)) => (cb)(id, params, conn_id, max_response_size, ext), }; let is_success = response.is_success(); @@ -469,6 +473,43 @@ impl Methods { pub fn method_names(&self) -> impl Iterator + '_ { self.callbacks.keys().copied() } + + /// Similar to [`Methods::extensions_mut`] but it's immutable. + pub fn extensions(&mut self) -> &Extensions { + &self.extensions + } + + /// Get a mutable reference to the extensions to add or remove data from + /// the extensions. + /// + /// This only affects direct calls to the methods and subscriptions + /// and can be used for example to unit test the API without a server. + /// + /// # Examples + /// + /// ``` + /// #[tokio::main] + /// async fn main() { + /// use jsonrpsee::{RpcModule, IntoResponse, Extensions}; + /// use jsonrpsee::core::RpcResult; + /// + /// let mut module = RpcModule::new(()); + /// module.register_method::, _>("magic_multiply", |params, _, ext| { + /// let magic = ext.get::().copied().unwrap(); + /// let val = params.one::()?; + /// Ok(val * magic) + /// }).unwrap(); + /// + /// // inject arbitrary data into the extensions. + /// module.extensions_mut().insert(33_u64); + /// + /// let magic: u64 = module.call("magic_multiply", [1_u64]).await.unwrap(); + /// assert_eq!(magic, 33); + /// } + /// ``` + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } } impl Deref for RpcModule {