Skip to content

Commit

Permalink
filter dms in stream_conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
cameronvoell committed Sep 21, 2024
1 parent 6fb0385 commit fc02806
Showing 1 changed file with 48 additions and 21 deletions.
69 changes: 48 additions & 21 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{collections::HashMap, sync::Arc};

use crate::xmtp_openmls_provider::XmtpOpenMlsProvider;
use futures::{FutureExt, Stream, StreamExt};
use prost::Message;
use tokio::{sync::oneshot, task::JoinHandle};
Expand Down Expand Up @@ -130,18 +131,42 @@ where

pub async fn stream_conversations(
&self,
include_dm: bool,
) -> Result<impl Stream<Item = MlsGroup> + '_, ClientError> {
let provider = Arc::new(self.context.mls_provider()?);

let event_queue =
tokio_stream::wrappers::BroadcastStream::new(self.local_events.subscribe());

let event_queue = event_queue.filter_map(|event| async move {
match event {
Ok(LocalEvents::NewGroup(g)) => Some(g),
Err(BroadcastStreamRecvError::Lagged(missed)) => {
log::warn!("Missed {missed} messages due to local event queue lagging");
// Helper function for filtering Dm groups
let filter_group = move |group: MlsGroup, provider: Arc<XmtpOpenMlsProvider>| async move {
match group.metadata(provider.as_ref()) {
Ok(metadata) => {
if include_dm || metadata.conversation_type != ConversationType::Dm {
Some(group)
} else {
None
}
}
Err(err) => {
log::error!("Error processing group metadata: {:?}", err);
None
}
}
};

let event_provider = Arc::clone(&provider);
let event_queue = event_queue.filter_map(move |event| {
let provider = Arc::clone(&event_provider);
async move {
match event {
Ok(LocalEvents::NewGroup(group)) => filter_group(group, provider).await,
Err(BroadcastStreamRecvError::Lagged(missed)) => {
log::warn!("Missed {missed} messages due to local event queue lagging");
None
}
}
}
});

let installation_key = self.installation_public_key();
Expand All @@ -152,17 +177,24 @@ where
.subscribe_welcome_messages(installation_key, Some(id_cursor))
.await?;

let stream_provider = Arc::clone(&provider);
let stream = subscription
.map(|welcome| async {
log::info!("Received conversation streaming payload");
self.process_streamed_welcome(welcome?).await
})
.filter_map(|res| async {
match res.await {
Ok(group) => Some(group),
Err(err) => {
log::error!("Error processing stream entry for conversation: {:?}", err);
None
.filter_map(move |res| {
let provider = Arc::clone(&stream_provider);
async move {
match res.await {
Ok(group) => filter_group(group, provider).await,
Err(err) => {
log::error!(
"Error processing stream entry for conversation: {:?}",
err
);
None
}
}
}
});
Expand Down Expand Up @@ -239,16 +271,11 @@ where
let (tx, rx) = oneshot::channel();

let handle = tokio::spawn(async move {
let stream = client.stream_conversations().await?;
let stream = client.stream_conversations(include_dm).await?;
futures::pin_mut!(stream);
let _ = tx.send(());
while let Some(convo) = stream.next().await {
let provider = client.context.mls_provider()?;
// Don't execute callback for dms unless include_dm is true
if include_dm || convo.metadata(provider)?.conversation_type != ConversationType::Dm
{
convo_callback(convo)
}
convo_callback(convo)
}
log::debug!("`stream_conversations` stream ended, dropping stream");
Ok(())
Expand Down Expand Up @@ -304,7 +331,7 @@ where
.await?;
futures::pin_mut!(messages_stream);

let convo_stream = self.stream_conversations().await?;
let convo_stream = self.stream_conversations(true).await?;
futures::pin_mut!(convo_stream);

let mut extra_messages = Vec::new();
Expand Down Expand Up @@ -422,7 +449,7 @@ mod tests {
let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
let bob_ptr = bob.clone();
tokio::spawn(async move {
let bob_stream = bob_ptr.stream_conversations().await.unwrap();
let bob_stream = bob_ptr.stream_conversations(true).await.unwrap();
futures::pin_mut!(bob_stream);
while let Some(item) = bob_stream.next().await {
let _ = tx.send(item);
Expand Down Expand Up @@ -774,7 +801,7 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread")]
async fn test_dm_creation() {
async fn test_dm_streaming() {
let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);

Expand Down

0 comments on commit fc02806

Please sign in to comment.