Skip to content

Commit

Permalink
Some debugging around seq_join
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Nov 11, 2023
1 parent abc0403 commit 1a6229d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::{
},
seq_join::{seq_join, SeqJoin},
};
use crate::seq_join::seq_join_with_ctx;

pub mod bucket;
#[cfg(feature = "descriptive-gate")]
Expand Down Expand Up @@ -466,7 +467,7 @@ where
}));

// Execute all of the async futures (sequentially), and flatten the result
let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
let flattenned_stream = seq_join_with_ctx("bk_tv".into(), sh_ctx.active_work(), stream_of_per_user_circuits)
.flat_map(|x| stream_iter(x.unwrap()));

// modulus convert breakdown keys and trigger values
Expand Down
1 change: 1 addition & 0 deletions src/protocol/modulus_conversion/convert_shares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ where
let stream = unfold(
(ctx, locally_converted, first_record),
|(ctx, mut locally_converted, record_id)| async move {
tracing::trace!("convert bits for {}/{record_id}", ctx.gate().as_ref());
let Some((triple, residual)) = locally_converted.next().await else {
return None;
};
Expand Down
29 changes: 27 additions & 2 deletions src/seq_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use std::borrow::Cow;
use std::marker::PhantomData;
use futures::{stream::{iter, Iter as StreamIter, TryCollect}, Future, Stream, StreamExt, TryStreamExt};
use futures_util::stream::FuturesOrdered;
Expand Down Expand Up @@ -98,6 +99,21 @@ pub fn seq_join<'a, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutur
spawner: UnsafeSpawner::default(),
source: source.fuse(),
active: VecDeque::with_capacity(active.get()),
span: None
}
}

pub fn seq_join_with_ctx<'a, S, F, O>(span: Cow<'static, str>, active: NonZeroUsize, source: S) -> SequentialFutures<'a, S, F>
where
S: Stream<Item = F> + Send,
F: Future<Output = O> + Send + 'a,
O: Send + 'static
{
SequentialFutures {
spawner: UnsafeSpawner::default(),
source: source.fuse(),
active: VecDeque::with_capacity(active.get()),
span: Some(span)
}
}

Expand Down Expand Up @@ -224,6 +240,7 @@ impl<F: IntoFuture<Output = T>, T: Send + 'static> ActiveItem<F> {
}

#[pin_project]
#[must_use = "seq_join result must be used."]
pub struct SequentialFutures<'a, S, F>
where
S: Stream<Item = F> + Send,
Expand All @@ -233,6 +250,7 @@ pub struct SequentialFutures<'a, S, F>
#[pin]
source: futures::stream::Fuse<S>,
active: VecDeque<ActiveItem<F>>,
span: Option<Cow<'static, str>>,
}

impl <'a, S, F, T> Stream for SequentialFutures<'a, S, F>
Expand All @@ -246,6 +264,7 @@ impl <'a, S, F, T> Stream for SequentialFutures<'a, S, F>

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let active_before = this.active.len();

// Draw more values from the input, up to the capacity.
while this.active.len() < this.active.capacity() {
Expand All @@ -258,7 +277,7 @@ impl <'a, S, F, T> Stream for SequentialFutures<'a, S, F>
}
}

if let Some(item) = this.active.front_mut() {
let r = if let Some(item) = this.active.front_mut() {
if item.check_ready(cx) {
let v = this.active.pop_front().map(ActiveItem::take);
Poll::Ready(v)
Expand All @@ -272,7 +291,13 @@ impl <'a, S, F, T> Stream for SequentialFutures<'a, S, F>
Poll::Ready(None)
} else {
Poll::Pending
}
};

let _ = this.span.as_ref().map(|r| {
tracing::trace!("{r} seq_join polled. active before = {active_before}, active after = {}", this.active.len());
});

r
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down

0 comments on commit 1a6229d

Please sign in to comment.