Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Method aliases + RpcModule: Clone #383

Merged
merged 8 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ pub enum Error {
/// Method was already registered.
#[error("Method: {0} was already registered")]
MethodAlreadyRegistered(String),
/// Method with that name has not yet been registered.
#[error("Method: {0} has not yet been registered")]
MethodNotFound(String),
/// Subscribe and unsubscribe method names are the same.
#[error("Cannot use the same method name for subscribe and unsubscribe, used: {0}")]
SubscriptionNameConflict(String),
Expand Down
86 changes: 62 additions & 24 deletions utils/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::sync::Arc;
/// implemented as a function pointer to a `Fn` function taking four arguments:
/// the `id`, `params`, a channel the function uses to communicate the result (or error)
/// back to `jsonrpsee`, and the connection ID (useful for the websocket transport).
pub type SyncMethod = Box<dyn Send + Sync + Fn(Id, RpcParams, &MethodSink, ConnectionId) -> Result<(), Error>>;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, RpcParams, &MethodSink, ConnectionId) -> Result<(), Error>>;
/// Similar to [`SyncMethod`], but represents an asynchronous handler.
pub type AsyncMethod = Arc<
dyn Send + Sync + Fn(OwnedId, OwnedRpcParams, MethodSink, ConnectionId) -> BoxFuture<'static, Result<(), Error>>,
Expand All @@ -41,6 +41,7 @@ struct SubscriptionKey {
}

/// Callback wrapper that can be either sync or async.
#[derive(Clone)]
pub enum MethodCallback {
/// Synchronous method handler.
Sync(SyncMethod),
Expand Down Expand Up @@ -81,10 +82,10 @@ impl Debug for MethodCallback {
}
}

/// Collection of synchronous and asynchronous methods.
#[derive(Default, Debug)]
/// Reference-counted, clone-on-write collection of synchronous and asynchronous methods.
#[derive(Default, Debug, Clone)]
pub struct Methods {
callbacks: FxHashMap<&'static str, MethodCallback>,
callbacks: Arc<FxHashMap<&'static str, MethodCallback>>,
}

impl Methods {
Expand All @@ -101,15 +102,22 @@ impl Methods {
Ok(())
}

/// Helper for obtaining a mut ref to the callbacks HashMap.
fn mut_callbacks(&mut self) -> &mut FxHashMap<&'static str, MethodCallback> {
Arc::make_mut(&mut self.callbacks)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat TIL.

}

/// Merge two [`Methods`]'s by adding all [`MethodCallback`]s from `other` into `self`.
/// Fails if any of the methods in `other` is present already.
pub fn merge(&mut self, other: Methods) -> Result<(), Error> {
pub fn merge(&mut self, mut other: Methods) -> Result<(), Error> {
for name in other.callbacks.keys() {
self.verify_method_name(name)?;
}

for (name, callback) in other.callbacks {
self.callbacks.insert(name, callback);
let callbacks = self.mut_callbacks();

for (name, callback) in other.mut_callbacks().drain() {
callbacks.insert(name, callback);
}

Ok(())
Expand Down Expand Up @@ -137,17 +145,33 @@ impl Methods {
/// Sets of JSON-RPC methods can be organized into a "module"s that are in turn registered on the server or,
/// alternatively, merged with other modules to construct a cohesive API. [`RpcModule`] wraps an additional context
/// argument that can be used to access data during call execution.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct RpcModule<Context> {
ctx: Arc<Context>,
methods: Methods,
}

impl<Context: Send + Sync + 'static> RpcModule<Context> {
impl<Context> RpcModule<Context> {
/// Create a new module with a given shared `Context`.
pub fn new(ctx: Context) -> Self {
Self { ctx: Arc::new(ctx), methods: Default::default() }
}

/// Convert a module into methods. Consumes self.
pub fn into_methods(self) -> Methods {
self.methods
}

/// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`.
/// Fails if any of the methods in `other` is present already.
pub fn merge<Context2>(&mut self, other: RpcModule<Context2>) -> Result<(), Error> {
self.methods.merge(other.methods)?;

Ok(())
}
}

impl<Context: Send + Sync + 'static> RpcModule<Context> {
/// Register a new synchronous RPC method, which computes the response with the given callback.
pub fn register_method<R, F>(&mut self, method_name: &'static str, callback: F) -> Result<(), Error>
where
Expand All @@ -159,9 +183,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let ctx = self.ctx.clone();

self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
method_name,
MethodCallback::Sync(Box::new(move |id, params, tx, _| {
MethodCallback::Sync(Arc::new(move |id, params, tx, _| {
match callback(params, &*ctx) {
Ok(res) => send_response(id, tx, res),
Err(CallError::InvalidParams) => send_error(id, tx, JsonRpcErrorCode::InvalidParams.into()),
Expand Down Expand Up @@ -192,7 +216,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

let ctx = self.ctx.clone();

self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
method_name,
MethodCallback::Async(Arc::new(move |id, params, tx, _| {
let ctx = ctx.clone();
Expand Down Expand Up @@ -265,9 +289,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {

{
let subscribers = subscribers.clone();
self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
subscribe_method_name,
MethodCallback::Sync(Box::new(move |id, params, method_sink, conn_id| {
MethodCallback::Sync(Arc::new(move |id, params, method_sink, conn_id| {
let (conn_tx, conn_rx) = oneshot::channel::<()>();
let sub_id = {
const JS_NUM_MASK: SubscriptionId = !0 >> 11;
Expand All @@ -293,9 +317,9 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
}

{
self.methods.callbacks.insert(
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::Sync(Box::new(move |id, params, tx, conn_id| {
MethodCallback::Sync(Arc::new(move |id, params, tx, conn_id| {
let sub_id = params.one()?;
subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id });
send_response(id, &tx, "Unsubscribed");
Expand All @@ -308,15 +332,16 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
Ok(())
}

/// Convert a module into methods. Consumes self.
pub fn into_methods(self) -> Methods {
self.methods
}
/// Register an `alias` name for an `existing_method`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

pub fn register_alias(&mut self, alias: &'static str, existing_method: &'static str) -> Result<(), Error> {
self.methods.verify_method_name(alias)?;

/// Merge two [`RpcModule`]'s by adding all [`Methods`] `other` into `self`.
/// Fails if any of the methods in `other` is present already.
pub fn merge<Context2>(&mut self, other: RpcModule<Context2>) -> Result<(), Error> {
self.methods.merge(other.methods)?;
let callback = match self.methods.callbacks.get(existing_method) {
Some(callback) => callback.clone(),
None => return Err(Error::MethodNotFound(existing_method.into())),
};

self.methods.mut_callbacks().insert(alias, callback);

Ok(())
}
Expand Down Expand Up @@ -431,4 +456,17 @@ mod tests {
assert!(methods.method("hi").is_some());
assert!(methods.method("goodbye").is_some());
}

#[test]
fn rpc_register_alias() {
let mut module = RpcModule::new(());

module.register_method("hello_world", |_: RpcParams, _| Ok(())).unwrap();
module.register_alias("hello_foobar", "hello_world").unwrap();

let methods = module.into_methods();

assert!(methods.method("hello_world").is_some());
assert!(methods.method("hello_foobar").is_some());
}
}
16 changes: 11 additions & 5 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl Server {
/// Register all methods from a [`Methods`] of provided [`RpcModule`] on this server.
/// In case a method already is registered with the same name, no method is added and a [`Error::MethodAlreadyRegistered`]
/// is returned. Note that the [`RpcModule`] is consumed after this call.
pub fn register_module<Context: Send + Sync + 'static>(&mut self, module: RpcModule<Context>) -> Result<(), Error> {
pub fn register_module<Context>(&mut self, module: RpcModule<Context>) -> Result<(), Error> {
self.methods.merge(module.into_methods())?;
Ok(())
}
Expand All @@ -74,21 +74,27 @@ impl Server {
/// Start responding to connections requests. This will block current thread until the server is stopped.
pub async fn start(self) {
let mut incoming = TcpListenerStream::new(self.listener);
let methods = Arc::new(self.methods);
let methods = self.methods;
let conn_counter = Arc::new(());
let cfg = self.cfg;
let mut id = 0;

while let Some(socket) = incoming.next().await {
if let Ok(socket) = socket {
socket.set_nodelay(true).unwrap_or_else(|e| panic!("Could not set NODELAY on socket: {:?}", e));

if Arc::strong_count(&methods) > self.cfg.max_connections as usize {
if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while");
continue;
}
let methods = methods.clone();
let counter = conn_counter.clone();

tokio::spawn(background_task(socket, id, methods, cfg));
tokio::spawn(async move {
let r = background_task(socket, id, methods, cfg).await;
drop(counter);
r
});

id += 1;
}
Expand All @@ -99,7 +105,7 @@ impl Server {
async fn background_task(
socket: tokio::net::TcpStream,
conn_id: ConnectionId,
methods: Arc<Methods>,
methods: Methods,
cfg: Settings,
) -> Result<(), Error> {
// For each incoming background_task we perform a handshake.
Expand Down