diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 5deb0291f..53c8e1297 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -16,6 +16,7 @@ use xmtp_id::{ associations::{builder::SignatureRequest, generate_inbox_id as xmtp_id_generate_inbox_id}, InboxId, }; +use xmtp_mls::client::FindGroupParams; use xmtp_mls::groups::group_mutable_metadata::MetadataField; use xmtp_mls::groups::group_permissions::BasePolicies; use xmtp_mls::groups::group_permissions::GroupMutablePermissionsError; @@ -763,6 +764,20 @@ impl FfiConversations { Ok(out) } + pub async fn create_dm(&self, account_address: String) -> Result, GenericError> { + log::info!("creating dm with target address: {}", account_address); + + let convo = self.inner_client.create_dm(account_address).await?; + + let out = Arc::new(FfiGroup { + inner_client: self.inner_client.clone(), + group_id: convo.group_id, + created_at_ns: convo.created_at_ns, + }); + + Ok(out) + } + pub async fn process_streamed_welcome_message( &self, envelope_bytes: Vec, @@ -787,7 +802,16 @@ impl FfiConversations { pub async fn sync_all_groups(&self) -> Result { let inner = self.inner_client.as_ref(); - let groups = inner.find_groups(None, None, None, None)?; + let groups = inner.find_groups(FindGroupParams { + include_dm_groups: true, + ..FindGroupParams::default() + })?; + + log::info!( + "groups for client inbox id {:?}: {:?}", + self.inner_client.inbox_id(), + groups.len() + ); let num_groups_synced: usize = inner.sync_all_groups(groups).await?; // Uniffi does not work with usize, so we need to convert to u32 @@ -807,12 +831,13 @@ impl FfiConversations { ) -> Result>, GenericError> { let inner = self.inner_client.as_ref(); let convo_list: Vec> = inner - .find_groups( - None, - opts.created_after_ns, - opts.created_before_ns, - opts.limit, - )? + .find_groups(FindGroupParams { + allowed_states: None, + created_after_ns: opts.created_after_ns, + created_before_ns: opts.created_before_ns, + limit: opts.limit, + include_dm_groups: false, + })? .into_iter() .map(|group| { Arc::new(FfiGroup { @@ -828,14 +853,17 @@ impl FfiConversations { pub async fn stream(&self, callback: Box) -> FfiStreamCloser { let client = self.inner_client.clone(); - let handle = - RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| { + let handle = RustXmtpClient::stream_conversations_with_callback( + client.clone(), + move |convo| { callback.on_conversation(Arc::new(FfiGroup { inner_client: client.clone(), group_id: convo.group_id, created_at_ns: convo.created_at_ns, })) - }); + }, + false, + ); FfiStreamCloser::new(handle) } @@ -3683,6 +3711,37 @@ mod tests { ); } + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] + async fn test_dms_sync_but_do_not_list() { + let alix = new_test_client().await; + let bola = new_test_client().await; + + let alix_conversations = alix.conversations(); + let bola_conversations = bola.conversations(); + + let _alix_group = alix_conversations + .create_dm(bola.account_address.clone()) + .await + .unwrap(); + let alix_num_sync = alix_conversations.sync_all_groups().await.unwrap(); + bola_conversations.sync().await.unwrap(); + let bola_num_sync = bola_conversations.sync_all_groups().await.unwrap(); + assert_eq!(alix_num_sync, 1); + assert_eq!(bola_num_sync, 1); + + let alix_groups = alix_conversations + .list(FfiListConversationsOptions::default()) + .await + .unwrap(); + assert_eq!(alix_groups.len(), 0); + + let bola_groups = bola_conversations + .list(FfiListConversationsOptions::default()) + .await + .unwrap(); + assert_eq!(bola_groups.len(), 0); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn test_set_and_get_group_consent() { let alix = new_test_client().await; diff --git a/bindings_node/src/conversations.rs b/bindings_node/src/conversations.rs index 3928c36f0..ed39f37fa 100644 --- a/bindings_node/src/conversations.rs +++ b/bindings_node/src/conversations.rs @@ -6,6 +6,7 @@ use napi::bindgen_prelude::{Error, Result, Uint8Array}; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}; use napi::JsFunction; use napi_derive::napi; +use xmtp_mls::client::FindGroupParams; use xmtp_mls::groups::{GroupMetadataOptions, PreconfiguredPolicies}; use crate::messages::NapiMessage; @@ -171,12 +172,12 @@ impl NapiConversations { }; let convo_list: Vec = self .inner_client - .find_groups( - None, - opts.created_after_ns, - opts.created_before_ns, - opts.limit, - ) + .find_groups(FindGroupParams { + created_after_ns: opts.created_after_ns, + created_before_ns: opts.created_before_ns, + limit: opts.limit, + ..FindGroupParams::default() + }) .map_err(|e| Error::from_reason(format!("{}", e)))? .into_iter() .map(|group| { @@ -196,8 +197,9 @@ impl NapiConversations { let tsfn: ThreadsafeFunction = callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?; let client = self.inner_client.clone(); - let stream_closer = - RustXmtpClient::stream_conversations_with_callback(client.clone(), move |convo| { + let stream_closer = RustXmtpClient::stream_conversations_with_callback( + client.clone(), + move |convo| { tsfn.call( Ok(NapiGroup::new( client.clone(), @@ -206,7 +208,9 @@ impl NapiConversations { )), ThreadsafeFunctionCallMode::Blocking, ); - }); + }, + false, + ); Ok(NapiStreamCloser::new(stream_closer)) } diff --git a/examples/cli/cli-client.rs b/examples/cli/cli-client.rs index 6ede48fa3..9125e2ff6 100755 --- a/examples/cli/cli-client.rs +++ b/examples/cli/cli-client.rs @@ -19,6 +19,7 @@ use futures::future::join_all; use kv_log_macro::{error, info}; use prost::Message; use xmtp_id::associations::unverified::{UnverifiedRecoverableEcdsaSignature, UnverifiedSignature}; +use xmtp_mls::client::FindGroupParams; use xmtp_mls::groups::message_history::MessageHistoryContent; use xmtp_mls::storage::group_message::GroupMessageKind; @@ -209,7 +210,7 @@ async fn main() { // recv(&client).await.unwrap(); let group_list = client - .find_groups(None, None, None, None) + .find_groups(FindGroupParams::default()) .expect("failed to list groups"); for group in group_list.iter() { group.sync(&client).await.expect("error syncing group"); diff --git a/xmtp_mls/migrations/2024-09-09-231735_create_dm_inbox_id/down.sql b/xmtp_mls/migrations/2024-09-09-231735_create_dm_inbox_id/down.sql new file mode 100644 index 000000000..2dbbf5787 --- /dev/null +++ b/xmtp_mls/migrations/2024-09-09-231735_create_dm_inbox_id/down.sql @@ -0,0 +1,3 @@ +-- This file should undo anything in `up.sql` +ALTER TABLE groups DROP COLUMN dm_inbox_id; +DROP INDEX idx_dm_target; diff --git a/xmtp_mls/migrations/2024-09-09-231735_create_dm_inbox_id/up.sql b/xmtp_mls/migrations/2024-09-09-231735_create_dm_inbox_id/up.sql new file mode 100644 index 000000000..6614f371a --- /dev/null +++ b/xmtp_mls/migrations/2024-09-09-231735_create_dm_inbox_id/up.sql @@ -0,0 +1,3 @@ +-- Your SQL goes here +ALTER TABLE groups ADD COLUMN dm_inbox_id text; +CREATE INDEX idx_dm_target ON groups(dm_inbox_id); diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 47d742584..a5982cb26 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -218,6 +218,15 @@ impl From<&str> for ClientError { } } +#[derive(Debug, Default)] +pub struct FindGroupParams { + pub allowed_states: Option>, + pub created_after_ns: Option, + pub created_before_ns: Option, + pub limit: Option, + pub include_dm_groups: bool, +} + /// Clients manage access to the network, identity, and data store #[derive(Debug)] pub struct Client { @@ -496,15 +505,42 @@ where } /// Create a new Direct Message with the default settings - pub fn create_dm(&self, dm_target_inbox_id: InboxId) -> Result { + pub async fn create_dm(&self, account_address: String) -> Result { + log::info!("creating dm with address: {}", account_address); + + let inbox_id = match self + .find_inbox_id_from_address(account_address.clone()) + .await? + { + Some(id) => id, + None => { + return Err(ClientError::Storage(StorageError::NotFound(format!( + "inbox id for address {} not found", + account_address + )))) + } + }; + + self.create_dm_by_inbox_id(inbox_id).await + } + + /// Create a new Direct Message with the default settings + pub async fn create_dm_by_inbox_id( + &self, + dm_target_inbox_id: InboxId, + ) -> Result { log::info!("creating dm with {}", dm_target_inbox_id); let group = MlsGroup::create_dm_and_insert( self.context.clone(), GroupMembershipState::Allowed, - dm_target_inbox_id, + dm_target_inbox_id.clone(), )?; + group + .add_members_by_inbox_id(self, vec![dm_target_inbox_id]) + .await?; + // notify any streams of the new group let _ = self.local_events.send(LocalEvents::NewGroup(group.clone())); @@ -558,17 +594,17 @@ where /// - created_after_ns: only return groups created after the given timestamp (in nanoseconds) /// - created_before_ns: only return groups created before the given timestamp (in nanoseconds) /// - limit: only return the first `limit` groups - pub fn find_groups( - &self, - allowed_states: Option>, - created_after_ns: Option, - created_before_ns: Option, - limit: Option, - ) -> Result, ClientError> { + pub fn find_groups(&self, params: FindGroupParams) -> Result, ClientError> { Ok(self .store() .conn()? - .find_groups(allowed_states, created_after_ns, created_before_ns, limit)? + .find_groups( + params.allowed_states, + params.created_after_ns, + params.created_before_ns, + params.limit, + params.include_dm_groups, + )? .into_iter() .map(|stored_group| { MlsGroup::new( @@ -870,6 +906,7 @@ mod tests { use crate::{ builder::ClientBuilder, + client::FindGroupParams, groups::GroupMetadataOptions, hpke::{decrypt_welcome, encrypt_welcome}, identity::serialize_key_package_hash_ref, @@ -971,7 +1008,7 @@ mod tests { .create_group(None, GroupMetadataOptions::default()) .unwrap(); - let groups = client.find_groups(None, None, None, None).unwrap(); + let groups = client.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(groups.len(), 2); assert_eq!(groups[0].group_id, group_1.group_id); assert_eq!(groups[1].group_id, group_2.group_id); @@ -1037,7 +1074,7 @@ mod tests { let bob_received_groups = bo.sync_welcomes().await.unwrap(); assert_eq!(bob_received_groups.len(), 2); - let bo_groups = bo.find_groups(None, None, None, None).unwrap(); + let bo_groups = bo.find_groups(FindGroupParams::default()).unwrap(); let bo_group1 = bo.group(alix_bo_group1.clone().group_id).unwrap(); let bo_messages1 = bo_group1 .find_messages(None, None, None, None, None) @@ -1142,7 +1179,7 @@ mod tests { log::info!("Syncing bolas welcomes"); // See if Bola can see that they were added to the group bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); log::info!("Syncing bolas messages"); @@ -1275,7 +1312,7 @@ mod tests { bo.sync_welcomes().await.unwrap(); // Bo should have two groups now - let bo_groups = bo.find_groups(None, None, None, None).unwrap(); + let bo_groups = bo.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bo_groups.len(), 2); // Bo's original key should be deleted diff --git a/xmtp_mls/src/groups/group_metadata.rs b/xmtp_mls/src/groups/group_metadata.rs index b274df3f1..d80ea93ec 100644 --- a/xmtp_mls/src/groups/group_metadata.rs +++ b/xmtp_mls/src/groups/group_metadata.rs @@ -17,6 +17,8 @@ pub enum GroupMetadataError { InvalidConversationType, #[error("missing extension")] MissingExtension, + #[error("invalid dm members")] + InvalidDmMembers, #[error("missing a dm member")] MissingDmMember, } diff --git a/xmtp_mls/src/groups/group_mutable_metadata.rs b/xmtp_mls/src/groups/group_mutable_metadata.rs index b97c14c0e..13a91aa42 100644 --- a/xmtp_mls/src/groups/group_mutable_metadata.rs +++ b/xmtp_mls/src/groups/group_mutable_metadata.rs @@ -114,7 +114,7 @@ impl GroupMutableMetadata { } // Admin / super admin is not needed for a DM - pub fn new_dm_default(_creator_inbox_id: String, _dm_target_inbox_id: String) -> Self { + pub fn new_dm_default(_creator_inbox_id: String, _dm_target_inbox_id: &str) -> Self { let mut attributes = HashMap::new(); // TODO: would it be helpful to incorporate the dm inbox ids in the name or description? attributes.insert( diff --git a/xmtp_mls/src/groups/message_history.rs b/xmtp_mls/src/groups/message_history.rs index 0f0768397..86aaca7ba 100644 --- a/xmtp_mls/src/groups/message_history.rs +++ b/xmtp_mls/src/groups/message_history.rs @@ -135,7 +135,7 @@ where pub async fn ensure_member_of_all_groups(&self, inbox_id: String) -> Result<(), GroupError> { let conn = self.store().conn()?; - let groups = conn.find_groups(None, None, None, None)?; + let groups = conn.find_groups(None, None, None, None, false)?; for group in groups { let group = self.group(group.id)?; Box::pin(group.add_members_by_inbox_id(self, vec![inbox_id.clone()])).await?; @@ -384,7 +384,7 @@ where self.sync_welcomes().await?; let conn = self.store().conn()?; - let groups = conn.find_groups(None, None, None, None)?; + let groups = conn.find_groups(None, None, None, None, false)?; for crate::storage::group::StoredGroup { id, .. } in groups.into_iter() { let group = self.group(id)?; Box::pin(group.sync(self)).await?; @@ -502,14 +502,14 @@ where async fn prepare_groups_to_sync(&self) -> Result, MessageHistoryError> { let conn = self.store().conn()?; - Ok(conn.find_groups(None, None, None, None)?) + Ok(conn.find_groups(None, None, None, None, false)?) } async fn prepare_messages_to_sync( &self, ) -> Result, MessageHistoryError> { let conn = self.store().conn()?; - let groups = conn.find_groups(None, None, None, None)?; + let groups = conn.find_groups(None, None, None, None, false)?; let mut all_messages: Vec = vec![]; for StoredGroup { id, .. } in groups.into_iter() { diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index f60f37a8d..61cea7a98 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -324,6 +324,7 @@ impl MlsGroup { now_ns(), membership_state, context.inbox_id(), + None, ); stored_group.store(provider.conn_ref())?; @@ -345,7 +346,7 @@ impl MlsGroup { let protected_metadata = build_dm_protected_metadata_extension(context.inbox_id(), dm_target_inbox_id.clone())?; let mutable_metadata = - build_dm_mutable_metadata_extension_default(context.inbox_id(), dm_target_inbox_id)?; + build_dm_mutable_metadata_extension_default(context.inbox_id(), &dm_target_inbox_id)?; let group_membership = build_starting_group_membership_extension(context.inbox_id(), 0); let mutable_permissions = PolicySet::new_dm(); let mutable_permission_extension = @@ -373,6 +374,7 @@ impl MlsGroup { now_ns(), membership_state, context.inbox_id(), + Some(dm_target_inbox_id), ); stored_group.store(provider.conn_ref())?; @@ -398,6 +400,16 @@ impl MlsGroup { let mls_group = mls_welcome.into_group(provider)?; let group_id = mls_group.group_id().to_vec(); let metadata = extract_group_metadata(&mls_group)?; + let dm_members = metadata.dm_members; + let dm_inbox_id = if let Some(dm_members) = &dm_members { + if dm_members.member_one_inbox_id == client.inbox_id() { + Some(dm_members.member_two_inbox_id.clone()) + } else { + Some(dm_members.member_one_inbox_id.clone()) + } + } else { + None + }; let group_type = metadata.conversation_type; let to_store = match group_type { @@ -408,6 +420,7 @@ impl MlsGroup { added_by_inbox, welcome_id, Purpose::Conversation, + dm_inbox_id, ), ConversationType::Dm => { validate_dm_group(client, &mls_group, &added_by_inbox)?; @@ -418,6 +431,7 @@ impl MlsGroup { added_by_inbox, welcome_id, Purpose::Conversation, + dm_inbox_id, ) } ConversationType::Sync => StoredGroup::new_from_welcome( @@ -427,6 +441,7 @@ impl MlsGroup { added_by_inbox, welcome_id, Purpose::Sync, + dm_inbox_id, ), }; @@ -1096,11 +1111,8 @@ impl MlsGroup { .unwrap() }); let mutable_metadata = custom_mutable_metadata.unwrap_or_else(|| { - build_dm_mutable_metadata_extension_default( - context.inbox_id(), - dm_target_inbox_id.clone(), - ) - .unwrap() + build_dm_mutable_metadata_extension_default(context.inbox_id(), &dm_target_inbox_id) + .unwrap() }); let group_membership = custom_group_membership .unwrap_or_else(|| build_starting_group_membership_extension(context.inbox_id(), 0)); @@ -1131,6 +1143,7 @@ impl MlsGroup { now_ns(), GroupMembershipState::Allowed, // Use Allowed as default for tests context.inbox_id(), + Some(dm_target_inbox_id), ); stored_group.store(provider.conn_ref())?; @@ -1212,7 +1225,7 @@ pub fn build_mutable_metadata_extension_default( pub fn build_dm_mutable_metadata_extension_default( creator_inbox_id: String, - dm_target_inbox_id: String, + dm_target_inbox_id: &str, ) -> Result { let mutable_metadata: Vec = GroupMutableMetadata::new_dm_default(creator_inbox_id, dm_target_inbox_id).try_into()?; @@ -1525,7 +1538,7 @@ mod tests { use crate::{ assert_err, assert_logged, builder::ClientBuilder, - client::MessageProcessingError, + client::{FindGroupParams, MessageProcessingError}, codecs::{group_updated::GroupUpdatedCodec, ContentCodec}, groups::{ build_dm_protected_metadata_extension, build_group_membership_extension, @@ -1559,7 +1572,7 @@ mod tests { ApiClient: XmtpApi, { client.sync_welcomes().await.unwrap(); - let mut groups = client.find_groups(None, None, None, None).unwrap(); + let mut groups = client.find_groups(FindGroupParams::default()).unwrap(); groups.remove(0) } @@ -1871,7 +1884,7 @@ mod tests { // Bo should not be able to actually read this group bo.sync_welcomes().await.unwrap(); - let groups = bo.find_groups(None, None, None, None).unwrap(); + let groups = bo.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(groups.len(), 0); assert_logged!("failed to create group from welcome", 1); } @@ -1999,7 +2012,7 @@ mod tests { .expect("send message"); bola_client.sync_welcomes().await.unwrap(); - let bola_groups = bola_client.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola_client.find_groups(FindGroupParams::default()).unwrap(); let bola_group = bola_groups.first().unwrap(); bola_group.sync(&bola_client).await.unwrap(); let bola_messages = bola_group @@ -2352,7 +2365,7 @@ mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); @@ -2520,7 +2533,7 @@ mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); @@ -2599,7 +2612,7 @@ mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); @@ -2615,7 +2628,7 @@ mod tests { // Verify that bola can not add caro because they are not an admin bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group: &MlsGroup = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); @@ -2679,7 +2692,7 @@ mod tests { // Verify that bola can not add charlie because they are not an admin bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group: &MlsGroup = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); @@ -2707,7 +2720,7 @@ mod tests { .await .unwrap(); bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); @@ -2723,7 +2736,7 @@ mod tests { // Verify that bola can not add caro as an admin because they are not a super admin bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); assert_eq!(bola_groups.len(), 1); let bola_group: &MlsGroup = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); @@ -2970,7 +2983,7 @@ mod tests { // Step 3: Verify that Bola can update the group name, and amal sees the update bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); let bola_group: &MlsGroup = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); bola_group @@ -3035,7 +3048,7 @@ mod tests { // Step 3: Bola attemps to add Caro, but fails because group is admin only let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; bola.sync_welcomes().await.unwrap(); - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola.find_groups(FindGroupParams::default()).unwrap(); let bola_group: &MlsGroup = bola_groups.first().unwrap(); bola_group.sync(&bola).await.unwrap(); let result = bola_group @@ -3179,33 +3192,31 @@ mod tests { let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; // Amal creates a dm group targetting bola - let amal_dm: MlsGroup = amal.create_dm(bola.inbox_id()).unwrap(); + let amal_dm: MlsGroup = amal.create_dm_by_inbox_id(bola.inbox_id()).await.unwrap(); // Amal can not add caro to the dm group let result = amal_dm - .add_members_by_inbox_id(&amal, vec![bola.inbox_id(), caro.inbox_id()]) + .add_members_by_inbox_id(&amal, vec![caro.inbox_id()]) .await; assert!(result.is_err()); + + // Bola is already a member let result = amal_dm - .add_members_by_inbox_id(&amal, vec![caro.inbox_id()]) + .add_members_by_inbox_id(&amal, vec![bola.inbox_id(), caro.inbox_id()]) .await; assert!(result.is_err()); amal_dm.sync(&amal).await.unwrap(); let members = amal_dm.members(&amal).await.unwrap(); - assert_eq!(members.len(), 1); - - // Amal can add bola - amal_dm - .add_members_by_inbox_id(&amal, vec![bola.inbox_id()]) - .await - .unwrap(); - amal_dm.sync(&amal).await.unwrap(); - let members = amal_dm.members(&amal).await.unwrap(); assert_eq!(members.len(), 2); // Bola can message amal let _ = bola.sync_welcomes().await; - let bola_groups = bola.find_groups(None, None, None, None).unwrap(); + let bola_groups = bola + .find_groups(FindGroupParams { + include_dm_groups: true, + ..FindGroupParams::default() + }) + .unwrap(); let bola_dm: &MlsGroup = bola_groups.first().unwrap(); bola_dm.send_message(b"test one", &bola).await.unwrap(); diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index d5516e7ec..b68e61e11 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -39,6 +39,8 @@ pub struct StoredGroup { pub added_by_inbox_id: String, /// The sequence id of the welcome message pub welcome_id: Option, + /// The inbox_id of the DM target + pub dm_inbox_id: Option, } impl_fetch!(StoredGroup, groups, Vec); @@ -53,6 +55,7 @@ impl StoredGroup { added_by_inbox_id: String, welcome_id: i64, purpose: Purpose, + dm_inbox_id: Option, ) -> Self { Self { id, @@ -62,6 +65,7 @@ impl StoredGroup { purpose, added_by_inbox_id, welcome_id: Some(welcome_id), + dm_inbox_id, } } @@ -71,6 +75,7 @@ impl StoredGroup { created_at_ns: i64, membership_state: GroupMembershipState, added_by_inbox_id: String, + dm_inbox_id: Option, ) -> Self { Self { id, @@ -80,6 +85,7 @@ impl StoredGroup { purpose: Purpose::Conversation, added_by_inbox_id, welcome_id: None, + dm_inbox_id, } } @@ -98,6 +104,7 @@ impl StoredGroup { purpose: Purpose::Sync, added_by_inbox_id: "".into(), welcome_id: None, + dm_inbox_id: None, } } } @@ -110,6 +117,7 @@ impl DbConnection { created_after_ns: Option, created_before_ns: Option, limit: Option, + include_dm_groups: bool, ) -> Result, StorageError> { let mut query = dsl::groups.order(dsl::created_at_ns.asc()).into_boxed(); @@ -129,6 +137,10 @@ impl DbConnection { query = query.limit(limit); } + if !include_dm_groups { + query = query.filter(dsl::dm_inbox_id.is_null()); + } + query = query.filter(dsl::purpose.eq(Purpose::Conversation)); Ok(self.raw_query(|conn| query.load(conn))?) @@ -331,6 +343,22 @@ pub(crate) mod tests { created_at_ns, membership_state, "placeholder_address".to_string(), + None, + ) + } + + /// Generate a test dm group + pub fn generate_dm(state: Option) -> StoredGroup { + let id = rand_vec(); + let created_at_ns = now_ns(); + let membership_state = state.unwrap_or(GroupMembershipState::Allowed); + let dm_inbox_id = Some("placeholder_inbox_id".to_string()); + StoredGroup::new( + id, + created_at_ns, + membership_state, + "placeholder_address".to_string(), + dm_inbox_id, ) } @@ -392,23 +420,31 @@ pub(crate) mod tests { test_group_1.store(conn).unwrap(); let test_group_2 = generate_group(Some(GroupMembershipState::Allowed)); test_group_2.store(conn).unwrap(); + let test_group_3 = generate_dm(Some(GroupMembershipState::Allowed)); + test_group_3.store(conn).unwrap(); - let all_results = conn.find_groups(None, None, None, None).unwrap(); + let all_results = conn.find_groups(None, None, None, None, false).unwrap(); assert_eq!(all_results.len(), 2); let pending_results = conn - .find_groups(Some(vec![GroupMembershipState::Pending]), None, None, None) + .find_groups( + Some(vec![GroupMembershipState::Pending]), + None, + None, + None, + false, + ) .unwrap(); assert_eq!(pending_results[0].id, test_group_1.id); assert_eq!(pending_results.len(), 1); // Offset and limit - let results_with_limit = conn.find_groups(None, None, None, Some(1)).unwrap(); + let results_with_limit = conn.find_groups(None, None, None, Some(1), false).unwrap(); assert_eq!(results_with_limit.len(), 1); assert_eq!(results_with_limit[0].id, test_group_1.id); let results_with_created_at_ns_after = conn - .find_groups(None, Some(test_group_1.created_at_ns), None, Some(1)) + .find_groups(None, Some(test_group_1.created_at_ns), None, Some(1), false) .unwrap(); assert_eq!(results_with_created_at_ns_after.len(), 1); assert_eq!(results_with_created_at_ns_after[0].id, test_group_2.id); @@ -417,7 +453,10 @@ pub(crate) mod tests { let synced_groups = conn.find_sync_groups().unwrap(); assert_eq!(synced_groups.len(), 0); - // test that ONLY normal groups show up. + // test that dm groups are included + let dm_results = conn.find_groups(None, None, None, None, true).unwrap(); + assert_eq!(dm_results.len(), 3); + assert_eq!(dm_results[2].id, test_group_3.id); }) } diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 4ff7615e7..b5104f9ac 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -404,6 +404,7 @@ mod tests { 100, GroupMembershipState::Allowed, "placeholder_address".to_string(), + None, ); group.store(conn).unwrap(); } diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 687ba7631..6b1ab0731 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -679,6 +679,7 @@ mod tests { 0, GroupMembershipState::Allowed, "goodbye".to_string(), + None, ); group.store(connection)?; Ok(()) @@ -730,6 +731,7 @@ mod tests { 0, GroupMembershipState::Allowed, "goodbye".to_string(), + None, ); group.store(conn1).unwrap(); diff --git a/xmtp_mls/src/storage/encrypted_store/schema.rs b/xmtp_mls/src/storage/encrypted_store/schema.rs index 7d835a5b9..7266406cd 100644 --- a/xmtp_mls/src/storage/encrypted_store/schema.rs +++ b/xmtp_mls/src/storage/encrypted_store/schema.rs @@ -53,6 +53,7 @@ diesel::table! { purpose -> Integer, added_by_inbox_id -> Text, welcome_id -> Nullable, + dm_inbox_id -> Nullable, } } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 2bb6f8cc4..c69bfdaac 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use crate::xmtp_openmls_provider::XmtpOpenMlsProvider; use futures::{FutureExt, Stream, StreamExt}; use prost::Message; use tokio::{sync::oneshot, task::JoinHandle}; @@ -9,7 +10,7 @@ use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage; use crate::{ api::GroupFilter, client::{extract_welcome_message, ClientError}, - groups::{extract_group_id, GroupError, MlsGroup}, + groups::{extract_group_id, group_metadata::ConversationType, GroupError, MlsGroup}, retry::Retry, retry_async, storage::{group::StoredGroup, group_message::StoredGroupMessage}, @@ -130,18 +131,42 @@ where pub async fn stream_conversations( &self, + include_dm: bool, ) -> Result + '_, ClientError> { + let provider = Arc::new(self.context.mls_provider()?); + let event_queue = tokio_stream::wrappers::BroadcastStream::new(self.local_events.subscribe()); - let event_queue = event_queue.filter_map(|event| async move { - match event { - Ok(LocalEvents::NewGroup(g)) => Some(g), - Err(BroadcastStreamRecvError::Lagged(missed)) => { - log::warn!("Missed {missed} messages due to local event queue lagging"); + // Helper function for filtering Dm groups + let filter_group = move |group: MlsGroup, provider: Arc| async move { + match group.metadata(provider.as_ref()) { + Ok(metadata) => { + if include_dm || metadata.conversation_type != ConversationType::Dm { + Some(group) + } else { + None + } + } + Err(err) => { + log::error!("Error processing group metadata: {:?}", err); None } } + }; + + let event_provider = Arc::clone(&provider); + let event_queue = event_queue.filter_map(move |event| { + let provider = Arc::clone(&event_provider); + async move { + match event { + Ok(LocalEvents::NewGroup(group)) => filter_group(group, provider).await, + Err(BroadcastStreamRecvError::Lagged(missed)) => { + log::warn!("Missed {missed} messages due to local event queue lagging"); + None + } + } + } }); let installation_key = self.installation_public_key(); @@ -152,17 +177,24 @@ where .subscribe_welcome_messages(installation_key, Some(id_cursor)) .await?; + let stream_provider = Arc::clone(&provider); let stream = subscription .map(|welcome| async { log::info!("Received conversation streaming payload"); self.process_streamed_welcome(welcome?).await }) - .filter_map(|res| async { - match res.await { - Ok(group) => Some(group), - Err(err) => { - log::error!("Error processing stream entry for conversation: {:?}", err); - None + .filter_map(move |res| { + let provider = Arc::clone(&stream_provider); + async move { + match res.await { + Ok(group) => filter_group(group, provider).await, + Err(err) => { + log::error!( + "Error processing stream entry for conversation: {:?}", + err + ); + None + } } } }); @@ -234,11 +266,12 @@ where pub fn stream_conversations_with_callback( client: Arc>, mut convo_callback: impl FnMut(MlsGroup) + Send + 'static, + include_dm: bool, ) -> StreamHandle> { let (tx, rx) = oneshot::channel(); let handle = tokio::spawn(async move { - let stream = client.stream_conversations().await?; + let stream = client.stream_conversations(include_dm).await?; futures::pin_mut!(stream); let _ = tx.send(()); while let Some(convo) = stream.next().await { @@ -287,7 +320,7 @@ where let mut group_id_to_info = self .store() .conn()? - .find_groups(None, None, None, None)? + .find_groups(None, None, None, None, false)? .into_iter() .map(Into::into) .collect::, MessagesStreamInfo>>(); @@ -298,7 +331,7 @@ where .await?; futures::pin_mut!(messages_stream); - let convo_stream = self.stream_conversations().await?; + let convo_stream = self.stream_conversations(true).await?; futures::pin_mut!(convo_stream); let mut extra_messages = Vec::new(); @@ -416,7 +449,7 @@ mod tests { let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); let bob_ptr = bob.clone(); tokio::spawn(async move { - let bob_stream = bob_ptr.stream_conversations().await.unwrap(); + let bob_stream = bob_ptr.stream_conversations(true).await.unwrap(); futures::pin_mut!(bob_stream); while let Some(item) = bob_stream.next().await { let _ = tx.send(item); @@ -726,12 +759,15 @@ mod tests { let notify = Delivery::new(None); let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone()); - let closer = - Client::::stream_conversations_with_callback(alix.clone(), move |g| { + let closer = Client::::stream_conversations_with_callback( + alix.clone(), + move |g| { let mut groups = groups_pointer.lock(); groups.push(g); notify_pointer.notify_one(); - }); + }, + false, + ); alix.create_group(None, GroupMetadataOptions::default()) .unwrap(); @@ -763,4 +799,67 @@ mod tests { closer.handle.abort(); } + + #[tokio::test(flavor = "multi_thread")] + async fn test_dm_streaming() { + let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + + let groups = Arc::new(Mutex::new(Vec::new())); + // Wait for 2 seconds for the group creation to be streamed + let notify = Delivery::new(Some(std::time::Duration::from_secs(1))); + let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone()); + + // Start a stream with enableDm set to false + let closer = Client::::stream_conversations_with_callback( + alix.clone(), + move |g| { + let mut groups = groups_pointer.lock(); + groups.push(g); + notify_pointer.notify_one(); + }, + false, + ); + + alix.create_dm_by_inbox_id(bo.inbox_id()).await.unwrap(); + + let result = notify.wait_for_delivery().await; + assert!(result.is_err(), "Stream unexpectedly received a DM group"); + + closer.handle.abort(); + + // Start a stream with enableDm set to true + let groups = Arc::new(Mutex::new(Vec::new())); + // Wait for 2 seconds for the group creation to be streamed + let notify = Delivery::new(Some(std::time::Duration::from_secs(1))); + let (notify_pointer, groups_pointer) = (notify.clone(), groups.clone()); + let closer = Client::::stream_conversations_with_callback( + alix.clone(), + move |g| { + let mut groups = groups_pointer.lock(); + groups.push(g); + notify_pointer.notify_one(); + }, + true, + ); + + alix.create_dm_by_inbox_id(bo.inbox_id()).await.unwrap(); + notify.wait_for_delivery().await.unwrap(); + { + let grps = groups.lock(); + assert_eq!(grps.len(), 1); + } + + let dm = bo.create_dm_by_inbox_id(alix.inbox_id()).await.unwrap(); + dm.add_members_by_inbox_id(&bo, vec![alix.inbox_id()]) + .await + .unwrap(); + notify.wait_for_delivery().await.unwrap(); + { + let grps = groups.lock(); + assert_eq!(grps.len(), 2); + } + + closer.handle.abort(); + } }