Skip to content

Commit

Permalink
Add edges to analysis h5 (#707)
Browse files Browse the repository at this point in the history
* Add edge names and edge indices to analysis h5

* Add test for node names and edge names in analysis h5 export
  • Loading branch information
roomrys authored Apr 15, 2022
1 parent 13ba3f6 commit 487b3dd
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
10 changes: 10 additions & 0 deletions sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def get_nodes_as_np_strings(labels: Labels) -> List[np.string_]:
return [np.string_(node.name) for node in labels.skeletons[0].nodes]


def get_edges_as_np_strings(labels: Labels) -> List[Tuple[np.string_, np.string_]]:
"""Get list of edge names as `np.string_`."""
return [
(np.string_(src_name), np.string_(dst_name))
for (src_name, dst_name) in labels.skeletons[0].edge_names
]


def get_occupancy_and_points_matrices(
labels: Labels, all_frames: bool
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -252,6 +260,8 @@ def main(labels: Labels, output_path: str, all_frames: bool = True):
data_dict = dict(
track_names=track_names,
node_names=get_nodes_as_np_strings(labels),
edge_names=get_edges_as_np_strings(labels),
edge_inds=labels.skeletons[0].edge_inds,
tracks=locations_matrix,
track_occupancy=occupancy_matrix,
point_scores=point_scores,
Expand Down
39 changes: 38 additions & 1 deletion tests/info/test_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,52 @@
get_occupancy_and_points_matrices,
remove_empty_tracks_from_matrices,
write_occupancy_file,
get_nodes_as_np_strings,
get_edges_as_np_strings,
)
from sleap.io.dataset import Labels


def test_output_matrices(centered_pair_predictions):
def test_output_matrices(centered_pair_predictions: Labels):

names = get_tracks_as_np_strings(centered_pair_predictions)
assert len(names) == 27
assert isinstance(names[0], np.string_)

# Check that node names and edges are read correctly
node_names = [
n.decode() for n in get_nodes_as_np_strings(centered_pair_predictions)
]
edge_names = [
(s.decode(), d.decode())
for (s, d) in get_edges_as_np_strings(centered_pair_predictions)
]

assert node_names[0] == "head"
assert node_names[1] == "neck"
assert node_names[2] == "thorax"
assert node_names[3] == "abdomen"
assert node_names[4] == "wingL"
assert node_names[5] == "wingR"
assert node_names[6] == "forelegL1"
assert node_names[7] == "forelegL2"
assert node_names[8] == "forelegL3"
assert node_names[9] == "forelegR1"
assert node_names[10] == "forelegR2"
assert node_names[11] == "forelegR3"
assert node_names[12] == "midlegL1"
assert node_names[13] == "midlegL2"
assert node_names[14] == "midlegL3"
assert node_names[15] == "midlegR1"
assert node_names[16] == "midlegR2"
assert node_names[17] == "midlegR3"

# Both lines check edge_names are read correctly, but latter is used in bento plugin
assert edge_names == centered_pair_predictions.skeleton.edge_names
for (src_node, dst_node) in edge_names:
assert src_node in node_names
assert dst_node in node_names

# Remove the first labeled frame
del centered_pair_predictions[0]
assert len(centered_pair_predictions) == 1099
Expand Down

0 comments on commit 487b3dd

Please sign in to comment.