Skip to content

Commit

Permalink
Python: generate_segments returns segment lengths if weights not None
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Jan 7, 2025
1 parent 98b3d7c commit 24a0461
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
15 changes: 11 additions & 4 deletions py/navis_fastcore/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ def generate_segments(node_ids, parent_ids, weights=None):
Segments as list of arrays, sorted from longest to shortest.
Each segment starts with a leaf and stops with a branch point
or root node.
lengths : array | None
If `weights` is provided this will be an array of segment
lengths. If `weights` is not provided this will be ``None``.
Examples
--------
>>> import navis_fastcore as fastcore
>>> import numpy as np
>>> node_ids = np.arange(7)
>>> parent_ids = np.array([-1, 0, 1, 2, 1, 4, 5])
>>> fastcore.generate_segments(node_ids, parent_ids)
>>> segs, _ = fastcore.generate_segments(node_ids, parent_ids)
>>> segs
[array([6, 5, 4, 1, 0]), array([3, 2, 1])]
"""
Expand All @@ -58,12 +62,15 @@ def generate_segments(node_ids, parent_ids, weights=None):
), "`weights` must have the same length as `node_ids`"

# Get the segments (this will be a list of arrays of node indices)
segments = _fastcore.generate_segments(parent_ix, weights=weights)
segments, lengths = _fastcore.generate_segments(parent_ix, weights=weights)

if lengths is not None:
lengths = np.asarray(lengths, dtype=np.float32)

# Map node indices back to IDs
seg_ids = [node_ids[s] for s in segments]

return seg_ids
return seg_ids, lengths


def break_segments(node_ids, parent_ids):
Expand Down Expand Up @@ -163,7 +170,7 @@ def segment_coords(
), "`weights` must have the same length as `node_ids`"

# Get the segments (this will be a list of arrays of node indices)
segments = _fastcore.generate_segments(parent_ix, weights=weights)
segments, _ = _fastcore.generate_segments(parent_ix, weights=weights)

# Translate into coordinates
seg_coords = [coords[s] for s in segments]
Expand Down
6 changes: 3 additions & 3 deletions py/src/dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,16 @@ pub fn node_indices_32<'py>(
pub fn generate_segments_py(
parents: PyReadonlyArray1<i32>,
weights: Option<PyReadonlyArray1<f32>>,
) -> Vec<Vec<i32>> {
) -> (Vec<Vec<i32>>, Option<Vec<f32>>) {
let weights: Option<Array1<f32>> = if weights.is_some() {
Some(weights.unwrap().as_array().to_owned())
} else {
None
};

let all_segments = generate_segments(&parents.as_array(), weights);
let (all_segments, lengths) = generate_segments(&parents.as_array(), weights);

all_segments
(all_segments, lengths)
}

#[pyfunction]
Expand Down
2 changes: 1 addition & 1 deletion py/tests/test_fastcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_node_indices(swc):
def test_generate_segments(swc, weights):
nodes, parents, _ = swc
start = time.time()
segments = fastcore.generate_segments(nodes, parents, weights=weights)
segments, lengths = fastcore.generate_segments(nodes, parents, weights=weights)
dur = time.time() - start

# print("Segments:", segments)
Expand Down

0 comments on commit 24a0461

Please sign in to comment.