From 79b6c7bf2093550d1c8892659af76ebe9f6cab96 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Tue, 16 Jul 2024 16:33:56 -0400 Subject: [PATCH] Refactor of stream_all_messages, fix flaky stream tests (#835) * quick pass at stream_all_msg refactor * tests pass * try replacing StreamCloser with JoinHandle * use Notify * lose all messages test is definitely still flaky --- bindings_ffi/Cargo.lock | 99 +++++- bindings_ffi/Cargo.toml | 2 + bindings_ffi/src/mls.rs | 358 ++++++++++---------- bindings_ffi/src/v2.rs | 12 +- bindings_node/Cargo.lock | 1 + bindings_node/Cargo.toml | 1 + bindings_node/src/conversations.rs | 25 +- bindings_node/src/groups.rs | 8 +- bindings_node/src/streams.rs | 66 +++- xmtp_mls/Cargo.toml | 2 + xmtp_mls/src/groups/subscriptions.rs | 15 +- xmtp_mls/src/groups/sync.rs | 5 + xmtp_mls/src/subscriptions.rs | 483 +++++++++++++++------------ xmtp_mls/src/utils/test.rs | 27 ++ 14 files changed, 664 insertions(+), 440 deletions(-) diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index 9c6dbe3ad..391d1f797 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -2582,6 +2582,15 @@ version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matchit" version = "0.7.3" @@ -2706,6 +2715,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -2970,6 +2989,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "p256" version = "0.13.2" @@ -3707,10 +3732,19 @@ checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", - "regex-automata", + "regex-automata 0.4.4", "regex-syntax 0.8.2", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + [[package]] name = "regex-automata" version = "0.4.4" @@ -3722,6 +3756,12 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.7.5" @@ -4349,6 +4389,15 @@ dependencies = [ "keccak", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -4673,6 +4722,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.31" @@ -5051,6 +5110,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", ] [[package]] @@ -5063,6 +5123,35 @@ dependencies = [ "tracing", ] +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -5359,6 +5448,12 @@ dependencies = [ "rand", ] +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15" @@ -5903,6 +5998,7 @@ dependencies = [ "thiserror", "tls_codec 0.4.0", "tokio", + "tokio-stream", "toml 0.8.8", "tracing", "xmtp_cryptography", @@ -5970,6 +6066,7 @@ dependencies = [ "thread-id", "tokio", "tokio-test", + "tracing-subscriber", "uniffi", "uniffi_macros", "uuid 1.9.1", diff --git a/bindings_ffi/Cargo.toml b/bindings_ffi/Cargo.toml index 2d58b716e..562874b0c 100644 --- a/bindings_ffi/Cargo.toml +++ b/bindings_ffi/Cargo.toml @@ -23,6 +23,7 @@ xmtp_proto = { path = "../xmtp_proto", features = ["proto_full", "grpc"] } xmtp_user_preferences = { path = "../xmtp_user_preferences" } xmtp_v2 = { path = "../xmtp_v2" } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } # NOTE: A regression in openssl-sys exists where libatomic is dynamically linked # for i686-linux-android targets. https://github.com/sfackler/rust-openssl/issues/2163 # @@ -46,6 +47,7 @@ tempfile = "3.5.0" tokio = { version = "1.28.1", features = ["full"] } tokio-test = "0.4" uniffi = { version = "0.27.2", features = ["bindgen-tests"] } +tracing-subscriber = "0.3" uuid = { version = "1.9", features = ["v4", "fast-rng" ] } # NOTE: The release profile reduces bundle size from 230M to 41M - may have performance impliciations diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 3fd948935..26145dc51 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -4,11 +4,8 @@ use crate::logger::FfiLogger; use crate::GenericError; use std::collections::HashMap; use std::convert::TryInto; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, -}; -use tokio::sync::oneshot::Sender; +use std::sync::Arc; +use tokio::{sync::Mutex, task::AbortHandle}; use xmtp_api_grpc::grpc_api_helper::Client as TonicApiClient; use xmtp_id::{ associations::{ @@ -32,6 +29,7 @@ use xmtp_mls::{ api::ApiClientWrapper, builder::ClientBuilder, client::Client as MlsClient, + client::ClientError, groups::{ group_metadata::{ConversationType, GroupMetadata}, group_permissions::GroupMutablePermissions, @@ -44,6 +42,7 @@ use xmtp_mls::{ group_message::{DeliveryStatus, GroupMessageKind, StoredGroupMessage}, EncryptedMessageStore, EncryptionKey, StorageOption, }, + subscriptions::StreamHandle, }; pub type RustXmtpClient = MlsClient; @@ -175,8 +174,7 @@ pub fn generate_inbox_id(account_address: String, nonce: u64) -> String { #[derive(uniffi::Object)] pub struct FfiSignatureRequest { - // Using `tokio::sync::Mutex` bc rust MutexGuard cannot be sent between threads. - inner: Arc>, + inner: Arc>, } #[uniffi::export(async_runtime = "tokio")] @@ -305,7 +303,7 @@ impl FfiXmtpClient { .signature_request() .map(|request| { Arc::new(FfiSignatureRequest { - inner: Arc::new(tokio::sync::Mutex::new(request)), + inner: Arc::new(Mutex::new(request)), }) }) } @@ -616,43 +614,30 @@ impl FfiConversations { Ok(convo_list) } - pub async fn stream( - &self, - callback: Box, - ) -> Result, GenericError> { + pub fn stream(&self, callback: Box) -> FfiStreamCloser { let client = self.inner_client.clone(); - let stream_closer = RustXmtpClient::stream_conversations_with_callback( - client.clone(), - move |convo| { + let handle = + RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| { callback.on_conversation(Arc::new(FfiGroup { inner_client: client.clone(), group_id: convo.group_id, created_at_ns: convo.created_at_ns, })) - }, - || {}, // on_close_callback - )?; + }); - Ok(Arc::new(FfiStreamCloser { - close_fn: stream_closer.close_fn, - is_closed_atomic: stream_closer.is_closed_atomic, - })) + FfiStreamCloser::new(handle) } - pub async fn stream_all_messages( + pub fn stream_all_messages( &self, message_callback: Box, - ) -> Result, GenericError> { - let stream_closer = RustXmtpClient::stream_all_messages_with_callback( + ) -> FfiStreamCloser { + let handle = RustXmtpClient::stream_all_messages_with_callback( self.inner_client.clone(), move |message| message_callback.on_message(message.into()), - ) - .await?; + ); - Ok(Arc::new(FfiStreamCloser { - close_fn: stream_closer.close_fn, - is_closed_atomic: stream_closer.is_closed_atomic, - })) + FfiStreamCloser::new(handle) } } @@ -1131,22 +1116,16 @@ impl FfiGroup { Ok(()) } - pub async fn stream( - &self, - message_callback: Box, - ) -> Result, GenericError> { + pub fn stream(&self, message_callback: Box) -> FfiStreamCloser { let inner_client = Arc::clone(&self.inner_client); - let stream_closer = MlsGroup::stream_with_callback( + let handle = MlsGroup::stream_with_callback( inner_client, self.group_id.clone(), self.created_at_ns, move |message| message_callback.on_message(message.into()), - )?; + ); - Ok(Arc::new(FfiStreamCloser { - close_fn: stream_closer.close_fn, - is_closed_atomic: stream_closer.is_closed_atomic, - })) + FfiStreamCloser::new(handle) } pub fn created_at_ns(&self) -> i64 { @@ -1261,27 +1240,67 @@ impl From for FfiMessage { } } -#[derive(uniffi::Object)] +#[derive(uniffi::Object, Clone, Debug)] pub struct FfiStreamCloser { - close_fn: Arc>>>, - is_closed_atomic: Arc, + #[allow(clippy::type_complexity)] + stream_handle: Arc>>>>, + // for convenience, does not require locking mutex. + abort_handle: Arc, +} + +impl FfiStreamCloser { + pub fn new(stream_handle: StreamHandle>) -> Self { + Self { + abort_handle: Arc::new(stream_handle.handle.abort_handle()), + stream_handle: Arc::new(Mutex::new(Some(stream_handle))), + } + } + + #[cfg(test)] + pub async fn wait_for_ready(&self) { + let mut handle = self.stream_handle.lock().await; + if let Some(ref mut h) = &mut *handle { + h.wait_for_ready().await; + } + } } #[uniffi::export] impl FfiStreamCloser { + /// Signal the stream to end + /// Does not wait for the stream to end. pub fn end(&self) { - match self.close_fn.lock() { - Ok(mut close_fn_option) => { - let _ = close_fn_option.take().map(|close_fn| close_fn.send(())); - } - _ => { - log::warn!("close_fn already closed"); + self.abort_handle.abort(); + } + + /// End the stream and asyncronously wait for it to shutdown + pub async fn end_and_wait(&self) -> Result<(), GenericError> { + if self.abort_handle.is_finished() { + return Ok(()); + } + + let mut stream_handle = self.stream_handle.lock().await; + let stream_handle = stream_handle.take(); + if let Some(h) = stream_handle { + h.handle.abort(); + match h.handle.await { + Err(e) if !e.is_cancelled() => Err(GenericError::Generic { + err: format!("subscription event loop join error {}", e), + }), + Err(e) if e.is_cancelled() => Ok(()), + Ok(t) => t.map_err(|e| GenericError::Generic { err: e.to_string() }), + Err(e) => Err(GenericError::Generic { + err: format!("error joining task {}", e), + }), } + } else { + log::warn!("subscription already closed"); + Ok(()) } } pub fn is_closed(&self) -> bool { - self.is_closed_atomic.load(Ordering::Relaxed) + self.abort_handle.is_finished() } } @@ -1360,15 +1379,15 @@ impl FfiGroupPermissions { mod tests { use crate::{ get_inbox_id_for_address, inbox_owner::SigningError, logger::FfiLogger, - FfiConversationCallback, FfiCreateGroupOptions, FfiGroupPermissionsOptions, FfiInboxOwner, - FfiListConversationsOptions, FfiListMessagesOptions, FfiMetadataField, FfiPermissionPolicy, - FfiPermissionPolicySet, FfiPermissionUpdateType, + FfiConversationCallback, FfiCreateGroupOptions, FfiGroup, FfiGroupPermissionsOptions, + FfiInboxOwner, FfiListConversationsOptions, FfiListMessagesOptions, FfiMetadataField, + FfiPermissionPolicy, FfiPermissionPolicySet, FfiPermissionUpdateType, }; use std::{ env, sync::{ atomic::{AtomicU32, Ordering}, - Arc, + Arc, Mutex, }, }; @@ -1378,6 +1397,7 @@ mod tests { self, distributions::{Alphanumeric, DistString}, }; + use tokio::{sync::Notify, time::error::Elapsed}; use xmtp_cryptography::{signature::RecoverableSignature, utils::rng}; use xmtp_id::associations::generate_inbox_id; use xmtp_mls::{storage::EncryptionKey, InboxOwner}; @@ -1417,36 +1437,48 @@ mod tests { } } - #[derive(Clone)] + #[derive(Default, Clone)] struct RustStreamCallback { num_messages: Arc, + messages: Arc>>, + conversations: Arc>>>, + notify: Arc, } impl RustStreamCallback { - pub fn new() -> Self { - Self { - num_messages: Arc::new(AtomicU32::new(0)), - } - } - pub fn message_count(&self) -> u32 { self.num_messages.load(Ordering::SeqCst) } + + pub async fn wait_for_delivery(&self) -> Result<(), Elapsed> { + tokio::time::timeout(std::time::Duration::from_secs(60), async { + self.notify.notified().await + }) + .await?; + Ok(()) + } } impl FfiMessageCallback for RustStreamCallback { fn on_message(&self, message: FfiMessage) { - println!("Got a message"); - let message = String::from_utf8(message.content).unwrap_or("".to_string()); - log::info!("Received: {}", message); + let mut messages = self.messages.lock().unwrap(); + log::info!( + "ON MESSAGE Received\n-------- \n{}\n----------", + String::from_utf8_lossy(&message.content) + ); + messages.push(message); let _ = self.num_messages.fetch_add(1, Ordering::SeqCst); + self.notify.notify_one(); } } impl FfiConversationCallback for RustStreamCallback { - fn on_conversation(&self, _: Arc) { - println!("received new conversation"); + fn on_conversation(&self, group: Arc) { + log::debug!("received conversation"); let _ = self.num_messages.fetch_add(1, Ordering::SeqCst); + let mut convos = self.conversations.lock().unwrap(); + convos.push(group); + self.notify.notify_one(); } } @@ -1974,13 +2006,11 @@ mod tests { let bo = new_test_client().await; // Stream all group messages - let message_callbacks = RustStreamCallback::new(); + let message_callbacks = RustStreamCallback::default(); let stream_messages = bo .conversations() - .stream_all_messages(Box::new(message_callbacks.clone())) - .await - .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + .stream_all_messages(Box::new(message_callbacks.clone())); + stream_messages.wait_for_ready().await; // Create group and send first message let alix_group = alix @@ -1992,12 +2022,11 @@ mod tests { .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - alix_group .update_group_name("Old Name".to_string()) .await .unwrap(); + message_callbacks.wait_for_delivery().await.unwrap(); let bo_groups = bo .conversations() @@ -2010,41 +2039,35 @@ mod tests { .update_group_name("Old Name2".to_string()) .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + message_callbacks.wait_for_delivery().await.unwrap(); // Uncomment the following lines to add more group name updates - // alix_group.update_group_name("Again Name".to_string()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - bo_group - .update_group_name("Old Name2".to_string()) + .update_group_name("Old Name3".to_string()) .await .unwrap(); + message_callbacks.wait_for_delivery().await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; assert_eq!(message_callbacks.message_count(), 3); - stream_messages.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + stream_messages.end_and_wait().await.unwrap(); + assert!(stream_messages.is_closed()); } // test is also showing intermittent failures with database locked msg - #[tokio::test(flavor = "multi_thread", worker_threads = 5)] #[ignore] + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn test_can_stream_and_update_name_without_forking_group() { let alix = new_test_client().await; let bo = new_test_client().await; // Stream all group messages - let message_callbacks = RustStreamCallback::new(); + let message_callbacks = RustStreamCallback::default(); let stream_messages = bo .conversations() - .stream_all_messages(Box::new(message_callbacks.clone())) - .await - .unwrap(); - - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + .stream_all_messages(Box::new(message_callbacks.clone())); + stream_messages.wait_for_ready().await; let first_msg_check = 2; let second_msg_check = 5; @@ -2059,16 +2082,13 @@ mod tests { .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - alix_group .update_group_name("hello".to_string()) .await .unwrap(); + message_callbacks.wait_for_delivery().await.unwrap(); alix_group.send("hello1".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - bo.conversations().sync().await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + message_callbacks.wait_for_delivery().await.unwrap(); let bo_groups = bo .conversations() @@ -2085,49 +2105,41 @@ mod tests { assert_eq!(bo_messages1.len(), first_msg_check); bo_group.send("hello2".as_bytes().to_vec()).await.unwrap(); + message_callbacks.wait_for_delivery().await.unwrap(); bo_group.send("hello3".as_bytes().to_vec()).await.unwrap(); + message_callbacks.wait_for_delivery().await.unwrap(); + alix_group.sync().await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; let alix_messages = alix_group .find_messages(FfiListMessagesOptions::default()) .unwrap(); assert_eq!(alix_messages.len(), second_msg_check); alix_group.send("hello4".as_bytes().to_vec()).await.unwrap(); + message_callbacks.wait_for_delivery().await.unwrap(); bo_group.sync().await.unwrap(); let bo_messages2 = bo_group .find_messages(FfiListMessagesOptions::default()) .unwrap(); assert_eq!(bo_messages2.len(), second_msg_check); + assert_eq!(message_callbacks.message_count(), second_msg_check as u32); - // TODO: message_callbacks should eventually come through here, why does this - // not work? - // tokio::time::sleep(tokio::time::Duration::from_millis(10000)).await; - // assert_eq!(message_callbacks.message_count(), second_msg_check as u32); - - stream_messages.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + stream_messages.end_and_wait().await.unwrap(); assert!(stream_messages.is_closed()); } #[tokio::test(flavor = "multi_thread", worker_threads = 5)] - // This one is flaky for me. Passes reliably locally and fails on CI - #[ignore] async fn test_conversation_streaming() { let amal = new_test_client().await; let bola = new_test_client().await; - let stream_callback = RustStreamCallback::new(); + let stream_callback = RustStreamCallback::default(); let stream = bola .conversations() - .stream(Box::new(stream_callback.clone())) - .await - .unwrap(); - - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + .stream(Box::new(stream_callback.clone())); amal.conversations() .create_group( @@ -2137,7 +2149,7 @@ mod tests { .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + stream_callback.wait_for_delivery().await.unwrap(); assert_eq!(stream_callback.message_count(), 1); // Create another group and add bola @@ -2148,12 +2160,11 @@ mod tests { ) .await .unwrap(); + stream_callback.wait_for_delivery().await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; assert_eq!(stream_callback.message_count(), 2); - stream.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + stream.end_and_wait().await.unwrap(); assert!(stream.is_closed()); } @@ -2171,19 +2182,17 @@ mod tests { ) .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let stream_callback = RustStreamCallback::new(); + let stream_callback = RustStreamCallback::default(); let stream = caro .conversations() - .stream_all_messages(Box::new(stream_callback.clone())) - .await - .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + .stream_all_messages(Box::new(stream_callback.clone())); + stream.wait_for_ready().await; alix_group.send("first".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + stream_callback.wait_for_delivery().await.unwrap(); + let bo_group = bo .conversations() .create_group( @@ -2192,27 +2201,26 @@ mod tests { ) .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + let _ = caro.inner_client.sync_welcomes().await.unwrap(); + bo_group.send("second".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + stream_callback.wait_for_delivery().await.unwrap(); alix_group.send("third".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + stream_callback.wait_for_delivery().await.unwrap(); bo_group.send("fourth".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; + stream_callback.wait_for_delivery().await.unwrap(); assert_eq!(stream_callback.message_count(), 4); - stream.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + stream.end_and_wait().await.unwrap(); assert!(stream.is_closed()); } - #[tokio::test(flavor = "multi_thread", worker_threads = 5)] - #[ignore] + #[tokio::test(flavor = "multi_thread")] async fn test_message_streaming() { let amal = new_test_client().await; let bola = new_test_client().await; - let group = amal + let amal_group: Arc = amal .conversations() .create_group( vec![bola.account_address.clone()], @@ -2221,19 +2229,25 @@ mod tests { .await .unwrap(); - let stream_callback = RustStreamCallback::new(); - let stream_closer = group - .stream(Box::new(stream_callback.clone())) + bola.inner_client.sync_welcomes().await.unwrap(); + let bola_group = bola.group(amal_group.group_id.clone()).unwrap(); + + let stream_callback = RustStreamCallback::default(); + let stream_closer = bola_group.stream(Box::new(stream_callback.clone())); + + stream_closer.wait_for_ready().await; + + amal_group.send("hello".as_bytes().to_vec()).await.unwrap(); + stream_callback.wait_for_delivery().await.unwrap(); + + amal_group + .send("goodbye".as_bytes().to_vec()) .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; + stream_callback.wait_for_delivery().await.unwrap(); - group.send("hello".as_bytes().to_vec()).await.unwrap(); - group.send("goodbye".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(250)).await; assert_eq!(stream_callback.message_count(), 2); - - stream_closer.end(); + stream_closer.end_and_wait().await.unwrap(); } #[tokio::test(flavor = "multi_thread", worker_threads = 5)] @@ -2254,21 +2268,18 @@ mod tests { ) .await .unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let stream_callback = RustStreamCallback::new(); + let stream_callback = RustStreamCallback::default(); let stream_closer = bola .conversations() - .stream_all_messages(Box::new(stream_callback.clone())) - .await - .unwrap(); + .stream_all_messages(Box::new(stream_callback.clone())); + stream_closer.wait_for_ready().await; - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + amal_group.send(b"hello1".to_vec()).await.unwrap(); + stream_callback.wait_for_delivery().await.unwrap(); + amal_group.send(b"hello2".to_vec()).await.unwrap(); + stream_callback.wait_for_delivery().await.unwrap(); - amal_group.send("hello1".as_bytes().to_vec()).await.unwrap(); - amal_group.send("hello2".as_bytes().to_vec()).await.unwrap(); - - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; assert_eq!(stream_callback.message_count(), 2); assert!(!stream_closer.is_closed()); @@ -2276,29 +2287,30 @@ mod tests { .remove_members_by_inbox_id(vec![bola.inbox_id().clone()]) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(2000)).await; + stream_callback.wait_for_delivery().await.unwrap(); assert_eq!(stream_callback.message_count(), 3); // Member removal transcript message - - amal_group.send("hello3".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + // + amal_group.send(b"hello3".to_vec()).await.unwrap(); + //TODO: could verify with a log message + tokio::time::sleep(std::time::Duration::from_millis(200)).await; assert_eq!(stream_callback.message_count(), 3); // Don't receive messages while removed assert!(!stream_closer.is_closed()); - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; amal_group .add_members(vec![bola.account_address.clone()]) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(500)).await; + + // TODO: could check for LOG message with a Eviction error on receive + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; assert_eq!(stream_callback.message_count(), 3); // Don't receive transcript messages while removed amal_group.send("hello4".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + stream_callback.wait_for_delivery().await.unwrap(); assert_eq!(stream_callback.message_count(), 4); // Receiving messages again assert!(!stream_closer.is_closed()); - stream_closer.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + stream_closer.end_and_wait().await.unwrap(); assert!(stream_closer.is_closed()); } @@ -2353,21 +2365,14 @@ mod tests { let bo = new_test_client().await; // Stream all group messages - let message_callbacks = RustStreamCallback::new(); - let group_callbacks = RustStreamCallback::new(); - let stream_groups = bo - .conversations() - .stream(Box::new(group_callbacks.clone())) - .await - .unwrap(); + let message_callback = RustStreamCallback::default(); + let group_callback = RustStreamCallback::default(); + let stream_groups = bo.conversations().stream(Box::new(group_callback.clone())); let stream_messages = bo .conversations() - .stream_all_messages(Box::new(message_callbacks.clone())) - .await - .unwrap(); - - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + .stream_all_messages(Box::new(message_callback.clone())); + stream_messages.wait_for_ready().await; // Create group and send first message let alix_group = alix @@ -2378,19 +2383,18 @@ mod tests { ) .await .unwrap(); + group_callback.wait_for_delivery().await.unwrap(); alix_group.send("hello1".as_bytes().to_vec()).await.unwrap(); - tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + message_callback.wait_for_delivery().await.unwrap(); - assert_eq!(group_callbacks.message_count(), 1); - assert_eq!(message_callbacks.message_count(), 1); + assert_eq!(group_callback.message_count(), 1); + assert_eq!(message_callback.message_count(), 1); - stream_messages.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + stream_messages.end_and_wait().await.unwrap(); assert!(stream_messages.is_closed()); - stream_groups.end(); - tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + stream_groups.end_and_wait().await.unwrap(); assert!(stream_groups.is_closed()); } diff --git a/bindings_ffi/src/v2.rs b/bindings_ffi/src/v2.rs index 932650f55..afee41745 100644 --- a/bindings_ffi/src/v2.rs +++ b/bindings_ffi/src/v2.rs @@ -312,9 +312,15 @@ impl FfiV2Subscription { let handle = handle.take(); if let Some(h) = handle { h.abort(); - h.await.map_err(|_| GenericError::Generic { - err: "subscription event loop join error".into(), - })?; + let join_result = h.await; + if matches!(join_result, Err(ref e) if !e.is_cancelled()) { + return Err(GenericError::Generic { + err: format!( + "subscription event loop join error {}", + join_result.unwrap_err() + ), + }); + } } Ok(()) } diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index d50511cea..a9c67462e 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -300,6 +300,7 @@ version = "0.0.1" dependencies = [ "futures", "hex", + "log", "napi", "napi-build", "napi-derive", diff --git a/bindings_node/Cargo.toml b/bindings_node/Cargo.toml index 4b4ba232e..c31c75cd8 100644 --- a/bindings_node/Cargo.toml +++ b/bindings_node/Cargo.toml @@ -25,6 +25,7 @@ xmtp_mls = { path = "../xmtp_mls", features = ["grpc", "native"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } xmtp_id = { path = "../xmtp_id" } rand = "0.8.5" +log = { version = "0.4", features = ["release_max_level_debug"] } [build-dependencies] napi-build = "2.0.1" diff --git a/bindings_node/src/conversations.rs b/bindings_node/src/conversations.rs index 10240189d..8f167e22f 100644 --- a/bindings_node/src/conversations.rs +++ b/bindings_node/src/conversations.rs @@ -186,9 +186,8 @@ impl NapiConversations { let tsfn: ThreadsafeFunction = callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; let client = self.inner_client.clone(); - let stream_closer = RustXmtpClient::stream_conversations_with_callback( - client.clone(), - move |convo| { + let stream_closer = + RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| { tsfn.call( Ok(NapiGroup::new( client.clone(), @@ -197,32 +196,22 @@ impl NapiConversations { )), ThreadsafeFunctionCallMode::Blocking, ); - }, - || {}, // on_close_callback - ) - .map_err(|e| Error::from_reason(format!("{}", e)))?; + }); - Ok(NapiStreamCloser::new( - stream_closer.close_fn, - stream_closer.is_closed_atomic, - )) + Ok(NapiStreamCloser::new(stream_closer)) } #[napi(ts_args_type = "callback: (err: null | Error, result: NapiMessage) => void")] pub fn stream_all_messages(&self, callback: JsFunction) -> Result { let tsfn: ThreadsafeFunction = callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; - let stream_closer = RustXmtpClient::stream_all_messages_with_callback_sync( + let stream_closer = RustXmtpClient::stream_all_messages_with_callback( self.inner_client.clone(), move |message| { tsfn.call(Ok(message.into()), ThreadsafeFunctionCallMode::Blocking); }, - ) - .map_err(|e| Error::from_reason(format!("{}", e)))?; + ); - Ok(NapiStreamCloser::new( - stream_closer.close_fn, - stream_closer.is_closed_atomic, - )) + Ok(NapiStreamCloser::new(stream_closer)) } } diff --git a/bindings_node/src/groups.rs b/bindings_node/src/groups.rs index e5c118d66..a08065062 100644 --- a/bindings_node/src/groups.rs +++ b/bindings_node/src/groups.rs @@ -562,13 +562,9 @@ impl NapiGroup { move |message| { tsfn.call(Ok(message.into()), ThreadsafeFunctionCallMode::Blocking); }, - ) - .map_err(|e| Error::from_reason(format!("{}", e)))?; + ); - Ok(NapiStreamCloser::new( - stream_closer.close_fn, - stream_closer.is_closed_atomic, - )) + Ok(stream_closer.into()) } #[napi] diff --git a/bindings_node/src/streams.rs b/bindings_node/src/streams.rs index 0d82382db..c12301239 100644 --- a/bindings_node/src/streams.rs +++ b/bindings_node/src/streams.rs @@ -1,35 +1,73 @@ -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, -}; -use tokio::sync::oneshot::Sender; +use napi::bindgen_prelude::Error; +use std::sync::Arc; +use tokio::{sync::Mutex, task::AbortHandle}; +use xmtp_mls::{client::ClientError, subscriptions::StreamHandle}; use napi_derive::napi; #[napi] pub struct NapiStreamCloser { - close_fn: Arc>>>, - is_closed_atomic: Arc, + #[allow(clippy::type_complexity)] + handle: Arc>>>>, + // for convenience, does not require locking mutex. + abort_handle: Arc, } -#[napi] impl NapiStreamCloser { - pub fn new(close_fn: Arc>>>, is_closed_atomic: Arc) -> Self { + pub fn new(handle: StreamHandle>) -> Self { Self { - close_fn, - is_closed_atomic, + abort_handle: Arc::new(handle.handle.abort_handle()), + handle: Arc::new(Mutex::new(Some(handle))), } } +} + +impl From>> for NapiStreamCloser { + fn from(handle: StreamHandle>) -> Self { + NapiStreamCloser::new(handle) + } +} +#[napi] +impl NapiStreamCloser { + /// Signal the stream to end + /// Does not wait for the stream to end. #[napi] pub fn end(&self) { - if let Ok(mut close_fn_option) = self.close_fn.lock() { - let _ = close_fn_option.take().map(|close_fn| close_fn.send(())); + self.abort_handle.abort(); + } + + /// End the stream and `await` for it to shutdown + /// Returns the `Result` of the task. + #[napi] + /// End the stream and asyncronously wait for it to shutdown + pub async fn end_and_wait(&self) -> Result<(), Error> { + if self.abort_handle.is_finished() { + return Ok(()); + } + + let mut stream_handle = self.handle.lock().await; + let stream_handle = stream_handle.take(); + if let Some(h) = stream_handle { + h.handle.abort(); + match h.handle.await { + Err(e) if !e.is_cancelled() => Err(Error::from_reason(format!( + "subscription event loop join error {}", + e + ))), + Err(e) if e.is_cancelled() => Ok(()), + Ok(t) => t.map_err(|e| Error::from_reason(e.to_string())), + Err(e) => Err(Error::from_reason(format!("error joining task {}", e))), + } + } else { + log::warn!("subscription already closed"); + Ok(()) } } + /// Checks if this stream is closed #[napi] pub fn is_closed(&self) -> bool { - self.is_closed_atomic.load(Ordering::Relaxed) + self.abort_handle.is_finished() } } diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index 26ac6b9ae..e7e0b103a 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -71,6 +71,8 @@ flume = "0.11" mockall = "0.11.4" mockito = "1.4.0" tempfile = "3.5.0" +tracing.workspace = true +tracing-subscriber.workspace = true tracing-log = "0.2.0" tracing-test = "0.2.4" xmtp_api_grpc = { path = "../xmtp_api_grpc" } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 1ee1bbd5a..e5c6b4099 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -5,11 +5,10 @@ use std::sync::Arc; use futures::Stream; use super::{extract_message_v1, GroupError, MlsGroup}; -use crate::retry::Retry; use crate::storage::group_message::StoredGroupMessage; -use crate::subscriptions::{MessagesStreamInfo, StreamCloser}; +use crate::subscriptions::{MessagesStreamInfo, StreamHandle}; use crate::XmtpApi; -use crate::{retry_async, Client}; +use crate::{retry::Retry, retry_async, Client}; use prost::Message; use xmtp_proto::xmtp::mls::api::v1::GroupMessage; @@ -66,8 +65,10 @@ impl MlsGroup { }) ); - if let Some(GroupError::ReceiveError(_)) = process_result.err() { + if let Some(GroupError::ReceiveError(_)) = process_result.as_ref().err() { self.sync(&client).await?; + } else if process_result.is_err() { + log::error!("Process stream entry {:?}", process_result.err()); } // Load the message from the DB to handle cases where it may have been already processed in @@ -119,11 +120,11 @@ impl MlsGroup { group_id: Vec, created_at_ns: i64, callback: impl FnMut(StoredGroupMessage) + Send + 'static, - ) -> Result + ) -> StreamHandle> where ApiClient: crate::XmtpApi, { - Ok(Client::::stream_messages_with_callback( + Client::::stream_messages_with_callback( client, HashMap::from([( group_id, @@ -133,7 +134,7 @@ impl MlsGroup { }, )]), callback, - )?) + ) } } diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index e4fe80643..ef059cd66 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -762,6 +762,11 @@ impl MlsGroup { sha256(payload_slice), post_commit_data, )?; + log::debug!( + "client [{}] set stored intent [{}] to state `published`", + client.inbox_id(), + intent.id + ); } Ok(()) diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index d30ff458c..490a8193f 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -1,16 +1,12 @@ -use std::{ - collections::HashMap, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, - }, -}; +use std::{collections::HashMap, pin::Pin, sync::Arc}; -use futures::{Stream, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; use prost::Message; -use tokio::sync::oneshot::{self, Sender}; -use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio::{ + sync::{mpsc, oneshot}, + task::JoinHandle, +}; +use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, UnboundedReceiverStream}; use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage; use crate::{ @@ -19,10 +15,18 @@ use crate::{ groups::{extract_group_id, GroupError, MlsGroup}, retry::Retry, retry_async, - storage::group_message::StoredGroupMessage, + storage::{group::StoredGroup, group_message::StoredGroupMessage}, Client, XmtpApi, }; +#[derive(Debug)] +/// Wrapper around a [`tokio::task::JoinHandle`] but with a oneshot receiver +/// which allows waiting for a `with_callback` stream fn to be ready for stream items. +pub struct StreamHandle { + pub handle: JoinHandle, + start: Option>, +} + /// Events local to this client /// are broadcast across all senders/receivers of streams #[derive(Clone, Debug)] @@ -31,35 +35,39 @@ pub(crate) enum LocalEvents { NewGroup(MlsGroup), } -// TODO simplify FfiStreamCloser + StreamCloser duplication -pub struct StreamCloser { - pub close_fn: Arc>>>, - pub is_closed_atomic: Arc, -} - -impl StreamCloser { - pub fn end(&self) { - match self.close_fn.lock() { - Ok(mut close_fn_option) => { - let _ = close_fn_option.take().map(|close_fn| close_fn.send(())); - } - _ => { - log::warn!("close_fn already closed"); - } +impl StreamHandle { + /// Waits for the stream to be fully spawned + pub async fn wait_for_ready(&mut self) { + if let Some(s) = self.start.take() { + let _ = s.await; } } +} - pub fn is_closed(&self) -> bool { - self.is_closed_atomic.load(Ordering::Relaxed) +impl From> for JoinHandle { + fn from(stream: StreamHandle) -> JoinHandle { + stream.handle } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub(crate) struct MessagesStreamInfo { pub convo_created_at_ns: i64, pub cursor: u64, } +impl From for (Vec, MessagesStreamInfo) { + fn from(group: StoredGroup) -> (Vec, MessagesStreamInfo) { + ( + group.id, + MessagesStreamInfo { + convo_created_at_ns: group.created_at_ns, + cursor: 0, + }, + ) + } +} + impl Client where ApiClient: XmtpApi, @@ -144,20 +152,19 @@ where let subscription = self .api_client - .subscribe_welcome_messages(installation_key, Some(id_cursor as u64)) + .subscribe_welcome_messages(installation_key, Some(id_cursor)) .await?; let stream = subscription - .map(|welcome_result| async { + .map(|welcome| async { log::info!("Received conversation streaming payload"); - let welcome = welcome_result?; - self.process_streamed_welcome(welcome).await + self.process_streamed_welcome(welcome?).await }) .filter_map(|res| async { match res.await { Ok(group) => Some(group), Err(err) => { - log::error!("Error processing stream entry: {:?}", err); + log::error!("Error processing stream entry for conversation: {:?}", err); None } } @@ -166,6 +173,7 @@ where Ok(Box::pin(futures::stream::select(stream, event_queue))) } + #[tracing::instrument(skip(self, group_id_to_info))] pub(crate) async fn stream_messages( self: Arc, group_id_to_info: HashMap, MessagesStreamInfo>, @@ -228,166 +236,159 @@ where pub fn stream_conversations_with_callback( client: Arc>, mut convo_callback: impl FnMut(MlsGroup) + Send + 'static, - mut on_close_callback: impl FnMut() + Send + 'static, - ) -> Result { - let (close_sender, close_receiver) = oneshot::channel::<()>(); - let is_closed = Arc::new(AtomicBool::new(false)); - let is_closed_clone = is_closed.clone(); + ) -> StreamHandle> { + let (tx, rx) = oneshot::channel(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { let mut stream = client.stream_conversations().await.unwrap(); - let mut close_receiver = close_receiver; - loop { - tokio::select! { - item = stream.next() => { - match item { - Some(convo) => { convo_callback(convo) }, - None => break - } - } - _ = &mut close_receiver => { - on_close_callback(); - break; - } - } + let _ = tx.send(()); + while let Some(convo) = stream.next().await { + convo_callback(convo) } - is_closed_clone.store(true, Ordering::Relaxed); - log::info!("closing stream"); + Ok(()) }); - Ok(StreamCloser { - close_fn: Arc::new(Mutex::new(Some(close_sender))), - is_closed_atomic: is_closed, - }) + StreamHandle { + start: Some(rx), + handle, + } } pub(crate) fn stream_messages_with_callback( client: Arc>, group_id_to_info: HashMap, MessagesStreamInfo>, mut callback: impl FnMut(StoredGroupMessage) + Send + 'static, - ) -> Result { - let (close_sender, close_receiver) = oneshot::channel::<()>(); - let is_closed = Arc::new(AtomicBool::new(false)); - - let is_closed_clone = is_closed.clone(); - tokio::spawn(async move { - let mut stream = Self::stream_messages(client, group_id_to_info) - .await - .unwrap(); - let mut close_receiver = close_receiver; - loop { - tokio::select! { - item = stream.next() => { - match item { - Some(message) => callback(message), - None => break - } - } - _ = &mut close_receiver => { - break; - } - } + ) -> StreamHandle> { + let (tx, rx) = oneshot::channel(); + + let handle = tokio::spawn(async move { + let mut stream = Self::stream_messages(client, group_id_to_info).await?; + let _ = tx.send(()); + while let Some(message) = stream.next().await { + callback(message) } - is_closed_clone.store(true, Ordering::Relaxed); - log::info!("closing stream"); + Ok(()) }); - Ok(StreamCloser { - close_fn: Arc::new(Mutex::new(Some(close_sender))), - is_closed_atomic: is_closed, - }) + StreamHandle { + start: Some(rx), + handle, + } } - pub async fn stream_all_messages_with_callback( + pub async fn stream_all_messages( client: Arc>, - callback: impl FnMut(StoredGroupMessage) + Send + Sync + 'static, - ) -> Result { - client.sync_welcomes().await?; // TODO pipe cursor from welcomes sync into groups_stream - Self::stream_all_messages_with_callback_sync(client, callback) - } + ) -> Result, ClientError> { + let (tx, rx) = mpsc::unbounded_channel(); - /// Requires a sync welcomes before use - pub fn stream_all_messages_with_callback_sync( - client: Arc>, - callback: impl FnMut(StoredGroupMessage) + Send + Sync + 'static, - ) -> Result { - let callback = Arc::new(Mutex::new(callback)); + client.sync_welcomes().await?; - let mut group_id_to_info: HashMap, MessagesStreamInfo> = client + let mut group_id_to_info = client .store() .conn()? .find_groups(None, None, None, None)? .into_iter() - .map(|group| { - ( - group.id.clone(), - MessagesStreamInfo { - convo_created_at_ns: group.created_at_ns, - cursor: 0, + .map(Into::into) + .collect::, MessagesStreamInfo>>(); + + tokio::spawn(async move { + let client = client.clone(); + let mut messages_stream = client + .clone() + .stream_messages(group_id_to_info.clone()) + .await?; + let mut convo_stream = Self::stream_conversations(&client).await?; + let mut extra_messages = Vec::new(); + + loop { + tokio::select! { + // biased enforces an order to select!. If a message and a group are both ready + // at the same time, `biased` mode will process the message before the new + // group. + biased; + + messages = futures::future::ready(&mut extra_messages), if !extra_messages.is_empty() => { + for message in messages.drain(0..) { + if tx.send(message).is_err() { + break; + } + } }, - ) - }) - .collect(); + Some(message) = messages_stream.next() => { + // an error can only mean the receiver has been dropped or closed so we're + // safe to end the stream + if tx.send(message).is_err() { + break; + } + } + Some(new_group) = convo_stream.next() => { + if tx.is_closed() { + break; + } + if group_id_to_info.contains_key(&new_group.group_id) { + continue; + } - let callback_clone = callback.clone(); - let messages_stream_closer_mutex = - Arc::new(Mutex::new(Self::stream_messages_with_callback( - client.clone(), - group_id_to_info.clone(), - move |message| callback_clone.lock().unwrap()(message), // TODO fix unwrap - )?)); - let messages_stream_closer_mutex_clone = messages_stream_closer_mutex.clone(); - let groups_stream_closer = Self::stream_conversations_with_callback( - client.clone(), - move |convo| { - // TODO make sure key comparison works correctly - if group_id_to_info.contains_key(&convo.group_id) { - return; - } - // Close existing message stream - // TODO remove unwrap - let mut messages_stream_closer = messages_stream_closer_mutex.lock().unwrap(); - messages_stream_closer.end(); - - // Set up new stream. For existing groups, stream new messages only by unsetting the cursor - for info in group_id_to_info.values_mut() { - info.cursor = 0; - } - group_id_to_info.insert( - convo.group_id, - MessagesStreamInfo { - convo_created_at_ns: convo.created_at_ns, - cursor: 1, // For the new group, stream all messages since the group was created + for info in group_id_to_info.values_mut() { + info.cursor = 0; + } + group_id_to_info.insert( + new_group.group_id, + MessagesStreamInfo { + convo_created_at_ns: new_group.created_at_ns, + cursor: 1, // For the new group, stream all messages since the group was created + }, + ); + let new_messages_stream = client.clone().stream_messages(group_id_to_info.clone()).await?; + + // attempt to drain all ready messages from existing stream + while let Some(Some(message)) = messages_stream.next().now_or_never() { + extra_messages.push(message); + } + let _ = std::mem::replace(&mut messages_stream, new_messages_stream); }, - ); - - // Open new message stream - let callback_clone = callback.clone(); - *messages_stream_closer = Self::stream_messages_with_callback( - client.clone(), - group_id_to_info.clone(), - move |message| callback_clone.lock().unwrap()(message), // TODO fix unwrap - ) - .unwrap(); // TODO fix unwrap - }, - move || { - messages_stream_closer_mutex_clone.lock().unwrap().end(); - }, - )?; + } + } + Ok::<_, ClientError>(()) + }); - Ok(groups_stream_closer) + Ok(UnboundedReceiverStream::new(rx)) + } + + pub fn stream_all_messages_with_callback( + client: Arc>, + mut callback: impl FnMut(StoredGroupMessage) + Send + Sync + 'static, + ) -> StreamHandle> { + let (tx, rx) = oneshot::channel(); + + let handle = tokio::spawn(async move { + let mut stream = Self::stream_all_messages(client).await?; + let _ = tx.send(()); + while let Some(message) = stream.next().await { + callback(message) + } + Ok(()) + }); + + StreamHandle { + start: Some(rx), + handle, + } } } #[cfg(test)] mod tests { + use crate::utils::test::Delivery; use crate::{ builder::ClientBuilder, groups::GroupMetadataOptions, storage::group_message::StoredGroupMessage, Client, }; use futures::StreamExt; - use std::sync::{Arc, Mutex}; - use tokio::sync::Notify; + use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }; use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient; use xmtp_cryptography::utils::generate_local_wallet; @@ -435,47 +436,49 @@ mod tests { let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); let messages_clone = messages.clone(); - let stream = Client::::stream_all_messages_with_callback( + + let notify = Delivery::new(); + let notify_pointer = notify.clone(); + let mut handle = Client::::stream_all_messages_with_callback( Arc::new(caro), move |message| { (*messages_clone.lock().unwrap()).push(message); + notify_pointer.notify_one(); }, - ) - .await - .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(50)).await; + ); + handle.wait_for_ready().await; alix_group .send_message("first".as_bytes(), &alix) .await .unwrap(); + notify.wait_for_delivery().await.unwrap(); bo_group .send_message("second".as_bytes(), &bo) .await .unwrap(); + notify.wait_for_delivery().await.unwrap(); alix_group .send_message("third".as_bytes(), &alix) .await .unwrap(); + notify.wait_for_delivery().await.unwrap(); bo_group .send_message("fourth".as_bytes(), &bo) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(200)).await; + notify.wait_for_delivery().await.unwrap(); let messages = messages.lock().unwrap(); - assert_eq!(messages[0].decrypted_message_bytes, "first".as_bytes()); - assert_eq!(messages[1].decrypted_message_bytes, "second".as_bytes()); - assert_eq!(messages[2].decrypted_message_bytes, "third".as_bytes()); - assert_eq!(messages[3].decrypted_message_bytes, "fourth".as_bytes()); - - stream.end(); + assert_eq!(messages[0].decrypted_message_bytes, b"first"); + assert_eq!(messages[1].decrypted_message_bytes, b"second"); + assert_eq!(messages[2].decrypted_message_bytes, b"third"); + assert_eq!(messages[3].decrypted_message_bytes, b"fourth"); } #[tokio::test(flavor = "multi_thread", worker_threads = 10)] async fn test_stream_all_messages_changing_group_list() { - let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; let caro = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -487,27 +490,22 @@ mod tests { .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); let messages_clone = messages.clone(); - let stream = + let delivery = Delivery::new(); + let delivery_pointer = delivery.clone(); + let mut handle = Client::::stream_all_messages_with_callback(caro.clone(), move |message| { - let text = String::from_utf8(message.decrypted_message_bytes.clone()) - .unwrap_or("".to_string()); - println!("Received: {}", text); + delivery_pointer.notify_one(); (*messages_clone.lock().unwrap()).push(message); - }) - .await - .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(50)).await; + }); + handle.wait_for_ready().await; alix_group .send_message("first".as_bytes(), &alix) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + delivery.wait_for_delivery().await.unwrap(); let bo_group = bo .create_group(None, GroupMetadataOptions::default()) @@ -516,19 +514,18 @@ mod tests { .add_members_by_inbox_id(&bo, vec![caro.inbox_id()]) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(300)).await; bo_group .send_message("second".as_bytes(), &bo) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + delivery.wait_for_delivery().await.unwrap(); alix_group .send_message("third".as_bytes(), &alix) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + delivery.wait_for_delivery().await.unwrap(); let alix_group_2 = alix .create_group(None, GroupMetadataOptions::default()) @@ -537,31 +534,31 @@ mod tests { .add_members_by_inbox_id(&alix, vec![caro.inbox_id()]) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(300)).await; alix_group .send_message("fourth".as_bytes(), &alix) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + delivery.wait_for_delivery().await.unwrap(); + alix_group_2 .send_message("fifth".as_bytes(), &alix) .await .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + delivery.wait_for_delivery().await.unwrap(); { let messages = messages.lock().unwrap(); assert_eq!(messages.len(), 5); } - stream.end(); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - assert!(stream.is_closed()); + let a = handle.handle.abort_handle(); + a.abort(); + let _ = handle.handle.await; + assert!(a.is_finished()); alix_group - .send_message("first".as_bytes(), &alix) + .send_message("should not show up".as_bytes(), &alix) .await .unwrap(); tokio::time::sleep(std::time::Duration::from_millis(100)).await; @@ -570,34 +567,96 @@ mod tests { assert_eq!(messages.len(), 5); } + #[ignore] + #[tokio::test(flavor = "multi_thread", worker_threads = 10)] + async fn test_stream_all_messages_does_not_lose_messages() { + let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + let caro = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + + let alix_group = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + alix_group + .add_members_by_inbox_id(&alix, vec![caro.inbox_id()]) + .await + .unwrap(); + + let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); + let messages_clone = messages.clone(); + + let blocked = Arc::new(AtomicU64::new(55)); + + let blocked_pointer = blocked.clone(); + let mut handle = + Client::::stream_all_messages_with_callback(caro.clone(), move |message| { + (*messages_clone.lock().unwrap()).push(message); + blocked_pointer.fetch_sub(1, Ordering::SeqCst); + }); + handle.wait_for_ready().await; + + let alix_group_pointer = alix_group.clone(); + let alix_pointer = alix.clone(); + tokio::spawn(async move { + for _ in 0..50 { + alix_group_pointer + .send_message(b"spam", &alix_pointer) + .await + .unwrap(); + tokio::time::sleep(std::time::Duration::from_micros(200)).await; + } + }); + + for _ in 0..5 { + let new_group = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + new_group + .add_members_by_inbox_id(&alix, vec![caro.inbox_id()]) + .await + .unwrap(); + new_group + .send_message(b"spam from new group", &alix) + .await + .unwrap(); + } + + let _ = tokio::time::timeout(std::time::Duration::from_secs(120), async { + while blocked.load(Ordering::SeqCst) > 0 { + tokio::task::yield_now().await; + } + }) + .await; + + let missed_messages = blocked.load(Ordering::SeqCst); + if missed_messages > 0 { + println!("Missed {} Messages", missed_messages); + panic!("Test failed due to missed messages"); + } + } + #[tokio::test(flavor = "multi_thread")] async fn test_self_group_creation() { let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let groups = Arc::new(Mutex::new(Vec::new())); - let notify = Arc::new(Notify::new()); + let notify = Delivery::new(); let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone()); - let closer = Client::::stream_conversations_with_callback( - alix.clone(), - move |g| { + let closer = + Client::::stream_conversations_with_callback(alix.clone(), move |g| { let mut groups = groups_pointer.lock().unwrap(); groups.push(g); notify_pointer.notify_one(); - }, - || {}, - ) - .unwrap(); + }); alix.create_group(None, GroupMetadataOptions::default()) .unwrap(); - tokio::time::timeout(std::time::Duration::from_secs(60), async { - notify.notified().await - }) - .await - .expect("Stream never received group"); + notify + .wait_for_delivery() + .await + .expect("Stream never received group"); { let grps = groups.lock().unwrap(); @@ -612,17 +671,13 @@ mod tests { .await .unwrap(); - tokio::time::timeout(std::time::Duration::from_secs(60), async { - notify.notified().await - }) - .await - .expect("Stream never received group"); + notify.wait_for_delivery().await.unwrap(); { let grps = groups.lock().unwrap(); assert_eq!(grps.len(), 2); } - closer.end(); + closer.handle.abort(); } } diff --git a/xmtp_mls/src/utils/test.rs b/xmtp_mls/src/utils/test.rs index cffc96d67..584c93ebe 100644 --- a/xmtp_mls/src/utils/test.rs +++ b/xmtp_mls/src/utils/test.rs @@ -4,6 +4,8 @@ use rand::{ distributions::{Alphanumeric, DistString}, Rng, }; +use std::sync::Arc; +use tokio::{sync::Notify, time::error::Elapsed}; use xmtp_api_grpc::grpc_api_helper::Client as GrpcClient; use xmtp_id::associations::{generate_inbox_id, RecoverableEcdsaSignature}; @@ -111,6 +113,31 @@ impl ClientBuilder { } } +/// wrapper over a `Notify` with a 60-scond timeout for waiting +#[derive(Clone, Default)] +pub struct Delivery { + notify: Arc, +} + +impl Delivery { + pub fn new() -> Self { + Self { + notify: Arc::new(Notify::new()), + } + } + + pub async fn wait_for_delivery(&self) -> Result<(), Elapsed> { + tokio::time::timeout(std::time::Duration::from_secs(60), async { + self.notify.notified().await + }) + .await + } + + pub fn notify_one(&self) { + self.notify.notify_one() + } +} + impl Client { pub async fn is_registered(&self, address: &String) -> bool { let ids = self