Skip to content

Commit

Permalink
Rust: improve generate_segments
Browse files Browse the repository at this point in the history
- sort segments by physical length if weights are provided
- return segment lengths (Option)
  • Loading branch information
schlegelp committed Jan 7, 2025
1 parent c25e47f commit 98b3d7c
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions fastcore/src/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ fn find_roots(parents: &ArrayView1<i32>) -> Vec<i32> {
///
/// A vector of vectors where each vector contains the nodes of a segment.
///
pub fn generate_segments<T>(parents: &ArrayView1<i32>, weights: Option<Array1<T>>) -> Vec<Vec<i32>>
pub fn generate_segments<T>(parents: &ArrayView1<i32>, weights: Option<Array1<T>>) -> (Vec<Vec<i32>>, Option<Vec<T>>)
where
T: Float + AddAssign,
T: Float + AddAssign + std::iter::Sum + std::fmt::Debug,
{
let mut all_segments: Vec<Vec<i32>> = vec![];
let mut current_segment = Array::from_elem(parents.len(), -1i32);
let mut seen = Array::from_elem(parents.len(), false);
let mut i: usize;
let mut node: i32;
let mut lengths: Option<Vec<T>> = None;

let weights: Option<Array1<T>> = if weights.is_some() {
Some(weights.unwrap().to_owned())
Expand Down Expand Up @@ -186,8 +187,31 @@ where
all_segments.push(current_segment.slice(s![..i]).iter().cloned().collect());
}
// println!("Found {} segments", all_segments.len());
all_segments.sort_by(|a, b| b.len().cmp(&a.len()));
all_segments

// If no weights, we can just sort by length
if weights.is_none() {
all_segments.sort_by(|a, b| b.len().cmp(&a.len()));
} else {
// If weights are provided we need to sort by the sum of the weights
let weights = weights.unwrap();
lengths = Some(all_segments
.iter()
.map(|segment| {
segment
.iter()
.map(|&node| weights[node as usize])
.sum::<T>()
})
.collect());
let lengths_unwrapped = lengths.as_ref().unwrap();
// Generate indices for sorting
let mut indices: Vec<usize> = (0..all_segments.len()).collect();
// Sort indices by the lengths
indices.sort_by(|a, b| lengths_unwrapped[*b].partial_cmp(&lengths_unwrapped[*a]).unwrap());
// Sort the segments by the sorted indices
all_segments = indices.iter().map(|&i| all_segments[i].clone()).collect();
}
(all_segments, lengths)
}

/// Break neuron into linear segments connecting leafs, branch points and root(s).
Expand Down

0 comments on commit 98b3d7c

Please sign in to comment.