Skip to content

Commit

Permalink
just use an arc
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Jan 9, 2025
1 parent ae29dfa commit b5734f1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 25 deletions.
20 changes: 5 additions & 15 deletions xmtp_mls/src/groups/device_sync/backup.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::storage::DbConnection;
use backup_element::{BackupElement, BackupRecordStreamer};
use futures::Stream;
use serde::{Deserialize, Serialize};
use std::{ops::Range, sync::Arc};
use xmtp_proto::xmtp::device_sync::consent_backup::ConsentRecordSave;

use crate::storage::DbConnection;

mod backup_element;

#[derive(Serialize, Deserialize)]
Expand All @@ -17,8 +17,7 @@ pub struct BackupMetadata {
}

pub struct BackupOptions {
from_ns: u64,
to_ns: u64,
range_ns: Option<Range<u64>>,
elements: Vec<BackupSelection>,
}

Expand All @@ -31,7 +30,7 @@ pub enum BackupSelection {
impl BackupSelection {
fn to_streamers(
&self,
conn: &'static DbConnection,
conn: &Arc<DbConnection>,
) -> Vec<Box<dyn Stream<Item = Vec<BackupElement>>>> {
match self {
Self::Consent => vec![Box::new(BackupRecordStreamer::<ConsentRecordSave>::new(
Expand All @@ -43,7 +42,7 @@ impl BackupSelection {
}

impl BackupOptions {
pub fn write(self, conn: &'static DbConnection) -> BackupWriter {
pub fn write(self, conn: &Arc<DbConnection>) -> BackupWriter {
let input_streams = self
.elements
.iter()
Expand All @@ -61,12 +60,3 @@ struct BackupWriter {
options: BackupOptions,
input_streams: Vec<Vec<Box<dyn Stream<Item = Vec<BackupElement>>>>>,
}

impl Stream for BackupWriter {
type Item = Vec<BackupElement>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
}
}
18 changes: 9 additions & 9 deletions xmtp_mls/src/groups/device_sync/backup/backup_element.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::marker::PhantomData;
use std::{marker::PhantomData, sync::Arc};

use futures::Stream;
use serde::{Deserialize, Serialize};
Expand All @@ -23,28 +23,28 @@ pub enum BackupElement {

trait BackupRecordProvider {
const BATCH_SIZE: i64;
fn backup_records(streamer: &BackupRecordStreamer<'_, Self>) -> Vec<BackupElement>
fn backup_records(streamer: &BackupRecordStreamer<Self>) -> Vec<BackupElement>
where
Self: Sized;
}

pub(super) struct BackupRecordStreamer<'a, R> {
pub(super) struct BackupRecordStreamer<R> {
offset: i64,
conn: &'a DbConnection,
conn: Arc<DbConnection>,
_phantom: PhantomData<R>,
}

impl<'a, R> BackupRecordStreamer<'a, R> {
pub(super) fn new(conn: &'a DbConnection) -> Self {
impl<R> BackupRecordStreamer<R> {
pub(super) fn new(conn: &Arc<DbConnection>) -> Self {
Self {
offset: 0,
conn,
conn: conn.clone(),
_phantom: PhantomData,
}
}
}

impl<'a, R> Stream for BackupRecordStreamer<'a, R>
impl<R> Stream for BackupRecordStreamer<R>
where
R: BackupRecordProvider + Unpin,
{
Expand All @@ -57,7 +57,7 @@ where

// Get a mutable reference to self
let this = self.get_mut();
let batch = R::backup_records(&*this);
let batch = R::backup_records(this);

// If no records found, we've reached the end of the stream
if batch.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use xmtp_proto::xmtp::device_sync::consent_backup::{

impl BackupRecordProvider for ConsentRecordSave {
const BATCH_SIZE: i64 = 100;
fn backup_records(streamer: &BackupRecordStreamer<'_, Self>) -> Vec<BackupElement>
fn backup_records(streamer: &BackupRecordStreamer<Self>) -> Vec<BackupElement>
where
Self: Sized,
{
Expand Down

0 comments on commit b5734f1

Please sign in to comment.