diff --git a/sleap/nn/paf_grouping.py b/sleap/nn/paf_grouping.py index 4d763375e..4a6a0a157 100644 --- a/sleap/nn/paf_grouping.py +++ b/sleap/nn/paf_grouping.py @@ -1310,13 +1310,13 @@ def toposort_edges(edge_types: List[EdgeType]) -> Tuple[int]: See also: assign_connections_to_instances """ - dg = nx.DiGraph( - [(edge_type.src_node_ind, edge_type.dst_node_ind) for edge_type in edge_types] - ) - lg = nx.line_graph(dg) - sorted_dg = list(nx.topological_sort(lg)) - lg = list(lg) - sorted_edge_inds = tuple([lg.index(edge) for edge in sorted_dg]) + edges = [ + (edge_type.src_node_ind, edge_type.dst_node_ind) for edge_type in edge_types + ] + dg = nx.DiGraph(edges) + root_ind = next(nx.topological_sort(dg)) + sorted_edges = nx.bfs_edges(dg, root_ind) + sorted_edge_inds = tuple([edges.index(edge) for edge in sorted_edges]) return sorted_edge_inds diff --git a/tests/nn/test_paf_grouping.py b/tests/nn/test_paf_grouping.py index 7fb66ab91..4856c1fed 100644 --- a/tests/nn/test_paf_grouping.py +++ b/tests/nn/test_paf_grouping.py @@ -322,6 +322,22 @@ def test_toposort_edges(): sorted_edge_inds = toposort_edges(edge_types) assert sorted_edge_inds == (12, 13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) + edge_inds = [ + (1, 4), + (1, 5), + (6, 8), + (6, 7), + (6, 9), + (9, 10), + (1, 0), + (1, 3), + (1, 2), + (6, 1), + ] + edge_types = [EdgeType(src_node, dst_node) for src_node, dst_node in edge_inds] + sorted_edge_inds = toposort_edges(edge_types) + assert sorted_edge_inds == (2, 3, 4, 9, 5, 0, 1, 6, 7, 8) + def test_assign_connections_to_instances(): connections = {