From e669a0f1349942f43d06c01253967893ce5efff6 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Mon, 1 Jul 2024 16:07:21 -0400 Subject: [PATCH] fix node bindings --- bindings_node/Cargo.lock | 2 + bindings_node/Cargo.toml | 1 + bindings_node/src/conversations.rs | 19 ++------ bindings_node/src/groups.rs | 8 +--- bindings_node/src/streams.rs | 74 +++++++++++++++++++++--------- 5 files changed, 62 insertions(+), 42 deletions(-) diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index a0e9a9c94..0b4bbb205 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -300,6 +300,7 @@ version = "0.0.1" dependencies = [ "futures", "hex", + "log", "napi", "napi-build", "napi-derive", @@ -5452,6 +5453,7 @@ dependencies = [ "thiserror", "tls_codec 0.4.1", "tokio", + "tokio-stream", "toml", "tracing", "xmtp_cryptography", diff --git a/bindings_node/Cargo.toml b/bindings_node/Cargo.toml index 4b4ba232e..c31c75cd8 100644 --- a/bindings_node/Cargo.toml +++ b/bindings_node/Cargo.toml @@ -25,6 +25,7 @@ xmtp_mls = { path = "../xmtp_mls", features = ["grpc", "native"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } xmtp_id = { path = "../xmtp_id" } rand = "0.8.5" +log = { version = "0.4", features = ["release_max_level_debug"] } [build-dependencies] napi-build = "2.0.1" diff --git a/bindings_node/src/conversations.rs b/bindings_node/src/conversations.rs index 298a6e21e..93472974d 100644 --- a/bindings_node/src/conversations.rs +++ b/bindings_node/src/conversations.rs @@ -196,31 +196,22 @@ impl NapiConversations { ThreadsafeFunctionCallMode::Blocking, ); }, - || {}, // on_close_callback - ) - .map_err(|e| Error::from_reason(format!("{}", e)))?; + ); - Ok(NapiStreamCloser::new( - stream_closer.close_fn, - stream_closer.is_closed_atomic, - )) + Ok(NapiStreamCloser::new(stream_closer)) } #[napi(ts_args_type = "callback: (err: null | Error, result: NapiMessage) => void")] pub fn stream_all_messages(&self, callback: JsFunction) -> Result { let tsfn: ThreadsafeFunction = callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; - let stream_closer = RustXmtpClient::stream_all_messages_with_callback_sync( + let stream_closer = RustXmtpClient::stream_all_messages_with_callback( self.inner_client.clone(), move |message| { tsfn.call(Ok(message.into()), ThreadsafeFunctionCallMode::Blocking); }, - ) - .map_err(|e| Error::from_reason(format!("{}", e)))?; + ); - Ok(NapiStreamCloser::new( - stream_closer.close_fn, - stream_closer.is_closed_atomic, - )) + Ok(NapiStreamCloser::new(stream_closer)) } } diff --git a/bindings_node/src/groups.rs b/bindings_node/src/groups.rs index 7b4646f4a..f9e4ef67a 100644 --- a/bindings_node/src/groups.rs +++ b/bindings_node/src/groups.rs @@ -511,13 +511,9 @@ impl NapiGroup { move |message| { tsfn.call(Ok(message.into()), ThreadsafeFunctionCallMode::Blocking); }, - ) - .map_err(|e| Error::from_reason(format!("{}", e)))?; + ); - Ok(NapiStreamCloser::new( - stream_closer.close_fn, - stream_closer.is_closed_atomic, - )) + Ok(stream_closer.into()) } #[napi] diff --git a/bindings_node/src/streams.rs b/bindings_node/src/streams.rs index 0d82382db..58568bf1b 100644 --- a/bindings_node/src/streams.rs +++ b/bindings_node/src/streams.rs @@ -1,35 +1,65 @@ -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, -}; -use tokio::sync::oneshot::Sender; +use std::sync::Arc; +use tokio::{sync::Mutex, task::{JoinHandle, AbortHandle}}; +use xmtp_mls::client::ClientError; +use napi::bindgen_prelude::Error; use napi_derive::napi; #[napi] pub struct NapiStreamCloser { - close_fn: Arc>>>, - is_closed_atomic: Arc, + handle: Arc>>>>, + // for convenience, does not require locking mutex. + abort_handle: Arc, } -#[napi] impl NapiStreamCloser { - pub fn new(close_fn: Arc>>>, is_closed_atomic: Arc) -> Self { - Self { - close_fn, - is_closed_atomic, + pub fn new(handle: JoinHandle>) -> Self { + Self { + abort_handle: Arc::new(handle.abort_handle()), + handle: Arc::new(Mutex::new(Some(handle))), + } } - } +} + +impl From>> for NapiStreamCloser { + fn from(handle: JoinHandle>) -> Self { + NapiStreamCloser::new(handle) + } +} - #[napi] - pub fn end(&self) { - if let Ok(mut close_fn_option) = self.close_fn.lock() { - let _ = close_fn_option.take().map(|close_fn| close_fn.send(())); +#[napi] +impl NapiStreamCloser { + /// Signal the stream to end + /// Does not wait for the stream to end. + pub fn end(&self) { + self.abort_handle.abort(); } - } - #[napi] - pub fn is_closed(&self) -> bool { - self.is_closed_atomic.load(Ordering::Relaxed) - } + /// End the stream and `await` for it to shutdown + /// Returns the `Result` of the task. + pub async fn end_and_wait(&self) -> Result<(), Error> { + if self.abort_handle.is_finished() { + return Ok(()); + } + + let mut handle = self.handle.lock().await; + let handle = handle.take(); + if let Some(h) = handle { + h.abort(); + let join_result = h.await; + if matches!(join_result, Err(ref e) if !e.is_cancelled()) { + return Err(Error::from_reason( + format!("subscription event loop join error {}", join_result.unwrap_err()) + )); + } + } else { + log::warn!("subscription already closed"); + } + Ok(()) + } + + /// Checks if this stream is closed + pub fn is_closed(&self) -> bool { + self.abort_handle.is_finished() + } }