Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(3c -> 3b) Add method to test instance grouping #1599

44 changes: 44 additions & 0 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3751,6 +3751,50 @@ def get_all_views_at_frame(

return views

@staticmethod
def get_instance_grouping(
instances: Dict[int, Dict[Camcorder, List[Instance]]],
reprojection_error_per_frame: Dict[int, float],
) -> Dict[int, Dict[Camcorder, List[Instance]]]:
"""Get instance grouping for triangulation."""

frame_with_min_error = min(
reprojection_error_per_frame, key=reprojection_error_per_frame.get
)

best_instances = instances[frame_with_min_error]
best_instances_correct_format = {frame_with_min_error: best_instances}

return best_instances_correct_format

@staticmethod
def calculate_reprojection_per_frame(
session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]]
) -> Dict[int, float]:
"""Calculate reprojection error per frame."""

reprojection_error_per_frame = {}

# Triangulate and reproject instance coordinates.
instances_and_coords: Dict[
int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]
] = TriangulateSession.calculate_reprojected_points(
session=session, instances=instances
)
for frame_id, instances_in_frame in instances_and_coords.items():
frame_error = 0
for cam, instances_in_view in instances_in_frame.items():
# Compare instance coordinates here
view_error = 0
for inst, inst_coords in instances_in_view:
node_errors = np.nan_to_num(inst.numpy() - inst_coords)
instance_error = np.linalg.norm(node_errors)
view_error += instance_error
frame_error += view_error
reprojection_error_per_frame[frame_id] = frame_error

return reprojection_error_per_frame
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method calculate_reprojection_per_frame calculates the reprojection error for each frame, which seems correct. However, there are a few points to consider:

  • The use of np.nan_to_num in line 3790 may mask potential issues with the data. If NaN values are expected and normal, this is fine, but if they indicate a problem with the data or computation, it might be better to handle them explicitly.
  • The calculation of instance_error in line 3791 does not square the difference before taking the norm, which is typical in reprojection error calculations. If the intention is to calculate the Euclidean distance, this is correct. Otherwise, consider squaring the differences.
  • Ensure that the np.linalg.norm function is the correct choice for calculating the instance error. If you're looking for a sum of squared differences, you might need to square the node_errors before summing.
  • There is no check for empty instances_in_view. If there are no instances in view for a particular camera, this could lead to a division by zero or other unexpected behavior when calculating averages or normalizing errors.


@staticmethod
def get_permutations_of_instances(
selected_instance: Instance,
Expand Down
62 changes: 62 additions & 0 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,3 +1366,65 @@ def product(seq):
for inst in instances_in_view:
assert inst.frame_idx == selected_instance.frame_idx
assert inst.video == session[cam]


def test_triangulate_session_calculate_reprojection_per_frame(
multiview_min_session_labels: Labels,
):
"""Test `TriangulateSession.get_permutations_of_instances`."""

labels = multiview_min_session_labels
session = labels.sessions[0]
lf = labels.labeled_frames[0]
selected_instance = lf.instances[0]

instances = TriangulateSession.get_permutations_of_instances(
selected_instance=selected_instance,
session=session,
frame_idx=lf.frame_idx,
)

reprojection_error_per_frame = TriangulateSession.calculate_reprojection_per_frame(
session=session, instances=instances
)

for frame_id in instances.keys():
assert frame_id in reprojection_error_per_frame
assert isinstance(reprojection_error_per_frame[frame_id], float)


def test_triangulate_session_get_instance_grouping(
multiview_min_session_labels: Labels,
):
"""Test `TriangulateSession.get_permutations_of_instances`."""

labels = multiview_min_session_labels
session = labels.sessions[0]
lf = labels.labeled_frames[0]
selected_instance = lf.instances[0]

instances = TriangulateSession.get_permutations_of_instances(
selected_instance=selected_instance,
session=session,
frame_idx=lf.frame_idx,
)

reprojection_error_per_frame = TriangulateSession.calculate_reprojection_per_frame(
session=session, instances=instances
)

best_instances: Dict[
int, Dict[Camcorder, Instance]
] = TriangulateSession.get_instance_grouping(
instances=instances, reprojection_error_per_frame=reprojection_error_per_frame
)
assert len(best_instances) == 1
for frame_id, instances_in_frame in best_instances.items():
for cam, instances_in_view in instances_in_frame.items():
for inst in instances_in_view:
assert inst.frame_idx == selected_instance.frame_idx
assert inst.track == selected_instance.track


if __name__ == "__main__":
pytest.main([f"{__file__}::test_triangulate_session_get_instance_grouping"])
Loading