Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add edges to analysis h5 #707

Merged
merged 4 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def get_nodes_as_np_strings(labels: Labels) -> List[np.string_]:
return [np.string_(node.name) for node in labels.skeletons[0].nodes]


def get_edges_as_np_strings(labels: Labels) -> List[Tuple[np.string_, np.string_]]:
"""Get list of edge names as `np.string_`."""
return [
(np.string_(src_name), np.string_(dst_name))
for (src_name, dst_name) in labels.skeletons[0].edge_names
]


def get_occupancy_and_points_matrices(
labels: Labels, all_frames: bool
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -252,6 +260,8 @@ def main(labels: Labels, output_path: str, all_frames: bool = True):
data_dict = dict(
track_names=track_names,
node_names=get_nodes_as_np_strings(labels),
edge_names=get_edges_as_np_strings(labels),
roomrys marked this conversation as resolved.
Show resolved Hide resolved
edge_inds=labels.skeletons[0].edge_inds,
tracks=locations_matrix,
track_occupancy=occupancy_matrix,
point_scores=point_scores,
Expand Down
39 changes: 38 additions & 1 deletion tests/info/test_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,52 @@
get_occupancy_and_points_matrices,
remove_empty_tracks_from_matrices,
write_occupancy_file,
get_nodes_as_np_strings,
get_edges_as_np_strings,
)
from sleap.io.dataset import Labels


def test_output_matrices(centered_pair_predictions):
def test_output_matrices(centered_pair_predictions: Labels):

names = get_tracks_as_np_strings(centered_pair_predictions)
assert len(names) == 27
assert isinstance(names[0], np.string_)

# Check that node names and edges are read correctly
node_names = [
n.decode() for n in get_nodes_as_np_strings(centered_pair_predictions)
]
edge_names = [
(s.decode(), d.decode())
for (s, d) in get_edges_as_np_strings(centered_pair_predictions)
]

assert node_names[0] == "head"
assert node_names[1] == "neck"
assert node_names[2] == "thorax"
assert node_names[3] == "abdomen"
assert node_names[4] == "wingL"
assert node_names[5] == "wingR"
assert node_names[6] == "forelegL1"
assert node_names[7] == "forelegL2"
assert node_names[8] == "forelegL3"
assert node_names[9] == "forelegR1"
assert node_names[10] == "forelegR2"
assert node_names[11] == "forelegR3"
assert node_names[12] == "midlegL1"
assert node_names[13] == "midlegL2"
assert node_names[14] == "midlegL3"
assert node_names[15] == "midlegR1"
assert node_names[16] == "midlegR2"
assert node_names[17] == "midlegR3"

# Both lines check edge_names are read correctly, but latter is used in bento plugin
assert edge_names == centered_pair_predictions.skeleton.edge_names
for (src_node, dst_node) in edge_names:
assert src_node in node_names
assert dst_node in node_names
roomrys marked this conversation as resolved.
Show resolved Hide resolved

# Remove the first labeled frame
del centered_pair_predictions[0]
assert len(centered_pair_predictions) == 1099
Expand Down