Skip to content

Commit

Permalink
Started refactoring on par_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
valebes committed Nov 15, 2023
1 parent 7e07d3a commit 75c24ae
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/templates/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ where
F: FnOnce(TKey, Vec<TIn>) -> (TKey, TReduce) + Send + Copy,
{
fn run(&mut self, input: TInIter) -> Option<TOutIter> {
let res: TOutIter = self.threadpool.par_reduce(input, self.f).collect();
let res: TOutIter = self.threadpool.par_reduce_by_key(input, self.f).collect();
Some(res)
}
fn number_of_replicas(&self) -> usize {
Expand Down Expand Up @@ -489,7 +489,7 @@ where
F: FnOnce(TKey, Vec<TIn>) -> (TKey, TReduce) + Send + Copy,
{
fn run(&mut self, input: TInIter) -> Option<TOutIter> {
let res: TOutIter = self.threadpool.par_reduce(input, self.f).collect();
let res: TOutIter = self.threadpool.par_reduce_by_key(input, self.f).collect();
Some(res)
}
fn number_of_replicas(&self) -> usize {
Expand Down
16 changes: 8 additions & 8 deletions src/templates/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
{
/// Creates a new source from any type that implements the `Iterator` trait.
/// The source will terminate when the iterator is exhausted.
///
///
/// # Arguments
/// * `iterator` - Type that implements the [`Iterator`] trait
/// and represents the stream of data we want emit.
Expand Down Expand Up @@ -123,7 +123,7 @@ where
///
/// # Arguments
/// * `chunk_size` - Number of elements for each chunk.
///
///
/// # Examples
/// Given a stream of numbers, we create a pipeline with a splitter that
/// create vectors of two elements each.
Expand Down Expand Up @@ -204,7 +204,7 @@ where
T: Send + 'static + Clone,
{
/// Creates a new aggregator node.
///
///
/// # Arguments
/// * `chunk_size` - Number of elements for each chunk.
///
Expand Down Expand Up @@ -232,7 +232,7 @@ where
}

/// Creates a new aggregator node with 'n_replicas' replicas of the same node.
///
///
/// # Arguments
/// * `n_replicas` - Number of replicas.
/// * `chunk_size` - Number of elements for each chunk.
Expand Down Expand Up @@ -291,7 +291,7 @@ where
F: FnMut(T) -> U + Send + 'static + Clone,
{
/// Creates a new sequential node.
///
///
/// # Arguments
/// * `f` - Function name or lambda function that specify the logic
/// of this node.
Expand Down Expand Up @@ -334,7 +334,7 @@ where
F: FnMut(T) -> U + Send + 'static + Clone,
{
/// Creates a new parallel node.
///
///
/// # Arguments
/// * `n_replicas` - Number of replicas.
/// * `f` - Function name or lambda function that specify the logic
Expand Down Expand Up @@ -380,7 +380,7 @@ where
F: FnMut(&T) -> bool + Send + 'static + Clone,
{
/// Creates a new filter node.
///
///
/// # Arguments
/// * `f` - Function name or lambda function that represent the predicate
/// function we want to apply.
Expand Down Expand Up @@ -561,7 +561,7 @@ where
T: Send + 'static + Clone,
{
/// Creates a new ordered aggregator node
///
///
/// # Arguments
/// * `chunk_size` - Number of elements for each chunk.
pub fn build(chunk_size: usize) -> impl InOut<T, Vec<T>> {
Expand Down
51 changes: 47 additions & 4 deletions src/thread_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ impl ThreadPool {
});
});

drop(arc_tx);
drop(arc_tx); // Refactoring?

let mut disconnected = false;

Expand Down Expand Up @@ -442,7 +442,7 @@ impl ThreadPool {
Iter: IntoIterator,
{
let map = self.par_map(iter, f);
self.par_reduce(map, reduce)
self.par_reduce_by_key(map, reduce)
}

/// Reduces in parallel the elements of an iterator `iter` by the function `f`.
Expand All @@ -468,12 +468,16 @@ impl ThreadPool {
/// vec.push((i % 10, i));
/// }
///
/// let res: Vec<(i32, i32)> = pool.par_reduce(vec, |k, v| -> (i32, i32) {
/// let res: Vec<(i32, i32)> = pool.par_reduce_by_key(vec, |k, v| -> (i32, i32) {
/// (k, v.iter().sum())
/// }).collect();
/// assert_eq!(res.len(), 10);
/// ```
pub fn par_reduce<Iter, K, V, R, F>(&mut self, iter: Iter, f: F) -> impl Iterator<Item = (K, R)>
pub fn par_reduce_by_key<Iter, K, V, R, F>(
&mut self,
iter: Iter,
f: F,
) -> impl Iterator<Item = (K, R)>
where
<Iter as IntoIterator>::Item: Send,
K: Send + Ord + 'static,
Expand All @@ -491,6 +495,45 @@ impl ThreadPool {
self.par_map(ordered_map, move |(k, v)| f(k, v))
}

/// Reduce
///
pub fn par_reduce<Iter, V, F>(&mut self, iter: Iter, f: F) -> V
where
<Iter as IntoIterator>::Item: Send,
V: Send + 'static,
F: FnOnce(V, V) -> V + Send + Copy + Sync,
Iter: IntoIterator<Item = V>,
{
let mut data: Vec<V> = iter.into_iter().collect();

while data.len() != 1 {
let mut tmp = Vec::new();
let mut num_proc = self.num_workers;

while data.len() < 2 * num_proc {
num_proc -= 1;

Check warning on line 514 in src/thread_pool/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/thread_pool/mod.rs#L514

Added line #L514 was not covered by tests
}
let mut counter = 0;

while !data.is_empty() {
counter %= num_proc;
tmp.push((counter, data.pop().unwrap()));
counter += 1;
}

data = self
.par_reduce_by_key(tmp, |k, v| {
(k, v.into_iter().reduce(|a, b| f(a, b)).unwrap())
})
.collect::<Vec<(usize, V)>>()
.into_iter()

Check warning on line 529 in src/thread_pool/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/thread_pool/mod.rs#L528-L529

Added lines #L528 - L529 were not covered by tests
.map(|(_a, b)| b)
.collect();

Check warning on line 531 in src/thread_pool/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/thread_pool/mod.rs#L531

Added line #L531 was not covered by tests
}
data.pop().unwrap()

}

/// Create a new scope to execute jobs on other threads.
/// The function passed to this method will be provided with a [`Scope`] object,
/// which can be used to spawn new jobs through the [`Scope::execute`] method.
Expand Down
19 changes: 17 additions & 2 deletions src/thread_pool/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,27 @@ fn test_par_reduce() {
}

let res: Vec<(i32, i32)> = pool
.par_reduce(vec, |k, v| -> (i32, i32) { (k, v.iter().sum()) })
.par_reduce_by_key(vec, |k, v| -> (i32, i32) { (k, v.iter().sum()) })
.collect();

assert_eq!(res.len(), 10);
}

#[test]
#[serial]
fn test_new_reduce() {
let mut pool = ThreadPool::new();

let mut vec = Vec::new();
for _i in 0..130 {
vec.push(1);
}

let res = pool.par_reduce(vec, |a, b| a + b);

assert_eq!(res, 130);
}

#[test]
#[serial]
fn test_par_map_reduce_seq() {
Expand All @@ -178,7 +193,7 @@ fn test_par_map_reduce_seq() {
}

let res = tp.par_map(vec, |el| -> (i32, i32) { (el, 1) });
let res = tp.par_reduce(res, |k, v| (k, v.iter().sum::<i32>()));
let res = tp.par_reduce_by_key(res, |k, v| (k, v.iter().sum::<i32>()));

let mut check = true;
for (k, v) in res {
Expand Down

0 comments on commit 75c24ae

Please sign in to comment.