Skip to content

Commit

Permalink
run rpc futures in a tokio thread till completion
Browse files Browse the repository at this point in the history
when ran directly inside hyper's service function, it might get aborted mid-way when a client disconnects, leaving the future in complete.

this is an issue as we might have some code that need to be executed atomically inside any part of the code base that could be triggered by an RPC request. if the client disconnects, such a function will abort on the next await call, leaving us with non-atomic state (e.g. partial update for a map and its inverse).
  • Loading branch information
mariocynicys committed Oct 8, 2024
1 parent e65fefe commit 604740c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 22 deletions.
48 changes: 34 additions & 14 deletions mm2src/mm2_main/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use std::net::SocketAddr;

cfg_native! {
use hyper::{self, Body, Server};
use futures::channel::oneshot;
use mm2_net::sse_handler::{handle_sse, SSE_ENDPOINT};
}

Expand Down Expand Up @@ -333,35 +334,54 @@ pub extern "C" fn spawn_rpc(ctx_h: u32) {
Ok((cert_chain, privkey))
}

// Handles incoming HTTP requests.
async fn handle_request(
req: Request<Body>,
remote_addr: SocketAddr,
ctx_h: u32,
is_event_stream_enabled: bool,
) -> Result<Response<Body>, Infallible> {
let (tx, rx) = oneshot::channel();
// We execute the request in a separate task to avoid it being left uncompleted if the client disconnects.
// So what's inside the spawn here will complete till completion (or panic).
common::executor::spawn(async move {
if is_event_stream_enabled && req.uri().path() == SSE_ENDPOINT {
tx.send(handle_sse(req, ctx_h).await).ok();
} else {
tx.send(rpc_service(req, ctx_h, remote_addr).await).ok();
}
});
// On the other hand, this `.await` might be aborted if the client disconnects.
match rx.await {
Ok(res) => Ok(res),
Err(_) => {
let err = "The RPC service aborted without responding.";
error!("{}", err);
Ok(Response::builder().status(500).body(Body::from(err)).unwrap())
},
}
}

// NB: We need to manually handle the incoming connections in order to get the remote IP address,
// cf. https://github.com/hyperium/hyper/issues/1410#issuecomment-419510220.
// Although if the ability to access the remote IP address is solved by the Hyper in the future
// then we might want to refactor into starting it ideomatically in order to benefit from a more graceful shutdown,
// cf. https://github.com/hyperium/hyper/pull/1640.

let ctx = MmArc::from_ffi_handle(ctx_h).expect("No context");

let is_event_stream_enabled = ctx.event_stream_configuration.is_some();

let make_svc_fut = move |remote_addr: SocketAddr| async move {
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| async move {
if is_event_stream_enabled && req.uri().path() == SSE_ENDPOINT {
let res = handle_sse(req, ctx_h).await?;
return Ok::<_, Infallible>(res);
}

let res = rpc_service(req, ctx_h, remote_addr).await;
Ok::<_, Infallible>(res)
}))
};

//The `make_svc` macro creates a `make_service_fn` for a specified socket type.
// `$socket_type`: The socket type with a `remote_addr` method that returns a `SocketAddr`.
macro_rules! make_svc {
($socket_type:ty) => {
make_service_fn(move |socket: &$socket_type| {
let remote_addr = socket.remote_addr();
make_svc_fut(remote_addr)
async move {
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
handle_request(req, remote_addr, ctx_h, is_event_stream_enabled)
}))
}
})
};
}
Expand Down
13 changes: 5 additions & 8 deletions mm2src/mm2_net/src/sse_handler.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use hyper::{body::Bytes, Body, Request, Response};
use mm2_core::mm_ctx::MmArc;
use serde_json::json;
use std::convert::Infallible;

pub const SSE_ENDPOINT: &str = "/event-stream";

/// Handles broadcasted messages from `mm2_event_stream` continuously.
pub async fn handle_sse(request: Request<Body>, ctx_h: u32) -> Result<Response<Body>, Infallible> {
pub async fn handle_sse(request: Request<Body>, ctx_h: u32) -> Response<Body> {
// This is only called once for per client on the initialization,
// meaning this is not a resource intensive computation.
let ctx = match MmArc::from_ffi_handle(ctx_h) {
Expand Down Expand Up @@ -62,17 +61,15 @@ pub async fn handle_sse(request: Request<Body>, ctx_h: u32) -> Result<Response<B
.body(body);

match response {
Ok(res) => Ok(res),
Ok(res) => res,
Err(err) => handle_internal_error(err.to_string()).await,
}
}

/// Fallback function for handling errors in SSE connections
async fn handle_internal_error(message: String) -> Result<Response<Body>, Infallible> {
let response = Response::builder()
async fn handle_internal_error(message: String) -> Response<Body> {
Response::builder()
.status(500)
.body(Body::from(message))
.expect("Returning 500 should never fail.");

Ok(response)
.expect("Returning 500 should never fail.")
}

0 comments on commit 604740c

Please sign in to comment.