Skip to content

Commit

Permalink
Add method and (failing) test to get instance grouping
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Nov 15, 2023
1 parent 7fd89ec commit b9af5c3
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
44 changes: 44 additions & 0 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3751,6 +3751,50 @@ def get_all_views_at_frame(

return views

@staticmethod
def get_instance_grouping(
instances: Dict[int, Dict[Camcorder, List[Instance]]],
reprojection_error_per_frame: Dict[int, float],
) -> Dict[int, Dict[Camcorder, List[Instance]]]:
"""Get instance grouping for triangulation."""

frame_with_min_error = min(
reprojection_error_per_frame, key=reprojection_error_per_frame.get
)

best_instances = instances[frame_with_min_error]
best_instances_correct_format = {frame_with_min_error: best_instances}

return best_instances_correct_format

@staticmethod
def calculate_reprojection_per_frame(
session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]]
) -> Dict[int, float]:
"""Calculate reprojection error per frame."""

reprojection_error_per_frame = {}

# Triangulate and reproject instance coordinates.
instances_and_coords: Dict[
int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]
] = TriangulateSession.calculate_reprojected_points(
session=session, instances=instances
)
for frame_id, instances_in_frame in instances_and_coords.items():
frame_error = 0
for cam, instances_in_view in instances_in_frame.items():
# Compare instance coordinates here
view_error = 0
for inst, inst_coords in instances_in_view:
node_errors = np.nan_to_num(inst.numpy() - inst_coords)
instance_error = np.linalg.norm(node_errors)
view_error += instance_error
frame_error += view_error
reprojection_error_per_frame[frame_id] = frame_error

return reprojection_error_per_frame

@staticmethod
def get_permutations_of_instances(
selected_instance: Instance,
Expand Down
62 changes: 62 additions & 0 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,3 +1366,65 @@ def product(seq):
for inst in instances_in_view:
assert inst.frame_idx == selected_instance.frame_idx
assert inst.video == session[cam]


def test_triangulate_session_calculate_reprojection_per_frame(
multiview_min_session_labels: Labels,
):
"""Test `TriangulateSession.get_permutations_of_instances`."""

labels = multiview_min_session_labels
session = labels.sessions[0]
lf = labels.labeled_frames[0]
selected_instance = lf.instances[0]

instances = TriangulateSession.get_permutations_of_instances(
selected_instance=selected_instance,
session=session,
frame_idx=lf.frame_idx,
)

reprojection_error_per_frame = TriangulateSession.calculate_reprojection_per_frame(
session=session, instances=instances
)

for frame_id in instances.keys():
assert frame_id in reprojection_error_per_frame
assert isinstance(reprojection_error_per_frame[frame_id], float)


def test_triangulate_session_get_instance_grouping(
multiview_min_session_labels: Labels,
):
"""Test `TriangulateSession.get_permutations_of_instances`."""

labels = multiview_min_session_labels
session = labels.sessions[0]
lf = labels.labeled_frames[0]
selected_instance = lf.instances[0]

instances = TriangulateSession.get_permutations_of_instances(
selected_instance=selected_instance,
session=session,
frame_idx=lf.frame_idx,
)

reprojection_error_per_frame = TriangulateSession.calculate_reprojection_per_frame(
session=session, instances=instances
)

best_instances: Dict[
int, Dict[Camcorder, Instance]
] = TriangulateSession.get_instance_grouping(
instances=instances, reprojection_error_per_frame=reprojection_error_per_frame
)
assert len(best_instances) == 1
for frame_id, instances_in_frame in best_instances.items():
for cam, instances_in_view in instances_in_frame.items():
for inst in instances_in_view:
assert inst.frame_idx == selected_instance.frame_idx
assert inst.track == selected_instance.track


if __name__ == "__main__":
pytest.main([f"{__file__}::test_triangulate_session_get_instance_grouping"])

0 comments on commit b9af5c3

Please sign in to comment.