From 24a046111bfc7373714d8aa8b1b6dcd141deb47b Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Tue, 7 Jan 2025 10:30:48 +0000 Subject: [PATCH] Python: generate_segments returns segment lengths if weights not None --- py/navis_fastcore/dag.py | 15 +++++++++++---- py/src/dag.rs | 6 +++--- py/tests/test_fastcore.py | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/py/navis_fastcore/dag.py b/py/navis_fastcore/dag.py index 44c84b2..0eb2ec0 100644 --- a/py/navis_fastcore/dag.py +++ b/py/navis_fastcore/dag.py @@ -36,6 +36,9 @@ def generate_segments(node_ids, parent_ids, weights=None): Segments as list of arrays, sorted from longest to shortest. Each segment starts with a leaf and stops with a branch point or root node. + lengths : array | None + If `weights` is provided this will be an array of segment + lengths. If `weights` is not provided this will be ``None``. Examples -------- @@ -43,7 +46,8 @@ def generate_segments(node_ids, parent_ids, weights=None): >>> import numpy as np >>> node_ids = np.arange(7) >>> parent_ids = np.array([-1, 0, 1, 2, 1, 4, 5]) - >>> fastcore.generate_segments(node_ids, parent_ids) + >>> segs, _ = fastcore.generate_segments(node_ids, parent_ids) + >>> segs [array([6, 5, 4, 1, 0]), array([3, 2, 1])] """ @@ -58,12 +62,15 @@ def generate_segments(node_ids, parent_ids, weights=None): ), "`weights` must have the same length as `node_ids`" # Get the segments (this will be a list of arrays of node indices) - segments = _fastcore.generate_segments(parent_ix, weights=weights) + segments, lengths = _fastcore.generate_segments(parent_ix, weights=weights) + + if lengths is not None: + lengths = np.asarray(lengths, dtype=np.float32) # Map node indices back to IDs seg_ids = [node_ids[s] for s in segments] - return seg_ids + return seg_ids, lengths def break_segments(node_ids, parent_ids): @@ -163,7 +170,7 @@ def segment_coords( ), "`weights` must have the same length as `node_ids`" # Get the segments (this will be a list of arrays of node indices) - segments = _fastcore.generate_segments(parent_ix, weights=weights) + segments, _ = _fastcore.generate_segments(parent_ix, weights=weights) # Translate into coordinates seg_coords = [coords[s] for s in segments] diff --git a/py/src/dag.rs b/py/src/dag.rs index ec54239..77f452f 100644 --- a/py/src/dag.rs +++ b/py/src/dag.rs @@ -106,16 +106,16 @@ pub fn node_indices_32<'py>( pub fn generate_segments_py( parents: PyReadonlyArray1, weights: Option>, -) -> Vec> { +) -> (Vec>, Option>) { let weights: Option> = if weights.is_some() { Some(weights.unwrap().as_array().to_owned()) } else { None }; - let all_segments = generate_segments(&parents.as_array(), weights); + let (all_segments, lengths) = generate_segments(&parents.as_array(), weights); - all_segments + (all_segments, lengths) } #[pyfunction] diff --git a/py/tests/test_fastcore.py b/py/tests/test_fastcore.py index f217b35..115d144 100644 --- a/py/tests/test_fastcore.py +++ b/py/tests/test_fastcore.py @@ -96,7 +96,7 @@ def test_node_indices(swc): def test_generate_segments(swc, weights): nodes, parents, _ = swc start = time.time() - segments = fastcore.generate_segments(nodes, parents, weights=weights) + segments, lengths = fastcore.generate_segments(nodes, parents, weights=weights) dur = time.time() - start # print("Segments:", segments)