From 1ea02af9800cbfee0b39aff2192d9b2107080a94 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Fri, 20 Dec 2024 12:42:16 -0500 Subject: [PATCH 1/3] Fix NotFound string errors --- Cargo.lock | 4 +- xmtp_debug/src/args.rs | 11 +++ xmtp_mls/src/client.rs | 55 ++++++--------- xmtp_mls/src/groups/device_sync.rs | 26 ++++--- xmtp_mls/src/groups/mod.rs | 6 +- xmtp_mls/src/storage/encrypted_store/group.rs | 12 ++-- .../storage/encrypted_store/group_intent.rs | 69 ++++++++----------- .../storage/encrypted_store/refresh_state.rs | 46 ++++++++----- .../encrypted_store/sqlcipher_connection.rs | 8 +-- xmtp_mls/src/storage/errors.rs | 47 +++++++++++-- 10 files changed, 158 insertions(+), 126 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf5f4a245..f13a17713 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3211,7 +3211,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -6899,7 +6899,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/xmtp_debug/src/args.rs b/xmtp_debug/src/args.rs index 3c2cf67c5..a0241216b 100644 --- a/xmtp_debug/src/args.rs +++ b/xmtp_debug/src/args.rs @@ -154,6 +154,17 @@ pub enum EntityKind { Identity, } +impl std::fmt::Display for EntityKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use EntityKind::*; + match self { + Group => write!(f, "group"), + Message => write!(f, "message"), + Identity => write!(f, "identity"), + } + } +} + /// specify the log output #[derive(Args, Debug)] pub struct LogOptions { diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 02327f11b..d1041dfcf 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -50,7 +50,7 @@ use crate::{ group_message::StoredGroupMessage, refresh_state::EntityKind, wallet_addresses::WalletEntry, - EncryptedMessageStore, StorageError, + EncryptedMessageStore, NotFound, StorageError, }, subscriptions::{LocalEventError, LocalEvents}, types::InstallationId, @@ -108,6 +108,12 @@ pub enum ClientError { Generic(String), } +impl From for ClientError { + fn from(value: NotFound) -> Self { + ClientError::Storage(StorageError::NotFound(value)) + } +} + impl From for ClientError { fn from(err: GroupError) -> ClientError { ClientError::Group(Box::new(err)) @@ -309,11 +315,7 @@ where address: String, ) -> Result, ClientError> { let results = self.find_inbox_ids_from_addresses(conn, &[address]).await?; - if let Some(first_result) = results.into_iter().next() { - Ok(first_result) - } else { - Ok(None) - } + Ok(results.into_iter().next().flatten()) } /// Calls the server to look up the `inbox_id`s` associated with a list of addresses. @@ -556,10 +558,9 @@ where { Some(id) => id, None => { - return Err(ClientError::Storage(StorageError::NotFound(format!( - "inbox id for address {} not found", - account_address - )))) + return Err(ClientError::Storage(StorageError::NotFound( + NotFound::InboxIdForAddress(account_address), + ))); } }; @@ -610,13 +611,10 @@ where group_id: Vec, ) -> Result, ClientError> { let stored_group: Option = conn.fetch(&group_id)?; - match stored_group { - Some(group) => Ok(MlsGroup::new(self.clone(), group.id, group.created_at_ns)), - None => Err(ClientError::Storage(StorageError::NotFound(format!( - "group {}", - hex::encode(group_id) - )))), - } + stored_group + .map(|g| MlsGroup::new(self.clone(), g.id, g.created_at_ns)) + .ok_or(NotFound::GroupById(group_id)) + .map_err(Into::into) } /// Look up a group by its ID @@ -638,17 +636,10 @@ where target_inbox_id: String, ) -> Result, ClientError> { let conn = self.store().conn()?; - match conn.find_dm_group(&target_inbox_id)? { - Some(dm_group) => Ok(MlsGroup::new( - self.clone(), - dm_group.id, - dm_group.created_at_ns, - )), - None => Err(ClientError::Storage(StorageError::NotFound(format!( - "dm_target_inbox_id {}", - hex::encode(target_inbox_id) - )))), - } + let group = conn + .find_dm_group(&target_inbox_id)? + .ok_or(NotFound::DmByInbox(target_inbox_id))?; + Ok(MlsGroup::new(self.clone(), group.id, group.created_at_ns)) } /// Look up a message by its ID @@ -656,13 +647,7 @@ where pub fn message(&self, message_id: Vec) -> Result { let conn = &mut self.store().conn()?; let message = conn.get_group_message(&message_id)?; - match message { - Some(message) => Ok(message), - None => Err(ClientError::Storage(StorageError::NotFound(format!( - "message {}", - hex::encode(message_id) - )))), - } + Ok(message.ok_or(NotFound::MessageById(message_id))?) } /// Query for groups with optional filters diff --git a/xmtp_mls/src/groups/device_sync.rs b/xmtp_mls/src/groups/device_sync.rs index 070caabe6..43660e643 100644 --- a/xmtp_mls/src/groups/device_sync.rs +++ b/xmtp_mls/src/groups/device_sync.rs @@ -6,11 +6,9 @@ use crate::{ configuration::NS_IN_HOUR, storage::{ consent_record::StoredConsentRecord, - group::StoredGroup, - group::{ConversationType, GroupQueryArgs}, - group_message::MsgQueryArgs, - group_message::{GroupMessageKind, StoredGroupMessage}, - DbConnection, StorageError, + group::{ConversationType, GroupQueryArgs, StoredGroup}, + group_message::{GroupMessageKind, MsgQueryArgs, StoredGroupMessage}, + DbConnection, NotFound, StorageError, }, subscriptions::{LocalEvents, StreamMessages, SubscribeError, SyncMessage}, xmtp_openmls_provider::XmtpOpenMlsProvider, @@ -115,6 +113,12 @@ impl RetryableError for DeviceSyncError { } } +impl From for DeviceSyncError { + fn from(value: NotFound) -> Self { + DeviceSyncError::Storage(StorageError::NotFound(value)) + } +} + impl Client where ApiClient: XmtpApi + Send + Sync + 'static, @@ -211,9 +215,9 @@ where retry, (async { conn.get_group_message(&message_id)? - .ok_or(DeviceSyncError::Storage(StorageError::NotFound(format!( - "Message id {message_id:?} not found." - )))) + .ok_or(DeviceSyncError::from(NotFound::MessageById( + message_id.clone(), + ))) }) )?; @@ -240,9 +244,9 @@ where retry, (async { conn.get_group_message(&message_id)? - .ok_or(DeviceSyncError::Storage(StorageError::NotFound(format!( - "Message id {message_id:?} not found." - )))) + .ok_or(DeviceSyncError::from(NotFound::MessageById( + message_id.clone(), + ))) }) )?; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 3c1320f6c..5f0bfa22c 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -58,7 +58,7 @@ use self::{ intents::IntentError, validated_commit::CommitValidationError, }; -use crate::storage::{group_message::ContentType, StorageError}; +use crate::storage::{group_message::ContentType, NotFound, StorageError}; use xmtp_common::time::now_ns; use xmtp_proto::xmtp::mls::{ api::v1::{ @@ -418,7 +418,9 @@ impl MlsGroup { let mls_group = OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) .map_err(crate::StorageError::from)? - .ok_or(crate::StorageError::NotFound("Group Not Found".into()))?; + .ok_or(StorageError::from(NotFound::GroupById( + self.group_id.to_vec(), + )))?; // Perform the operation with the MLS group operation(mls_group).await.map_err(Into::into) diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index 010eab713..547f85134 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -5,7 +5,7 @@ use super::{ schema::groups::{self, dsl}, Sqlite, }; -use crate::{impl_fetch, impl_store, DuplicateItem, StorageError}; +use crate::{impl_fetch, impl_store, storage::NotFound, DuplicateItem, StorageError}; use diesel::{ backend::Backend, deserialize::{self, FromSql, FromSqlRow}, @@ -379,9 +379,8 @@ impl DbConnection { Ok::, StorageError>(ts) })?; - last_ts.ok_or(StorageError::NotFound(format!( - "installation time for group {}", - hex::encode(group_id) + last_ts.ok_or(StorageError::NotFound(NotFound::InstallationTimeForGroup( + group_id, ))) } @@ -407,10 +406,7 @@ impl DbConnection { Ok::<_, StorageError>(ts) })?; - last_ts.ok_or(StorageError::NotFound(format!( - "installation time for group {}", - hex::encode(group_id) - ))) + last_ts.ok_or(NotFound::InstallationTimeForGroup(group_id).into()) } /// Updates the 'last time checked' we checked for new installations. diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 70edb0956..743eccc5e 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -17,7 +17,7 @@ use super::{ use crate::{ groups::intents::{IntentError, SendMessageIntentData}, impl_fetch, impl_store, - storage::StorageError, + storage::{NotFound, StorageError}, utils::id::calculate_message_id, Delete, }; @@ -197,7 +197,7 @@ impl DbConnection { staged_commit: Option>, published_in_epoch: i64, ) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Published is from @@ -213,30 +213,25 @@ impl DbConnection { .execute(conn) })?; - match res { - // If nothing matched the query, check if its already published, otherwise return an error. Either ID or state was wrong - 0 => { - let already_published = self.raw_query(|conn| { - dsl::group_intents - .filter(dsl::id.eq(intent_id)) - .first::(conn) - }); - - if already_published.is_ok() { - Ok(()) - } else { - Err(StorageError::NotFound(format!( - "Published intent {intent_id} for commit" - ))) - } + if rows_changed == 0 { + let already_published = self.raw_query(|conn| { + dsl::group_intents + .filter(dsl::id.eq(intent_id)) + .first::(conn) + }); + + if already_published.is_ok() { + return Ok(()); + } else { + return Err(NotFound::IntentForToPublish(intent_id).into()); } - _ => Ok(()), } + Ok(()) } // Set the intent with the given ID to `Committed` pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to Committed is from @@ -246,19 +241,18 @@ impl DbConnection { .execute(conn) })?; - match res { - // If nothing matched the query, return an error. Either ID or state was wrong - 0 => Err(StorageError::NotFound(format!( - "Published intent {intent_id} for commit" - ))), - _ => Ok(()), + // If nothing matched the query, return an error. Either ID or state was wrong + if rows_changed == 0 { + return Err(NotFound::IntentForCommitted(intent_id).into()); } + + Ok(()) } // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and // `post_commit_data` pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) // State machine requires that the only valid state transition to ToPublish is from @@ -275,32 +269,27 @@ impl DbConnection { .execute(conn) })?; - match res { - // If nothing matched the query, return an error. Either ID or state was wrong - 0 => Err(StorageError::NotFound(format!( - "Published intent {intent_id} for ToPublish" - ))), - _ => Ok(()), + if rows_changed == 0 { + return Err(NotFound::IntentForPublish(intent_id).into()); } + Ok(()) } /// Set the intent with the given ID to `Error` #[tracing::instrument(level = "trace", skip(self))] pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> { - let res = self.raw_query(|conn| { + let rows_changed = self.raw_query(|conn| { diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) .set(dsl::state.eq(IntentState::Error)) .execute(conn) })?; - match res { - // If nothing matched the query, return an error. Either ID or state was wrong - 0 => Err(StorageError::NotFound(format!( - "state for intent {intent_id}" - ))), - _ => Ok(()), + if rows_changed == 0 { + return Err(NotFound::IntentById(intent_id).into()); } + + Ok(()) } // Simple lookup of intents by payload hash, meant to be used when processing messages off the diff --git a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs index f5cf0ba33..b1cfefcb0 100644 --- a/xmtp_mls/src/storage/encrypted_store/refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/refresh_state.rs @@ -8,7 +8,11 @@ use diesel::{ }; use super::{db_connection::DbConnection, schema::refresh_state, Sqlite}; -use crate::{impl_store, impl_store_or_ignore, storage::StorageError, StoreOrIgnore}; +use crate::{ + impl_store, impl_store_or_ignore, + storage::{NotFound, StorageError}, + StoreOrIgnore, +}; #[repr(i32)] #[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, Hash, FromSqlRow)] @@ -18,6 +22,16 @@ pub enum EntityKind { Group = 2, } +impl std::fmt::Display for EntityKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use EntityKind::*; + match self { + Welcome => write!(f, "welcome"), + Group => write!(f, "group"), + } + } +} + impl ToSql for EntityKind where i32: ToSql, @@ -96,24 +110,18 @@ impl DbConnection { entity_kind: EntityKind, cursor: i64, ) -> Result { - let state: Option = self.get_refresh_state(&entity_id, entity_kind)?; - match state { - Some(state) => { - use super::schema::refresh_state::dsl; - let num_updated = self.raw_query(|conn| { - diesel::update(&state) - .filter(dsl::cursor.lt(cursor)) - .set(dsl::cursor.eq(cursor)) - .execute(conn) - })?; - Ok(num_updated == 1) - } - None => Err(StorageError::NotFound(format!( - "state for entity ID {} with kind {:?}", - hex::encode(entity_id.as_ref()), - entity_kind - ))), - } + use super::schema::refresh_state::dsl; + let state: RefreshState = self.get_refresh_state(&entity_id, entity_kind)?.ok_or( + NotFound::RefreshStateByIdAndKind(entity_id.as_ref().to_vec(), entity_kind), + )?; + + let num_updated = self.raw_query(|conn| { + diesel::update(&state) + .filter(dsl::cursor.lt(cursor)) + .set(dsl::cursor.eq(cursor)) + .execute(conn) + })?; + Ok(num_updated == 1) } } diff --git a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs index fe9350b26..6723f0df1 100644 --- a/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/sqlcipher_connection.rs @@ -12,7 +12,7 @@ use std::{ path::{Path, PathBuf}, }; -use crate::storage::StorageError; +use crate::storage::{NotFound, StorageError}; use super::{EncryptionKey, StorageOption}; @@ -165,9 +165,9 @@ impl EncryptedConnection { ) -> Result<(), StorageError> { let mut row_iter = conn.load(sql_query("PRAGMA cipher_salt"))?; // cipher salt should always exist. if it doesn't SQLCipher is misconfigured. - let row = row_iter.next().ok_or(StorageError::NotFound( - "Cipher salt doesn't exist in database".into(), - ))??; + let row = row_iter + .next() + .ok_or(NotFound::CipherSalt(path.to_string()))??; let salt = >::build_from_row(&row)?; tracing::debug!( salt, diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index 3cc3df2a8..8351acbe6 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -3,7 +3,10 @@ use std::sync::PoisonError; use diesel::result::DatabaseErrorKind; use thiserror::Error; -use super::sql_key_store::{self, SqlKeyStoreError}; +use super::{ + refresh_state::EntityKind, + sql_key_store::{self, SqlKeyStoreError}, +}; use crate::groups::intents::IntentError; use xmtp_common::{retryable, RetryableError}; @@ -27,10 +30,9 @@ pub enum StorageError { Serialization(String), #[error("deserialization error")] Deserialization(String), - // TODO:insipx Make NotFound into an enum of possible items that may not be found - #[error("{0} not found")] - NotFound(String), - #[error("lock")] + #[error(transparent)] + NotFound(#[from] NotFound), + #[error("lock {0}")] Lock(String), #[error("Pool needs to reconnect before use")] PoolNeedsConnection, @@ -50,6 +52,35 @@ pub enum StorageError { OpenMlsStorage(#[from] SqlKeyStoreError), } +#[derive(Error, Debug)] +// Monolithic enum for all things lost +pub enum NotFound { + #[error("group with welcome id {0} not found")] + GroupByWelcome(i64), + #[error("group with id {id} not found", id = hex::encode(_0))] + GroupById(Vec), + #[error("installation time for group {id}", id = hex::encode(_0))] + InstallationTimeForGroup(Vec), + #[error("inbox id for address {0} not found")] + InboxIdForAddress(String), + #[error("message id {id} not found", id = hex::encode(_0))] + MessageById(Vec), + #[error("dm by dm_target_inbox_id {0} not found")] + DmByInbox(String), + #[error("intent with id {0} for state Publish from ToPublish not found")] + IntentForToPublish(i32), + #[error("intent with id {0} for state ToPublish from Published not found")] + IntentForPublish(i32), + #[error("intent with id {0} for state Committed from Published not found")] + IntentForCommitted(i32), + #[error("Intent with id {0} not found")] + IntentById(i32), + #[error("refresh state with id {id} and kind {1} not found", id = hex::encode(_0))] + RefreshStateByIdAndKind(Vec, EntityKind), + #[error("Cipher salt for db at [`{0}`] not found")] + CipherSalt(String), +} + #[derive(Error, Debug)] pub enum DuplicateItem { #[error("the welcome id {0:?} already exists")] @@ -105,6 +136,12 @@ impl RetryableError for StorageError { } } +impl RetryableError for NotFound { + fn is_retryable(&self) -> bool { + true + } +} + // OpenMLS KeyStore errors impl RetryableError for openmls::group::AddMembersError { fn is_retryable(&self) -> bool { From 4e2cb59f66c0d9aaa8a06defe37c08ae00e576d4 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Fri, 20 Dec 2024 12:43:32 -0500 Subject: [PATCH 2/3] feat(wasm): unblock streams in the browser --- Cargo.toml | 23 +- common/src/test.rs | 4 + xmtp_api_grpc/Cargo.toml | 2 +- xmtp_api_http/Cargo.toml | 5 +- xmtp_api_http/src/http_stream.rs | 231 +++++++++++++++++ xmtp_api_http/src/lib.rs | 4 +- xmtp_api_http/src/util.rs | 83 +----- xmtp_mls/Cargo.toml | 3 +- xmtp_mls/src/api/mls.rs | 12 +- xmtp_mls/src/storage/encrypted_store/group.rs | 53 ++++ .../mod.rs} | 189 +++++--------- xmtp_mls/src/subscriptions/stream_all.rs | 87 +++++++ .../src/subscriptions/stream_conversations.rs | 241 ++++++++++++++++++ 13 files changed, 711 insertions(+), 226 deletions(-) create mode 100644 xmtp_api_http/src/http_stream.rs rename xmtp_mls/src/{subscriptions.rs => subscriptions/mod.rs} (88%) create mode 100644 xmtp_mls/src/subscriptions/stream_all.rs create mode 100644 xmtp_mls/src/subscriptions/stream_conversations.rs diff --git a/Cargo.toml b/Cargo.toml index b46b42352..d34d8d219 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,7 @@ ctor = "0.2" ed25519 = "2.2.3" ed25519-dalek = { version = "2.1.1", features = ["zeroize"] } ethers = { version = "2.0", default-features = false } -futures = "0.3.30" -futures-core = "0.3.30" +futures = { version = "0.3.30", default-features = false } getrandom = { version = "0.2", default-features = false } hex = "0.4.3" hkdf = "0.12.3" @@ -61,16 +60,7 @@ thiserror = "2.0" tls_codec = "0.4.1" tokio = { version = "1.35.1", default-features = false } uuid = "1.10" -wasm-timer = "0.2" web-time = "1.1" -# Changing this version and rustls may potentially break the android build. Use Caution. -# Test with Android and Swift first. -# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier -# Until then, always test agains iOS/Android after updating these dependencies & making a PR -# Related Issues: -# - https://github.com/seanmonstar/reqwest/issues/2159 -# - https://github.com/hyperium/tonic/pull/1974 -# - https://github.com/rustls/rustls-platform-verifier/issues/58 bincode = "1.3" console_error_panic_hook = "0.1" const_format = "0.2" @@ -87,6 +77,14 @@ openssl = { version = "0.10", features = ["vendored"] } openssl-sys = { version = "0.9", features = ["vendored"] } parking_lot = "0.12.3" sqlite-web = "0.0.1" +# Changing this version and rustls may potentially break the android build. Use Caution. +# Test with Android and Swift first. +# Its probably preferable to one day use https://github.com/rustls/rustls-platform-verifier +# Until then, always test agains iOS/Android after updating these dependencies & making a PR +# Related Issues: +# - https://github.com/seanmonstar/reqwest/issues/2159 +# - https://github.com/hyperium/tonic/pull/1974 +# - https://github.com/rustls/rustls-platform-verifier/issues/58 tonic = { version = "0.12", default-features = false } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false } @@ -101,7 +99,8 @@ criterion = { version = "0.5", features = [ "html_reports", "async_tokio", ]} - once_cell = "1.2" +once_cell = "1.2" +pin-project-lite = "0.2" # Internal Crate Dependencies xmtp_api_grpc = { path = "xmtp_api_grpc" } diff --git a/common/src/test.rs b/common/src/test.rs index 4cfb2442d..e8ae377a6 100644 --- a/common/src/test.rs +++ b/common/src/test.rs @@ -108,6 +108,10 @@ pub fn rand_u64() -> u64 { crypto_utils::rng().gen() } +pub fn rand_i64() -> i64 { + crypto_utils::rng().gen() +} + #[cfg(not(target_arch = "wasm32"))] pub fn tmp_path() -> String { let db_name = crate::rand_string::<24>(); diff --git a/xmtp_api_grpc/Cargo.toml b/xmtp_api_grpc/Cargo.toml index b69a0d0c4..67ea6fb16 100644 --- a/xmtp_api_grpc/Cargo.toml +++ b/xmtp_api_grpc/Cargo.toml @@ -8,7 +8,7 @@ version.workspace = true async-stream.workspace = true async-trait = "0.1" base64.workspace = true -futures.workspace = true +futures = { workspace = true, features = ["alloc"] } hex.workspace = true prost = { workspace = true, features = ["prost-derive"] } tokio = { workspace = true, features = ["macros", "time"] } diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index b26a414a9..09a6a9214 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -8,16 +8,17 @@ license.workspace = true crate-type = ["cdylib", "rlib"] [dependencies] -async-stream.workspace = true futures = { workspace = true } tracing.workspace = true reqwest = { version = "0.12.5", features = ["json", "stream"] } serde = { workspace = true } serde_json = { workspace = true } -thiserror = "2.0" +thiserror.workspace = true tokio = { workspace = true, features = ["sync", "rt", "macros"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } async-trait = "0.1" +bytes = "1.9" +pin-project-lite = "0.2.15" [dev-dependencies] xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] } diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs new file mode 100644 index 000000000..8e969b0c4 --- /dev/null +++ b/xmtp_api_http/src/http_stream.rs @@ -0,0 +1,231 @@ +//! Streams that work with HTTP POST requests + +use crate::util::GrpcResponse; +use futures::{ + stream::{self, Stream, StreamExt}, + Future, +}; +use reqwest::Response; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::Deserializer; +use std::{marker::PhantomData, pin::Pin, task::Poll}; +use xmtp_proto::{Error, ErrorKind}; + +#[derive(Deserialize, Serialize, Debug)] +pub(crate) struct SubscriptionItem { + pub result: T, +} + +#[cfg(target_arch = "wasm32")] +pub type BytesStream = stream::LocalBoxStream<'static, Result>; + +// #[cfg(not(target_arch = "wasm32"))] +// pub type BytesStream = Pin> + Send>>; + +#[cfg(not(target_arch = "wasm32"))] +pub type BytesStream = stream::BoxStream<'static, Result>; + +pin_project_lite::pin_project! { + #[project = PostStreamProject] + enum HttpPostStream { + NotStarted{#[pin] fut: F}, + // `Reqwest::bytes_stream` returns `impl Stream` rather than a type generic, + // so we can't use a type generic here + // this makes wasm a bit tricky. + Started { + #[pin] http: BytesStream, + remaining: Vec, + _marker: PhantomData, + }, + } +} + +impl Stream for HttpPostStream +where + F: Future>, + for<'de> R: Send + Deserialize<'de>, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::task::Poll::*; + match self.as_mut().project() { + PostStreamProject::NotStarted { fut } => match fut.poll(cx) { + Ready(response) => { + let s = response.unwrap().bytes_stream(); + self.set(Self::started(s)); + self.as_mut().poll_next(cx) + } + Pending => { + cx.waker().wake_by_ref(); + Pending + } + }, + PostStreamProject::Started { + ref mut http, + ref mut remaining, + .. + } => { + let mut pinned = std::pin::pin!(http); + let next = pinned.as_mut().poll_next(cx); + Self::on_bytes(next, remaining, cx) + } + } + } +} + +impl HttpPostStream +where + R: Send, +{ + #[cfg(not(target_arch = "wasm32"))] + fn started( + http: impl Stream> + Send + 'static, + ) -> Self { + Self::Started { + http: http.boxed(), + remaining: Vec::new(), + _marker: PhantomData, + } + } + + #[cfg(target_arch = "wasm32")] + fn started(http: impl Stream> + 'static) -> Self { + Self::Started { + http: http.boxed_local(), + remaining: Vec::new(), + _marker: PhantomData, + } + } +} + +impl HttpPostStream +where + F: Future>, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + fn new(request: F) -> Self { + Self::NotStarted { fut: request } + } + + fn on_bytes( + p: Poll>>, + remaining: &mut Vec, + cx: &mut std::task::Context<'_>, + ) -> Poll::Item>> { + use futures::task::Poll::*; + match p { + Ready(Some(bytes)) => { + let bytes = bytes.map_err(|e| { + Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()) + })?; + let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); + let de = Deserializer::from_slice(bytes); + let mut stream = de.into_iter::>(); + 'messages: loop { + tracing::debug!("Waiting on next response ..."); + let response = stream.next(); + let res = match response { + Some(Ok(GrpcResponse::Ok(response))) => Ok(response), + Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), + Some(Ok(GrpcResponse::Err(e))) => { + Err(Error::new(ErrorKind::MlsError).with(e.message)) + } + Some(Err(e)) => { + if e.is_eof() { + *remaining = (&**bytes)[stream.byte_offset()..].to_vec(); + return Pending; + } else { + Err(Error::new(ErrorKind::MlsError).with(e.to_string())) + } + } + Some(Ok(GrpcResponse::Empty {})) => continue 'messages, + None => return Ready(None), + }; + return Ready(Some(res)); + } + } + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl HttpPostStream +where + F: Future> + Unpin, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + /// Establish the initial HTTP Stream connection + fn establish(&mut self) -> () { + // we need to poll the future once to progress the future state & + // establish the initial POST request. + // It should always be pending + let noop_waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&noop_waker); + // let mut this = Pin::new(self); + let mut this = Pin::new(self); + let _ = this.poll_next_unpin(&mut cx); + } +} + +#[cfg(target_arch = "wasm32")] +impl HttpPostStream +where + F: Future>, + for<'de> R: Deserialize<'de> + DeserializeOwned + Send, +{ + fn establish(&mut self) -> () { + // we need to poll the future once to progress the future state & + // establish the initial POST request. + // It should always be pending + let noop_waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&noop_waker); + let mut this = unsafe { Pin::new_unchecked(self) }; + let _ = this.poll_next_unpin(&mut cx); + } +} + +#[cfg(target_arch = "wasm32")] +pub fn create_grpc_stream( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::LocalBoxStream<'static, Result> { + create_grpc_stream_inner(request, endpoint, http_client).boxed_local() +} + +#[cfg(not(target_arch = "wasm32"))] +pub fn create_grpc_stream( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> stream::BoxStream<'static, Result> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ + create_grpc_stream_inner(request, endpoint, http_client).boxed() +} + +fn create_grpc_stream_inner( + request: T, + endpoint: String, + http_client: reqwest::Client, +) -> impl Stream> +where + T: Serialize + 'static, + R: DeserializeOwned + Send + 'static, +{ + let request = http_client.post(endpoint).json(&request).send(); + let mut http = HttpPostStream::new(request); + http.establish(); + http +} diff --git a/xmtp_api_http/src/lib.rs b/xmtp_api_http/src/lib.rs index 80489fb3c..8a3f972c4 100755 --- a/xmtp_api_http/src/lib.rs +++ b/xmtp_api_http/src/lib.rs @@ -1,11 +1,13 @@ #![warn(clippy::unwrap_used)] pub mod constants; +mod http_stream; mod util; use futures::stream; +use http_stream::create_grpc_stream; use reqwest::header; -use util::{create_grpc_stream, handle_error}; +use util::handle_error; use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient}; use xmtp_proto::xmtp::identity::api::v1::{ GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request, diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index 8a839fc56..34c878c4a 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -1,9 +1,5 @@ -use futures::{ - stream::{self, StreamExt}, - Stream, -}; +use crate::http_stream::SubscriptionItem; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::Deserializer; use std::io::Read; use xmtp_proto::{Error, ErrorKind}; @@ -23,11 +19,6 @@ pub(crate) struct ErrorResponse { details: Vec, } -#[derive(Deserialize, Serialize, Debug)] -pub(crate) struct SubscriptionItem { - pub result: T, -} - /// handle JSON response from gRPC, returning either /// the expected deserialized response object or a gRPC [`Error`] pub fn handle_error(reader: R) -> Result @@ -43,78 +34,6 @@ where } } -#[cfg(target_arch = "wasm32")] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::LocalBoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed_local() -} - -#[cfg(not(target_arch = "wasm32"))] -pub fn create_grpc_stream< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> stream::BoxStream<'static, Result> { - create_grpc_stream_inner(request, endpoint, http_client).boxed() -} - -pub fn create_grpc_stream_inner< - T: Serialize + Send + 'static, - R: DeserializeOwned + Send + std::fmt::Debug + 'static, ->( - request: T, - endpoint: String, - http_client: reqwest::Client, -) -> impl Stream> { - async_stream::stream! { - let request = http_client - .post(endpoint) - .json(&request) - .send() - .await - .map_err(|e| Error::new(ErrorKind::MlsError).with(e))?; - - let mut remaining = vec![]; - for await bytes in request.bytes_stream() { - let bytes = bytes - .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?; - let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); - let de = Deserializer::from_slice(bytes); - let mut stream = de.into_iter::>(); - 'messages: loop { - let response = stream.next(); - let res = match response { - Some(Ok(GrpcResponse::Ok(response))) => Ok(response), - Some(Ok(GrpcResponse::SubscriptionItem(item))) => Ok(item.result), - Some(Ok(GrpcResponse::Err(e))) => { - Err(Error::new(ErrorKind::MlsError).with(e.message)) - } - Some(Err(e)) => { - if e.is_eof() { - remaining = (&**bytes)[stream.byte_offset()..].to_vec(); - break 'messages; - } else { - Err(Error::new(ErrorKind::MlsError).with(e.to_string())) - } - } - Some(Ok(GrpcResponse::Empty {})) => continue 'messages, - None => break 'messages, - }; - yield res; - } - } - } -} - #[cfg(feature = "test-utils")] #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index 0e160e485..ffd6f042e 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -49,7 +49,7 @@ async-stream.workspace = true async-trait.workspace = true bincode.workspace = true diesel_migrations.workspace = true -futures.workspace = true +futures = { workspace = true, features = ["alloc"] } hex.workspace = true hkdf.workspace = true openmls_rust_crypto = { workspace = true } @@ -70,6 +70,7 @@ tracing.workspace = true trait-variant.workspace = true xmtp_common.workspace = true zeroize.workspace = true +pin-project-lite.workspace = true # XMTP/Local xmtp_content_types = { path = "../xmtp_content_types" } diff --git a/xmtp_mls/src/api/mls.rs b/xmtp_mls/src/api/mls.rs index 3994cd8fa..86b206121 100644 --- a/xmtp_mls/src/api/mls.rs +++ b/xmtp_mls/src/api/mls.rs @@ -274,10 +274,10 @@ where Ok(()) } - pub async fn subscribe_group_messages( - &self, + pub(crate) async fn subscribe_group_messages<'a>( + &'a self, filters: Vec, - ) -> Result> + '_, ApiError> + ) -> Result<::GroupMessageStream<'a>, ApiError> where ApiClient: XmtpMlsStreams, { @@ -289,11 +289,11 @@ where .await } - pub async fn subscribe_welcome_messages( - &self, + pub(crate) async fn subscribe_welcome_messages<'a>( + &'a self, installation_key: &[u8], id_cursor: Option, - ) -> Result> + '_, ApiError> + ) -> Result<::WelcomeMessageStream<'a>, ApiError> where ApiClient: XmtpMlsStreams, { diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index 547f85134..93f3e8e86 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -454,6 +454,22 @@ impl DbConnection { Ok(stored_group) } + + /// Get all the welcome ids turned into groups + pub(crate) fn group_welcome_ids(&self) -> Result, StorageError> { + self.raw_query(|conn| { + Ok::<_, StorageError>( + dsl::groups + .filter(dsl::welcome_id.is_not_null()) + .select(dsl::welcome_id) + .load::>(conn)? + .into_iter() + .map(|id| id.expect("SQL explicity filters for none")) + .collect(), + ) + }) + .map_err(Into::into) + } } #[repr(i32)] @@ -566,6 +582,25 @@ pub(crate) mod tests { ) } + /// Generate a test group with welcome + pub fn generate_group_with_welcome( + state: Option, + welcome_id: Option, + ) -> StoredGroup { + let id = rand_vec::<24>(); + let created_at_ns = now_ns(); + let membership_state = state.unwrap_or(GroupMembershipState::Allowed); + StoredGroup::new_from_welcome( + id, + created_at_ns, + membership_state, + "placeholder_address".to_string(), + welcome_id.unwrap_or(xmtp_common::rand_i64()), + ConversationType::Group, + None, + ) + } + /// Generate a test consent fn generate_consent_record( entity_type: ConsentType, @@ -852,4 +887,22 @@ pub(crate) mod tests { }) .await } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] + async fn test_get_group_welcome_ids() { + with_connection(|conn| { + let mls_groups = vec![ + generate_group_with_welcome(None, Some(30)), + generate_group(None), + generate_group(None), + generate_group_with_welcome(None, Some(10)), + ]; + for g in mls_groups.iter() { + g.store(conn).unwrap(); + } + assert_eq!(vec![30, 10], conn.group_welcome_ids().unwrap()); + }) + .await + } } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions/mod.rs similarity index 88% rename from xmtp_mls/src/subscriptions.rs rename to xmtp_mls/src/subscriptions/mod.rs index 97f538504..e9b8dcad0 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions/mod.rs @@ -10,6 +10,9 @@ use tracing::instrument; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage}; +// mod stream_all; +// mod stream_conversations; + use crate::{ client::{extract_welcome_message, ClientError}, groups::{ @@ -454,9 +457,9 @@ where futures::pin_mut!(messages_stream); let convo_stream = self.stream_conversations(conversation_type).await?; - futures::pin_mut!(convo_stream); + tracing::info!("\n\n Waiting on messages \n\n"); let mut extra_messages = Vec::new(); loop { @@ -609,6 +612,45 @@ pub(crate) mod tests { use xmtp_cryptography::utils::generate_local_wallet; use xmtp_id::InboxOwner; + /// A macro for asserting that a stream yields a specific decrypted message. + /// + /// # Example + /// ```rust + /// assert_msg!(stream, b"first"); + /// ``` + #[macro_export] + macro_rules! assert_msg { + ($stream:expr, $expected:expr) => { + assert_eq!( + $stream + .next() + .await + .unwrap() + .unwrap() + .decrypted_message_bytes, + $expected.as_bytes() + ); + }; + } + + /// A macro for asserting that a stream yields a specific decrypted message. + /// + /// # Example + /// ```rust + /// assert_msg!(stream, b"first"); + /// ``` + #[macro_export] + macro_rules! assert_msg_exists { + ($stream:expr) => { + assert!(!$stream + .next() + .await + .unwrap() + .unwrap() + .decrypted_message_bytes + .is_empty()); + }; + } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] async fn test_stream_welcomes() { let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -617,23 +659,8 @@ pub(crate) mod tests { .create_group(None, GroupMetadataOptions::default()) .unwrap(); - // FIXME:insipx we run into an issue where the reqwest::post().send() request - // blocks the executor and we cannot progress the runtime if we dont `tokio::spawn` this. - // A solution might be to use `hyper` instead, and implement a custom connection pool with - // `deadpool`. This is a bit more work but shouldn't be too complicated since - // we're only using `post` requests. It would be nice for all streams to work - // w/o spawning a separate task. - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - let bob_ptr = bob.clone(); - crate::spawn(None, async move { - let bob_stream = bob_ptr.stream_conversations(None).await.unwrap(); - futures::pin_mut!(bob_stream); - while let Some(item) = bob_stream.next().await { - let _ = tx.send(item); - } - }); - + let stream = bob.stream_conversations(None).await.unwrap(); + futures::pin_mut!(stream); let group_id = alice_bob_group.group_id.clone(); alice_bob_group .add_members_by_inbox_id(&[bob.inbox_id()]) @@ -644,7 +671,7 @@ pub(crate) mod tests { assert_eq!(bob_received_groups.group_id, group_id); } - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] async fn test_stream_messages() { xmtp_common::logger(); let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -653,6 +680,7 @@ pub(crate) mod tests { let alice_group = alice .create_group(None, GroupMetadataOptions::default()) .unwrap(); + tracing::info!("Group Id = [{}]", hex::encode(&alice_group.group_id)); alice_group .add_members_by_inbox_id(&[bob.inbox_id()]) @@ -664,33 +692,16 @@ pub(crate) mod tests { .unwrap(); let bob_group = bob_groups.first().unwrap(); - let notify = Delivery::new(None); - let notify_ptr = notify.clone(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - crate::spawn(None, async move { - let stream = alice_group.stream().await.unwrap(); - futures::pin_mut!(stream); - while let Some(item) = stream.next().await { - let _ = tx.send(item); - notify_ptr.notify_one(); - } - }); - - let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - // let stream = alice_group.stream().await.unwrap(); + let stream = alice_group.stream().await.unwrap(); futures::pin_mut!(stream); bob_group.send_message(b"hello").await.unwrap(); - tracing::debug!("Bob Sent Message!, waiting for delivery"); - // notify.wait_for_delivery().await.unwrap(); + let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello"); bob_group.send_message(b"hello2").await.unwrap(); - // notify.wait_for_delivery().await.unwrap(); let message = stream.next().await.unwrap().unwrap(); assert_eq!(message.decrypted_message_bytes, b"hello2"); - - // assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] @@ -714,40 +725,20 @@ pub(crate) mod tests { .add_members_by_inbox_id(&[caro.inbox_id()]) .await .unwrap(); - xmtp_common::time::sleep(core::time::Duration::from_millis(100)).await; - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + bo_group.send_message(b"first").await.unwrap(); + assert_msg!(stream, "first"); - let notify = Delivery::new(None); - let notify_pointer = notify.clone(); - let mut handle = Client::::stream_all_messages_with_callback( - Arc::new(caro), - None, - move |message| { - (*messages_clone.lock()).push(message.unwrap()); - notify_pointer.notify_one(); - }, - ); - handle.wait_for_ready().await; + bo_group.send_message(b"second").await.unwrap(); + assert_msg!(stream, "second"); - alix_group.send_message("first".as_bytes()).await.unwrap(); - notify - .wait_for_delivery() - .await - .expect("didn't get `first`"); - bo_group.send_message("second".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); - alix_group.send_message("third".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); - bo_group.send_message("fourth".as_bytes()).await.unwrap(); - notify.wait_for_delivery().await.unwrap(); + alix_group.send_message(b"third").await.unwrap(); + assert_msg!(stream, "third"); - let messages = messages.lock(); - assert_eq!(messages[0].decrypted_message_bytes, b"first"); - assert_eq!(messages[1].decrypted_message_bytes, b"second"); - assert_eq!(messages[2].decrypted_message_bytes, b"third"); - assert_eq!(messages[3].decrypted_message_bytes, b"fourth"); + bo_group.send_message(b"fourth").await.unwrap(); + assert_msg!(stream, "fourth"); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] @@ -765,39 +756,21 @@ pub(crate) mod tests { .await .unwrap(); - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); - let delivery = Delivery::new(None); - let delivery_pointer = delivery.clone(); - let mut handle = Client::::stream_all_messages_with_callback( - caro.clone(), - None, - move |message| { - delivery_pointer.notify_one(); - (*messages_clone.lock()).push(message.unwrap()); - }, - ); - handle.wait_for_ready().await; + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + tracing::info!("\n\nSENDING FIRST MESSAGE\n\n"); alix_group.send_message(b"first").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `first`"); + assert_msg!(stream, "first"); let bo_group = bo.create_dm(caro_wallet.get_address()).await.unwrap(); + assert_msg_exists!(stream); bo_group.send_message(b"second").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `second`"); + assert_msg!(stream, "second"); alix_group.send_message(b"third").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `third`"); + assert_msg!(stream, "third"); let alix_group_2 = alix .create_group(None, GroupMetadataOptions::default()) @@ -808,36 +781,10 @@ pub(crate) mod tests { .unwrap(); alix_group.send_message(b"fourth").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `fourth`"); + assert_msg!(stream, "fourth"); alix_group_2.send_message(b"fifth").await.unwrap(); - delivery - .wait_for_delivery() - .await - .expect("timed out waiting for `fifth`"); - - { - let messages = messages.lock(); - assert_eq!(messages.len(), 5); - } - - let a = handle.abort_handle(); - a.end(); - let _ = handle.join().await; - assert!(a.is_finished()); - - alix_group - .send_message("should not show up".as_bytes()) - .await - .unwrap(); - xmtp_common::time::sleep(core::time::Duration::from_millis(100)).await; - - let messages = messages.lock(); - - assert_eq!(messages.len(), 5); + assert_msg!(stream, "fifth"); } #[ignore] diff --git a/xmtp_mls/src/subscriptions/stream_all.rs b/xmtp_mls/src/subscriptions/stream_all.rs new file mode 100644 index 000000000..a6b0c3913 --- /dev/null +++ b/xmtp_mls/src/subscriptions/stream_all.rs @@ -0,0 +1,87 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::{ + client::ClientError, + groups::scoped_client::ScopedGroupClient, + groups::subscriptions, + storage::{ + group::{ConversationType, GroupQueryArgs}, + group_message::StoredGroupMessage, + }, + Client, +}; +use futures::{ + stream::{self, Stream, StreamExt}, + Future, +}; +use xmtp_id::scw_verifier::SmartContractSignatureVerifier; +use xmtp_proto::api_client::{trait_impls::XmtpApi, XmtpMlsStreams}; + +use super::{MessagesStreamInfo, SubscribeError}; +pub struct StreamAllMessages<'a, C, Welcomes, Messages> { + /// The monolithic XMTP Client + client: &'a C, + /// Type of conversation to stream + conversation_type: Option, + /// Conversations that are being actively streamed + active_conversations: HashMap, MessagesStreamInfo>, + /// Welcomes Stream + welcomes: Welcomes, + /// Messages Stream + messages: Messages, + /// Extra messages from message stream, when the stream switches because + /// of a new group received. + extra_messages: Vec, +} + +impl<'a, A, V, Welcomes, Messages> StreamAllMessages<'a, Client, Welcomes, Messages> +where + A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, +{ + pub async fn new( + client: &'a Client, + conversation_type: Option, + ) -> Result { + let mut active_conversations = async { + let provider = client.mls_provider()?; + client.sync_welcomes(&provider).await?; + + let active_conversations = provider + .conn_ref() + .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? + .into_iter() + .map(Into::into) + .collect::, MessagesStreamInfo>>(); + Ok::<_, ClientError>(active_conversations) + } + .await?; + + let messages = + subscriptions::stream_messages(client, Arc::new(active_conversations.clone())).await?; + let welcomes = client.stream_conversations(conversation_type).await?; + + Self { + client, + conversation_type, + messages, + welcomes, + active_conversations, + extra_messages: Vec::new(), + } + } +} + +impl<'a, C, Welcomes, Messages> Stream for StreamAllMessages<'a, C, Welcomes, Messages> +where + C: ScopedGroupClient, +{ + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } +} diff --git a/xmtp_mls/src/subscriptions/stream_conversations.rs b/xmtp_mls/src/subscriptions/stream_conversations.rs new file mode 100644 index 000000000..3f3f3abb5 --- /dev/null +++ b/xmtp_mls/src/subscriptions/stream_conversations.rs @@ -0,0 +1,241 @@ +use std::{collections::HashSet, marker::PhantomData, sync::Arc, task::Poll}; + +use futures::{prelude::stream::Select, Stream}; +use pin_project_lite::pin_project; +use tokio_stream::wrappers::BroadcastStream; +use xmtp_common::{retry_async, Retry}; +use xmtp_id::scw_verifier::SmartContractSignatureVerifier; +use xmtp_proto::{ + api_client::{trait_impls::XmtpApi, XmtpMlsStreams}, + xmtp::mls::api::v1::WelcomeMessage, +}; + +use crate::{ + groups::{scoped_client::ScopedGroupClient, MlsGroup}, + storage::{group::ConversationType, DbConnection}, + Client, XmtpOpenMlsProvider, +}; + +use super::{LocalEvents, SubscribeError}; + +enum WelcomeOrGroup { + Group(Result, SubscribeError>), + Welcome(Result), +} + +pin_project! { + /// Broadcast stream filtered + mapped to WelcomeOrGroup + struct BroadcastGroupStream { + #[pin] inner: BroadcastStream>, + } +} + +impl BroadcastGroupStream { + fn new(inner: BroadcastStream>) -> Self { + Self { inner } + } +} + +impl Stream for BroadcastGroupStream +where + C: Clone + Send + Sync + 'static, // required by tokio::BroadcastStream +{ + type Item = WelcomeOrGroup; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(event)) => { + let ev = xmtp_common::optify!(event, "Missed messages due to event queue lag") + .and_then(LocalEvents::group_filter); + if let Some(g) = ev { + Ready(Some(WelcomeOrGroup::::Group(Ok(g)))) + } else { + // skip this item since it was either missed due to lag, or not a group + Pending + } + } + Pending => Pending, + Ready(None) => Ready(None), + } + } +} + +pin_project! { + /// Subscription Stream mapped to WelcomeOrGroup + struct SubscriptionStream { + #[pin] inner: S, + _marker: PhantomData, + } +} + +impl SubscriptionStream { + fn new(inner: S) -> Self { + Self { + inner, + _marker: PhantomData, + } + } +} + +impl Stream for SubscriptionStream +where + S: Stream>, +{ + type Item = WelcomeOrGroup; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(welcome)) => Ready(Some(WelcomeOrGroup::Welcome(welcome))), + Pending => Pending, + Ready(None) => Ready(None), + } + } +} + +pin_project! { + pub struct StreamConversations<'a, C, Subscription> { + client: &'a C, + #[pin] inner: Subscription, + conversation_type: Option, + known_welcome_ids: HashSet + } +} + +type MultiplexedSelect = Select, SubscriptionStream>; + +impl<'a, A, V> + StreamConversations< + 'a, + Client, + MultiplexedSelect, ::WelcomeMessageStream<'a>>, + > +where + A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, +{ + pub async fn new( + client: &'a Client, + conversation_type: Option, + conn: &DbConnection, + ) -> Result { + let installation_key = client.installation_public_key(); + let id_cursor = 0; + tracing::info!( + inbox_id = client.inbox_id(), + "Setting up conversation stream" + ); + + let events = + BroadcastGroupStream::new(BroadcastStream::new(client.local_events.subscribe())); + + let subscription = client + .api_client + .subscribe_welcome_messages(installation_key.as_ref(), Some(id_cursor)) + .await?; + let subscription = SubscriptionStream::new(subscription); + let known_welcome_ids = HashSet::from_iter(conn.group_welcome_ids()?.into_iter()); + + let stream = futures::stream::select(events, subscription); + + Ok(Self { + client, + inner: stream, + known_welcome_ids, + conversation_type, + }) + } +} + +impl<'a, C, Subscription> Stream for StreamConversations<'a, C, Subscription> +where + C: ScopedGroupClient + Clone, + Subscription: Stream, SubscribeError>>, +{ + type Item = Result, SubscribeError>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::task::Poll::*; + let this = self.project(); + + match this.inner.poll_next(cx) { + Ready(Some(msg)) => { + todo!() + } + // stream ended + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } + } +} + +impl<'a, C, Subscription> StreamConversations<'a, C, Subscription> +where + C: ScopedGroupClient + Clone, +{ + async fn process_streamed_welcome( + &mut self, + client: C, + provider: &XmtpOpenMlsProvider, + welcome: WelcomeMessage, + ) -> Result, SubscribeError> { + let welcome_v1 = crate::client::extract_welcome_message(welcome)?; + if self.known_welcome_ids.contains(&(welcome_v1.id as i64)) { + let conn = provider.conn_ref(); + self.known_welcome_ids.insert(welcome_v1.id as i64); + let group = conn.find_group_by_welcome_id(welcome_v1.id as i64)?; + tracing::info!( + inbox_id = client.inbox_id(), + group_id = hex::encode(&group.id), + welcome_id = ?group.welcome_id, + "Loading existing group for welcome_id: {:?}", + group.welcome_id + ); + return Ok(MlsGroup::new(client.clone(), group.id, group.created_at_ns)); + } + + let creation_result = retry_async!( + Retry::default(), + (async { + tracing::info!( + installation_id = &welcome_v1.id, + "Trying to process streamed welcome" + ); + let welcome_v1 = &welcome_v1; + client + .context + .store() + .transaction_async(provider, |provider| async move { + MlsGroup::create_from_encrypted_welcome( + Arc::new(client.clone()), + provider, + welcome_v1.hpke_public_key.as_slice(), + &welcome_v1.data, + welcome_v1.id as i64, + ) + .await + }) + .await + }) + ); + + Ok(creation_result?) + } +} From b2f0822a3d3323be151d86054b315e1bf246dc76 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Sun, 22 Dec 2024 18:56:20 -0500 Subject: [PATCH 3/3] progress from friday --- Cargo.lock | 4 +- xmtp_mls/src/client.rs | 4 +- xmtp_mls/src/subscriptions/mod.rs | 15 +- .../src/subscriptions/stream_conversations.rs | 185 +++++++++++++----- 4 files changed, 153 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f13a17713..9795538da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7246,9 +7246,10 @@ dependencies = [ name = "xmtp_api_http" version = "0.1.0" dependencies = [ - "async-stream", "async-trait", + "bytes", "futures", + "pin-project-lite", "reqwest 0.12.9", "serde", "serde_json", @@ -7424,6 +7425,7 @@ dependencies = [ "openssl", "openssl-sys", "parking_lot 0.12.3", + "pin-project-lite", "prost", "rand", "reqwest 0.12.9", diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index d1041dfcf..fe6504bbb 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -558,9 +558,7 @@ where { Some(id) => id, None => { - return Err(ClientError::Storage(StorageError::NotFound( - NotFound::InboxIdForAddress(account_address), - ))); + return Err(NotFound::InboxIdForAddress(account_address).into()); } }; diff --git a/xmtp_mls/src/subscriptions/mod.rs b/xmtp_mls/src/subscriptions/mod.rs index e9b8dcad0..c0f6d1307 100644 --- a/xmtp_mls/src/subscriptions/mod.rs +++ b/xmtp_mls/src/subscriptions/mod.rs @@ -11,7 +11,7 @@ use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage}; // mod stream_all; -// mod stream_conversations; +mod stream_conversations; use crate::{ client::{extract_welcome_message, ClientError}, @@ -24,7 +24,7 @@ use crate::{ consent_record::StoredConsentRecord, group::{ConversationType, GroupQueryArgs, StoredGroup}, group_message::StoredGroupMessage, - StorageError, + NotFound, StorageError, }, Client, XmtpApi, XmtpOpenMlsProvider, }; @@ -223,6 +223,13 @@ impl From for (Vec, MessagesStreamInfo) { } } +// TODO: REMOVE BEFORE MERGING +// TODO: REMOVE BEFORE MERGING +// TODO: REMOVE BEFORE MERGING +pub(self) mod temp { + pub(super) type Result = std::result::Result; +} + #[derive(thiserror::Error, Debug)] pub enum SubscribeError { #[error("failed to start new messages stream {0}")] @@ -231,6 +238,9 @@ pub enum SubscribeError { Client(#[from] ClientError), #[error(transparent)] Group(#[from] GroupError), + #[error(transparent)] + NotFound(#[from] NotFound), + // TODO: Add this to `NotFound` #[error("group message expected in database but is missing")] GroupMessageNotFound, #[error("processing group message in stream: {0}")] @@ -258,6 +268,7 @@ impl RetryableError for SubscribeError { Storage(e) => retryable!(e), Api(e) => retryable!(e), Decode(_) => false, + NotFound(e) => retryable!(e), } } } diff --git a/xmtp_mls/src/subscriptions/stream_conversations.rs b/xmtp_mls/src/subscriptions/stream_conversations.rs index 3f3f3abb5..3bc47894c 100644 --- a/xmtp_mls/src/subscriptions/stream_conversations.rs +++ b/xmtp_mls/src/subscriptions/stream_conversations.rs @@ -1,26 +1,28 @@ -use std::{collections::HashSet, marker::PhantomData, sync::Arc, task::Poll}; +use std::{ + collections::HashSet, future::Future, marker::PhantomData, pin::Pin, + sync::Arc, task::Poll, +}; -use futures::{prelude::stream::Select, Stream}; +use crate::{ + groups::{scoped_client::ScopedGroupClient, MlsGroup}, + storage::{group::ConversationType, DbConnection, NotFound}, + Client, XmtpOpenMlsProvider, +}; +use futures::{future::FutureExt, prelude::stream::Select, Stream}; use pin_project_lite::pin_project; use tokio_stream::wrappers::BroadcastStream; use xmtp_common::{retry_async, Retry}; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{ api_client::{trait_impls::XmtpApi, XmtpMlsStreams}, - xmtp::mls::api::v1::WelcomeMessage, -}; - -use crate::{ - groups::{scoped_client::ScopedGroupClient, MlsGroup}, - storage::{group::ConversationType, DbConnection}, - Client, XmtpOpenMlsProvider, + xmtp::mls::api::v1::{welcome_message::V1 as WelcomeMessageV1, WelcomeMessage}, }; -use super::{LocalEvents, SubscribeError}; +use super::{temp::Result, LocalEvents, SubscribeError}; enum WelcomeOrGroup { - Group(Result, SubscribeError>), - Welcome(Result), + Group(Result>), + Welcome(Result), } pin_project! { @@ -85,7 +87,7 @@ impl SubscriptionStream { impl Stream for SubscriptionStream where - S: Stream>, + S: Stream>, { type Item = WelcomeOrGroup; @@ -97,7 +99,10 @@ where let this = self.project(); match this.inner.poll_next(cx) { - Ready(Some(welcome)) => Ready(Some(WelcomeOrGroup::Welcome(welcome))), + Ready(Some(welcome)) => { + let welcome = welcome.map_err(SubscribeError::from); + Ready(Some(WelcomeOrGroup::Welcome(welcome))) + } Pending => Pending, Ready(None) => Ready(None), } @@ -109,7 +114,27 @@ pin_project! { client: &'a C, #[pin] inner: Subscription, conversation_type: Option, - known_welcome_ids: HashSet + known_welcome_ids: HashSet, + #[pin] state: ProcessState<'a, C>, + } +} + +pin_project! { + #[project = ProcessProject] + enum ProcessState<'a, C> { + /// State where we are waiting on the next Message from the network + Waiting, + /// State where we are waiting on an IO/Network future to finish processing the current message + /// before moving on to the next one + Processing { + #[pin] future: Pin, Option) >> + 'a >> + } + } +} + +impl<'a, C> Default for ProcessState<'a, C> { + fn default() -> Self { + ProcessState::Waiting } } @@ -129,7 +154,7 @@ where client: &'a Client, conversation_type: Option, conn: &DbConnection, - ) -> Result { + ) -> Result { let installation_key = client.installation_public_key(); let id_cursor = 0; tracing::info!( @@ -154,6 +179,7 @@ where inner: stream, known_welcome_ids, conversation_type, + state: ProcessState::Waiting, }) } } @@ -161,27 +187,58 @@ where impl<'a, C, Subscription> Stream for StreamConversations<'a, C, Subscription> where C: ScopedGroupClient + Clone, - Subscription: Stream, SubscribeError>>, + Subscription: Stream>> + 'a, { - type Item = Result, SubscribeError>; + type Item = Result>; fn poll_next( - self: std::pin::Pin<&mut Self>, + mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { use std::task::Poll::*; - let this = self.project(); + use ProcessState::*; + let mut this = self.as_mut().project(); - match this.inner.poll_next(cx) { - Ready(Some(msg)) => { - todo!() - } - // stream ended - Ready(None) => Ready(None), - Pending => { - cx.waker().wake_by_ref(); - Pending + match this.state.as_mut().project() { + ProcessProject::Waiting => { + match this.inner.poll_next(cx) { + Ready(Some(item)) => { + let future = + // need to clone client into Arc<> here b/c: + // otherwise the `'1` ref for `Pin<&mut Self>` in arg to `poll_next` needs to + // live as long as `'a` ref for `Client`. + // This is because we're boxing this future (i.e `Box`). + // There maybe a way to avoid it, but we need to `Box<>` the type + // b/c there's no way to get the anonymous future type on the stack generated by an + // `async fn`. If we can somehow store `impl Trait` on a struct (or + // something similar), we could avoid the `Clone` + `Arc`ing. + Self::process_new_item(this.known_welcome_ids.clone(), Arc::new(this.client.clone()), item); + + this.state.set(ProcessState::Processing { + future: future.boxed(), + }); + Pending + } + // stream ended + Ready(None) => Ready(None), + Pending => { + cx.waker().wake_by_ref(); + Pending + } + } } + /// We're processing a message we received + ProcessProject::Processing { future } => match future.poll(cx) { + Ready(Ok((group, welcome_id))) => { + if let Some(id) = welcome_id { + this.known_welcome_ids.insert(id); + } + this.state.set(ProcessState::Waiting); + Ready(Some(Ok(group))) + } + Ready(Err(e)) => Ready(Some(Err(e))), + Pending => Pending, + }, } } } @@ -190,17 +247,42 @@ impl<'a, C, Subscription> StreamConversations<'a, C, Subscription> where C: ScopedGroupClient + Clone, { - async fn process_streamed_welcome( - &mut self, - client: C, + async fn process_new_item( + known_welcome_ids: HashSet, + client: Arc, + item: Result>, + ) -> Result<(MlsGroup, Option)> { + use WelcomeOrGroup::*; + let provider = client.context().mls_provider()?; + match item? { + Welcome(w) => Self::on_welcome(&known_welcome_ids, client, &provider, w?).await, + Group(g) => { + todo!() + } + } + } + + // process a new welcome, returning the new welcome ID + async fn on_welcome( + known_welcome_ids: &HashSet, + client: Arc, provider: &XmtpOpenMlsProvider, welcome: WelcomeMessage, - ) -> Result, SubscribeError> { - let welcome_v1 = crate::client::extract_welcome_message(welcome)?; - if self.known_welcome_ids.contains(&(welcome_v1.id as i64)) { + ) -> Result<(MlsGroup, Option)> { + let WelcomeMessageV1 { + id, + ref created_ns, + ref installation_key, + ref data, + ref hpke_public_key, + } = crate::client::extract_welcome_message(welcome)?; + let id = id as i64; + + if known_welcome_ids.contains(&(id)) { let conn = provider.conn_ref(); - self.known_welcome_ids.insert(welcome_v1.id as i64); - let group = conn.find_group_by_welcome_id(welcome_v1.id as i64)?; + let group = conn + .find_group_by_welcome_id(id)? + .ok_or(NotFound::GroupByWelcome(id))?; tracing::info!( inbox_id = client.inbox_id(), group_id = hex::encode(&group.id), @@ -208,34 +290,39 @@ where "Loading existing group for welcome_id: {:?}", group.welcome_id ); - return Ok(MlsGroup::new(client.clone(), group.id, group.created_at_ns)); + return Ok(( + MlsGroup::new(Arc::unwrap_or_clone(client), group.id, group.created_at_ns), + Some(id), + )); } - let creation_result = retry_async!( + let c = &client; + let mls_group = retry_async!( Retry::default(), (async { tracing::info!( - installation_id = &welcome_v1.id, + installation_id = hex::encode(installation_key), + welcome_id = &id, "Trying to process streamed welcome" ); - let welcome_v1 = &welcome_v1; - client - .context + + (*client) + .context() .store() .transaction_async(provider, |provider| async move { MlsGroup::create_from_encrypted_welcome( - Arc::new(client.clone()), + Arc::clone(c), provider, - welcome_v1.hpke_public_key.as_slice(), - &welcome_v1.data, - welcome_v1.id as i64, + hpke_public_key.as_slice(), + data, + id, ) .await }) .await }) - ); + )?; - Ok(creation_result?) + Ok((mls_group, Some(id))) } }