Skip to content

Commit

Permalink
Update group.members() (#749)
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas authored May 17, 2024
2 parents aa71fb9 + 6468cac commit f103d3b
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 44 deletions.
6 changes: 4 additions & 2 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ pub struct FfiGroup {

#[derive(uniffi::Record)]
pub struct FfiGroupMember {
pub account_address: String,
pub inbox_id: String,
pub account_addresses: Vec<String>,
pub installation_ids: Vec<Vec<u8>>,
}

Expand Down Expand Up @@ -413,7 +414,8 @@ impl FfiGroup {
.members()?
.into_iter()
.map(|member| FfiGroupMember {
account_address: member.account_address,
inbox_id: member.inbox_id,
account_addresses: member.account_addresses,
installation_ids: member.installation_ids,
})
.collect();
Expand Down
2 changes: 1 addition & 1 deletion examples/cli/serializable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl<'a> From<&'a MlsGroup> for SerializableGroup {
.members()
.expect("could not load members")
.into_iter()
.map(|m| m.account_address)
.map(|m| m.inbox_id)
.collect::<Vec<String>>();

let metadata = group.metadata().expect("could not load metadata");
Expand Down
1 change: 1 addition & 0 deletions xmtp_id/src/associations/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ impl From<AssociationState> for AssociationStateProto {

impl TryFrom<AssociationStateProto> for AssociationState {
type Error = DeserializationError;

fn try_from(proto: AssociationStateProto) -> Result<Self, Self::Error> {
let members = proto
.members
Expand Down
20 changes: 20 additions & 0 deletions xmtp_id/src/associations/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@ impl AssociationState {
.collect()
}

pub fn account_addresses(&self) -> Vec<String> {
self.members_by_kind(MemberKind::Address)
.into_iter()
.filter_map(|member| match member.identifier {
MemberIdentifier::Address(address) => Some(address),
MemberIdentifier::Installation(_) => None,
})
.collect()
}

pub fn installation_ids(&self) -> Vec<Vec<u8>> {
self.members_by_kind(MemberKind::Installation)
.into_iter()
.filter_map(|member| match member.identifier {
MemberIdentifier::Address(_) => None,
MemberIdentifier::Installation(installation_id) => Some(installation_id),
})
.collect()
}

pub fn diff(&self, new_state: &Self) -> AssociationStateDiff {
let new_members: Vec<MemberIdentifier> = new_state
.members
Expand Down
66 changes: 29 additions & 37 deletions xmtp_mls/src/groups/members.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::collections::HashMap;
use xmtp_id::InboxId;

use openmls::{credentials::BasicCredential, group::MlsGroup as OpenMlsGroup};
use super::{validated_commit::extract_group_membership, GroupError, MlsGroup};

use openmls_traits::OpenMlsProvider;

use super::{GroupError, MlsGroup};

use crate::identity::Identity;
use crate::{
storage::association_state::StoredAssociationState, xmtp_openmls_provider::XmtpOpenMlsProvider,
};

#[derive(Debug, Clone)]
pub struct GroupMember {
pub account_address: String,
pub inbox_id: InboxId,
pub account_addresses: Vec<String>,
pub installation_ids: Vec<Vec<u8>>,
}

Expand All @@ -24,39 +23,32 @@ impl MlsGroup {

pub fn members_with_provider(
&self,
provider: impl OpenMlsProvider,
provider: &XmtpOpenMlsProvider,
) -> Result<Vec<GroupMember>, GroupError> {
let openmls_group = self.load_mls_group(provider)?;
aggregate_member_list(&openmls_group)
}
}
// TODO: Replace with try_into from extensions
let group_membership = extract_group_membership(openmls_group.extensions())?;
let requests = group_membership
.members
.into_iter()
.map(|(inbox_id, sequence_id)| (inbox_id, sequence_id as i64))
.collect();

pub fn aggregate_member_list(openmls_group: &OpenMlsGroup) -> Result<Vec<GroupMember>, GroupError> {
let member_map: HashMap<String, GroupMember> = openmls_group
.members()
.filter_map(|member| {
let basic_credential = BasicCredential::try_from(&member.credential).ok()?;
Identity::get_validated_account_address(
basic_credential.identity(),
&member.signature_key,
)
.ok()
.map(|account_address| (account_address, member.signature_key.clone()))
})
.fold(
HashMap::new(),
|mut acc, (account_address, signature_key)| {
acc.entry(account_address.clone())
.and_modify(|e| e.installation_ids.push(signature_key.clone()))
.or_insert(GroupMember {
account_address,
installation_ids: vec![signature_key],
});
acc
},
);
let conn = provider.conn_ref();
let association_state_map = StoredAssociationState::batch_read_from_cache(conn, &requests)?;
// TODO: Figure out what to do with missing members from the local DB. Do we go to the network? Load from identity updates?
// Right now I am just omitting them
let members = association_state_map
.into_iter()
.map(|association_state| GroupMember {
inbox_id: association_state.inbox_id().to_string(),
account_addresses: association_state.account_addresses(),
installation_ids: association_state.installation_ids(),
})
.collect::<Vec<GroupMember>>();

Ok(member_map.into_values().collect())
Ok(members)
}
}

#[cfg(test)]
Expand Down
96 changes: 92 additions & 4 deletions xmtp_mls/src/storage/encrypted_store/association_state.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use diesel::prelude::*;
use prost::Message;
use xmtp_id::associations::{AssociationState, DeserializationError};
use xmtp_id::{
associations::{AssociationState, DeserializationError},
InboxId,
};
use xmtp_proto::xmtp::identity::associations::AssociationState as AssociationStateProto;

use super::schema::association_state;
use super::{
schema::association_state::{self, dsl},
DbConnection,
};
use crate::{impl_fetch, impl_store_or_ignore, storage::StorageError, Fetch, StoreOrIgnore};

/// StoredIdentityUpdate holds a serialized IdentityUpdate record
Expand All @@ -28,7 +34,7 @@ impl TryFrom<StoredAssociationState> for AssociationState {

impl StoredAssociationState {
pub fn write_to_cache(
conn: &super::db_connection::DbConnection,
conn: &DbConnection,
inbox_id: String,
sequence_id: i64,
state: AssociationState,
Expand All @@ -43,7 +49,7 @@ impl StoredAssociationState {
}

pub fn read_from_cache(
conn: &super::db_connection::DbConnection,
conn: &DbConnection,
inbox_id: String,
sequence_id: i64,
) -> Result<Option<AssociationState>, StorageError> {
Expand All @@ -62,4 +68,86 @@ impl StoredAssociationState {
})
.transpose()
}

pub fn batch_read_from_cache(
conn: &DbConnection,
identifiers: &Vec<(InboxId, i64)>,
) -> Result<Vec<AssociationState>, StorageError> {
// If no identifier provided, return empty hash map
if identifiers.is_empty() {
return Ok(vec![]);
}
let mut query = dsl::association_state.into_boxed();
for (inbox_id, sequence_id) in identifiers {
query = query.or_filter(
dsl::inbox_id
.eq(inbox_id)
.and(dsl::sequence_id.eq(sequence_id)),
);
}
let association_states =
conn.raw_query(|query_conn| query.load::<StoredAssociationState>(query_conn))?;

association_states
.into_iter()
.map(|stored_association_state| stored_association_state.try_into())
.collect::<Result<Vec<AssociationState>, DeserializationError>>()
.map_err(|err| StorageError::Deserialization(err.to_string()))
}
}

#[cfg(test)]
mod tests {
use crate::storage::encrypted_store::tests::with_connection;

use super::*;

#[test]
fn test_batch_read() {
with_connection(|conn| {
let association_state = AssociationState::new("1234".to_string(), 0);
let inbox_id = association_state.inbox_id().clone();
StoredAssociationState::write_to_cache(
conn,
inbox_id.to_string(),
1,
association_state,
)
.unwrap();

let association_state_2 = AssociationState::new("456".to_string(), 2);
let inbox_id_2 = association_state_2.inbox_id().clone();
StoredAssociationState::write_to_cache(
conn,
association_state_2.inbox_id().clone(),
2,
association_state_2,
)
.unwrap();

let first_association_state = StoredAssociationState::batch_read_from_cache(
conn,
&vec![(inbox_id.to_string(), 1)],
)
.unwrap();
assert_eq!(first_association_state.len(), 1);
assert_eq!(first_association_state[0].inbox_id(), &inbox_id);

let both_association_states = StoredAssociationState::batch_read_from_cache(
conn,
&vec![(inbox_id.to_string(), 1), (inbox_id_2.to_string(), 2)],
)
.unwrap();

assert_eq!(both_association_states.len(), 2);

let no_results = StoredAssociationState::batch_read_from_cache(
conn,
// Mismatched inbox_id and sequence_id
&vec![(inbox_id.to_string(), 2)],
)
.unwrap();
assert_eq!(no_results.len(), 0);
})
}
}

0 comments on commit f103d3b

Please sign in to comment.