From 84dc0eb53366b1524da9776c2e5a7f2151ea7b55 Mon Sep 17 00:00:00 2001 From: ramiz hajj <122569822+ramizhajj1@users.noreply.github.com> Date: Wed, 1 May 2024 16:51:50 -0700 Subject: [PATCH 1/2] (4a -> 4) Add menu to assign an `Instance` to an `InstanceGroup` (#1747) * Implement session menu and instance group functionality * Add instance group management to sessions menu and update related commands * Add Delete Instance Group * Fix obvious errors (but no debug tests yet) * Test `SetSelectedInstanceGroup` class * Add fixes for `SetSelectedInstances` test * Test AddInstanceGroup class * Change AddInstanceGroup to pass tests * Add test for DeleteInstanceGroup class * Better the DeleteInstanceGroup error message --------- Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> --- sleap/gui/app.py | 58 ++++++- sleap/gui/commands.py | 117 +++++++++++++ sleap/io/cameras.py | 42 ++++- tests/gui/test_commands.py | 348 +++++++++++++++++++++++++++++++++++++ 4 files changed, 555 insertions(+), 10 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 5a479a2b2..25a07385d 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -86,7 +86,6 @@ from sleap.skeleton import Skeleton from sleap.util import parse_uri_path - logger = getLogger(__name__) @@ -810,9 +809,18 @@ def new_instance_menu_action(): self.commands.deleteFrameLimitPredictions, ) + ### Sessions Menu ### + + sessionsMenu = self.menuBar().addMenu("Sessions") + + self.inst_groups_menu = sessionsMenu.addMenu("Set Instance Group") + self.inst_groups_delete_menu = sessionsMenu.addMenu("Delete Instance Group") + self.state.connect("frame_idx", self._update_sessions_menu) + ### Tracks Menu ### tracksMenu = self.menuBar().addMenu("Tracks") + self.track_menu = tracksMenu.addMenu("Set Instance Track") add_menu_check_item( tracksMenu, "propagate track labels", "Propagate Track Labels" @@ -1120,6 +1128,7 @@ def _update_gui_state(self): # Update menus + self.inst_groups_menu.setEnabled(has_selected_instance) self.track_menu.setEnabled(has_selected_instance) self.delete_tracks_menu.setEnabled(has_tracks) self._menu_actions["clear selection"].setEnabled(has_selected_instance) @@ -1253,10 +1262,12 @@ def _has_topic(topic_list): if _has_topic([UpdateTopic.frame, UpdateTopic.project_instances]): self.state["last_interacted_frame"] = self.state["labeled_frame"] + self._update_sessions_menu() if _has_topic([UpdateTopic.sessions]): self.update_cameras_model() self.update_unlinked_videos_model() + self._update_sessions_menu() def update_unlinked_videos_model(self): """Update the unlinked videos model with the selected session.""" @@ -1411,6 +1422,51 @@ def _update_track_menu(self): "New Track", self.commands.addTrack, Qt.CTRL + Qt.Key_0 ) + def _update_sessions_menu(self): + """Update the instance groups menu based on the frame index.""" + + # Clear menus before adding more items + self.inst_groups_menu.clear() + self.inst_groups_delete_menu.clear() + + # Get the session + session = self.state.get("session") + if session is None: + return + + # Get the frame group for the current frame + frame_idx = self.state["frame_idx"] + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is not None: + for inst_group_ind, instance_group in enumerate( + frame_group.instance_groups + ): + # Create shortcut key for first 9 groups + key_command = "" + if inst_group_ind < 9: + key_command = Qt.SHIFT + Qt.Key_0 + inst_group_ind + 1 + + # Update the Set Instance Group menu + self.inst_groups_menu.addAction( + instance_group.name, + lambda x=instance_group: self.commands.setInstanceGroup(x), + key_command, + ) + + # Update the Delete Instance Group menu + self.inst_groups_delete_menu.addAction( + instance_group.name, + lambda x=instance_group: self.commands.deleteInstanceGroup( + instance_group=x + ), + ) + + self.inst_groups_menu.addAction( + "New Instance Group", + self.commands.addInstanceGroup, + Qt.SHIFT + Qt.Key_0, + ) + def _update_seekbar_marks(self): """Updates marks on seekbar.""" set_slider_marks_from_labels( diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 8839369e2..247b93360 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -593,6 +593,14 @@ def setInstanceTrack(self, new_track: "Track"): """Sets track for selected instance.""" self.execute(SetSelectedInstanceTrack, new_track=new_track) + def addInstanceGroup(self): + """Sets the instance group for selected instance.""" + self.execute(AddInstanceGroup) + + def setInstanceGroup(self, instance_group: Optional["InstanceGroup"]): + """Sets the instance group for selected instance.""" + self.execute(SetSelectedInstanceGroup, instance_group=instance_group) + def deleteTrack(self, track: "Track"): """Delete a track and remove from all instances.""" self.execute(DeleteTrack, track=track) @@ -601,6 +609,10 @@ def deleteMultipleTracks(self, delete_all: bool = False): """Delete all tracks.""" self.execute(DeleteMultipleTracks, delete_all=delete_all) + def deleteInstanceGroup(self, instance_group: "InstanceGroup"): + """Delete an instance group.""" + self.execute(DeleteInstanceGroup, instance_group=instance_group) + def copyInstanceTrack(self): """Copies the selected instance's track to the track clipboard.""" self.execute(CopyInstanceTrack) @@ -2690,6 +2702,30 @@ def ask_and_do(context: CommandContext, params: dict): context.signal_update([UpdateTopic.project_instances]) +class AddInstanceGroup(EditCommand): + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context, params): + + # Get session and frame index + frame_idx = context.state["frame_idx"] + session: RecordingSession = context.state["session"] + if session is None: + raise ValueError("Cannot add instance group without session.") + + # Get or create frame group + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is None: + frame_group = session.new_frame_group(frame_idx=frame_idx) + + # Create and add instance group + instance_group = frame_group.add_instance_group(instance_group=None) + + # Now add the selected instance to the `InstanceGroup` + context.execute(SetSelectedInstanceGroup, instance_group=instance_group) + + class AddTrack(EditCommand): topics = [UpdateTopic.tracks] @@ -2706,6 +2742,61 @@ def do_action(context: CommandContext, params: dict): context.execute(SetSelectedInstanceTrack, new_track=new_track) +class SetSelectedInstanceGroup(EditCommand): + @staticmethod + def do_action(context, params): + """Set the `selected_instance` to the `instance_group`. + + Args: + context: The command context. + state: The context state. + instance: The selected instance. + frame_idx: The frame index. + video: The video. + session: The recording session. + + params: The command parameters. + instance_group: The `InstanceGroup` to set the selected instance to. + + Raises: + ValueError: If the `RecordingSession` is None. + ValueError: If the `FrameGroup` does not exist for the frame index. + ValueError: If the `Video` is not linked to a `Camcorder`. + """ + + selected_instance = context.state["instance"] + frame_idx = context.state["frame_idx"] + video = context.state["video"] + + base_message = ( + f"Cannot set instance group for selected instance [{selected_instance}]." + ) + + # `RecordingSession` should not be None + session: RecordingSession = context.state["session"] + if session is None: + raise ValueError(f"{base_message} No session for video [{video}]") + + # `FrameGroup` should already exist + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is None: + raise ValueError( + f"{base_message} Frame group does not exist for frame [{frame_idx}] in " + f"{session}." + ) + + # We need the camera and instance group to set the instance group + camera = session.get_camera(video=video) + if camera is None: + raise ValueError(f"{base_message} No camera linked to video [{video}]") + instance_group = params["instance_group"] + + # Set the instance group + frame_group.add_instance( + instance=selected_instance, camera=camera, instance_group=instance_group + ) + + class SetSelectedInstanceTrack(EditCommand): topics = [UpdateTopic.tracks] @@ -2795,6 +2886,31 @@ def do_action(context: CommandContext, params: dict): context.labels.remove_unused_tracks() +class DeleteInstanceGroup(EditCommand): + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context, params): + + instance_group = params["instance_group"] + frame_idx = context.state["frame_idx"] + + base_message = f"Cannot delete instance group [{instance_group}]." + + # `RecordingSession` should not be None + session: RecordingSession = context.state["session"] + if session is None: + raise ValueError(f"{base_message} No session in context state.") + + # `FrameGroup` should already exist + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is None: + raise ValueError(f"{base_message} No frame group for frame {frame_idx}.") + + # Remove the instance group + frame_group.remove_instance_group(instance_group=instance_group) + + class CopyInstanceTrack(EditCommand): @staticmethod def do_action(context: CommandContext, params: dict): @@ -3376,6 +3492,7 @@ def add_nodes_from_template( def add_force_directed_nodes( cls, context, instance, visible, center_point: QtCore.QPoint = None ): + import networkx as nx center_point = center_point or context.app.player.getVisibleRect().center() diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 4de2432db..0e227e159 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -645,7 +645,7 @@ def add_instance(self, cam: Camcorder, instance: Instance): ) # Add the instance to the `InstanceGroup` - self.replace_instance(cam, instance) + self.replace_instance(cam=cam, instance=instance) def replace_instance(self, cam: Camcorder, instance: Instance): """Replace an `Instance` in the `InstanceGroup`. @@ -668,6 +668,9 @@ def replace_instance(self, cam: Camcorder, instance: Instance): # Remove the instance if it already exists self.remove_instance(instance_or_cam=instance) + # Remove the instance currently at the cam (if any) + self.remove_instance(instance_or_cam=cam) + # Replace the instance in the `InstanceGroup` self._instance_by_camcorder[cam] = instance self._camcorder_by_instance[instance] = cam @@ -1571,7 +1574,7 @@ def __attrs_post_init__(self): for camera in self.session.camera_cluster.cameras: self._instances_by_cam[camera] = set() for instance_group in self.instance_groups: - self.add_instance_group(instance_group) + self.add_instance_group(instance_group=instance_group) @property def instance_groups(self) -> List[InstanceGroup]: @@ -1713,6 +1716,15 @@ def add_instance( # Add the `Instance` to the `InstanceGroup` if instance_group is not None: + # Remove any existing `Instance` in given `InstanceGroup` at same `Camcorder` + preexisting_instance = instance_group.get_instance(camera) + if preexisting_instance is not None: + self.remove_instance(instance=preexisting_instance) + + # Remove the `Instance` from the `FrameGroup` if it is already exists + self.remove_instance(instance=instance, remove_empty_instance_group=True) + + # Add the `Instance` to the `InstanceGroup` instance_group.add_instance(cam=camera, instance=instance) else: self._raise_if_instance_not_in_instance_group(instance=instance) @@ -1726,11 +1738,15 @@ def add_instance( labeled_frame = instance.frame self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera) - def remove_instance(self, instance: Instance): + def remove_instance( + self, instance: Instance, remove_empty_instance_group: bool = False + ): """Removes an `Instance` from the `FrameGroup`. Args: instance: `Instance` to remove from the `FrameGroup`. + remove_empty_instance_group: If True, then remove the `InstanceGroup` if it + is empty. Default is False. """ instance_group = self.get_instance_group(instance=instance) @@ -1747,12 +1763,22 @@ def remove_instance(self, instance: Instance): instance_group.remove_instance(instance_or_cam=instance) # Remove the `Instance` from the `FrameGroup` - self._instances_by_cam[camera].remove(instance) + if instance in self._instances_by_cam[camera]: + self._instances_by_cam[camera].remove(instance) + else: + logger.debug( + f"Instance {instance} not found in this FrameGroup: " + f"{self._instances_by_cam[camera]}." + ) # Remove "empty" `LabeledFrame`s from the `FrameGroup` if len(self._instances_by_cam[camera]) < 1: self.remove_labeled_frame(labeled_frame_or_camera=camera) + # Remove the `InstanceGroup` if it is empty + if remove_empty_instance_group and len(instance_group.instances) < 1: + self.remove_instance_group(instance_group=instance_group) + def add_instance_group( self, instance_group: Optional[InstanceGroup] = None ) -> InstanceGroup: @@ -1819,11 +1845,9 @@ def remove_instance_group(self, instance_group: InstanceGroup): # Remove the `Instance`s from the `FrameGroup` for camera, instance in instance_group.instance_by_camcorder.items(): self._instances_by_cam[camera].remove(instance) - - # Remove the `LabeledFrame` from the `FrameGroup` - labeled_frame = self.get_labeled_frame(camera=camera) - if labeled_frame is not None: - self.remove_labeled_frame(labeled_frame_or_camera=camera) + # Remove the `LabeledFrame` if no more grouped instances + if len(self._instances_by_cam[camera]) < 1: + self.remove_labeled_frame(labeled_frame_or_camera=camera) # TODO(LM): maintain this as a dictionary for quick lookups def get_instance_group(self, instance: Instance) -> Optional[InstanceGroup]: diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 68c8fb578..f3051c14a 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1103,3 +1103,351 @@ def test_TriangulateSession_do_action(multiview_min_session_frame_groups): assert np.allclose(inst_group_np, inst_group_np_post_tri, equal_nan=True) # TODO(LM): Test with `PredictedInstance`s + + +def test_SetSelectedInstanceGroup(multiview_min_session_frame_groups: Labels): + """Test that setting a new instance group works.""" + + labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx = 0 + frame_group: FrameGroup = session.frame_groups[frame_idx] + labeled_frame: LabeledFrame = frame_group.labeled_frames[0] + video = labeled_frame.video + camera = session.get_camera(video=video) + + # We want to replace `instance_0` with `instance_1` in the `InstanceGroup` + instance_0 = labeled_frame.user_instances[0] + instance_group_0 = frame_group.get_instance_group(instance=instance_0) + instance_1 = labeled_frame.user_instances[1] + instance_group_1 = frame_group.get_instance_group(instance=instance_1) + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + context.state["instance"] = instance_1 + context.state["video"] = video + + # No session + with pytest.raises(ValueError): + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No frame_idx + context.state["session"] = session + with pytest.raises(ValueError): + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # With session and frame_idx + context.state["frame_idx"] = frame_idx + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 not in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 5 + assert instance_0 not in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 in instance_group_0.instances + assert instance_1 not in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 not in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 in instance_group_0._camcorder_by_instance + assert instance_1 not in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 not in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 in instance_group_0._instance_by_camcorder.values() + assert instance_1 not in instance_group_1._instance_by_camcorder.values() + + # Let's move the instance to the other `InstanceGroup` + context.setInstanceGroup(instance_group=instance_group_1) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 not in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 7 + assert len(instance_group_1.instances) == 6 + assert instance_0 not in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 not in instance_group_0.instances + assert instance_1 in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 not in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_0._camcorder_by_instance + assert instance_1 in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 not in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_0._instance_by_camcorder.values() + assert instance_1 in instance_group_1._instance_by_camcorder.values() + + # Let's move the other instance back to its original `InstanceGroup` + context.state["instance"] = instance_0 + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + assert instance_0 in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 not in instance_group_0.instances + assert instance_1 in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_0._camcorder_by_instance + assert instance_1 in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_0._instance_by_camcorder.values() + assert instance_1 in instance_group_1._instance_by_camcorder.values() + + # Let's remove all but one instance from an `InstanceGroup` + for instance in instance_group_0.instances: + if instance == instance_0: + continue + frame_group.remove_instance(instance=instance) + assert len(instance_group_0.instances) == 1 + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 1 + assert len(instance_group_1.instances) == 6 + assert instance_0 in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 not in instance_group_0.instances + assert instance_1 in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_0._camcorder_by_instance + assert instance_1 in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_0._instance_by_camcorder.values() + assert instance_1 in instance_group_1._instance_by_camcorder.values() + + # Let's switch the last instance to a different `InstanceGroup` + context.setInstanceGroup(instance_group=instance_group_1) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 1 + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 not in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 0 + assert len(instance_group_1.instances) == 6 + assert instance_0 in instance_group_1.instances + assert instance_1 not in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_1._instance_by_camcorder.values() + + +def test_AddInstanceGroup(multiview_min_session_frame_groups: Labels): + """Test that adding an instance group works.""" + + labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx = 1 + frame_group: FrameGroup = session.frame_groups[frame_idx] + instance_group_0: InstanceGroup = frame_group.instance_groups[0] + instance_group_1: InstanceGroup = frame_group.instance_groups[1] + labeled_frame: LabeledFrame = frame_group.labeled_frames[0] + video = labeled_frame.video + camera = session.get_camera(video=video) + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + + # No session + with pytest.raises(ValueError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No frame_idx + context.state["session"] = session + with pytest.raises(TypeError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No instance + context.state["frame_idx"] = frame_idx + with pytest.raises(ValueError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_2 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 3 + assert instance_group_2 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + + # No video + context.state["instance"] = instance_group_0.get_instance(cam=camera) + with pytest.raises(ValueError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_3 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 4 + assert instance_group_3 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + assert len(instance_group_3.instances) == 0 + + # Everything, let's add an `InstanceGroup` and set the `Instance` to it + context.state["video"] = video + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_4 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 5 + assert instance_group_4 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 7 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + assert len(instance_group_3.instances) == 0 + assert len(instance_group_4.instances) == 1 + + # Everything, let's add an `InstanceGroup` and set the last `Instance` to it + context.state["video"] = video + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_5 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 5 + assert instance_group_4 not in frame_group.instance_groups + assert instance_group_5 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 7 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + assert len(instance_group_3.instances) == 0 + assert len(instance_group_4.instances) == 0 + assert len(instance_group_5.instances) == 1 + + +def test_DeleteInstanceGroup(multiview_min_session_frame_groups: Labels): + """Test that deleting an instance group works.""" + + labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx = 2 + frame_group: FrameGroup = session.frame_groups[frame_idx] + instance_group_0: InstanceGroup = frame_group.instance_groups[0] + instance_group_1: InstanceGroup = frame_group.instance_groups[1] + labeled_frame: LabeledFrame = frame_group.labeled_frames[0] + video = labeled_frame.video + camera = session.get_camera(video=video) + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + + # No session + with pytest.raises(ValueError): + context.deleteInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No frame_idx + context.state["session"] = session + with pytest.raises(ValueError): + context.deleteInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # Everything, let's delete an `InstanceGroup` + context.state["frame_idx"] = frame_idx + context.deleteInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 1 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 1 + assert instance_group_0 not in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # Everything, let's delete the last `InstanceGroup` + context.state["frame_idx"] = frame_idx + context.deleteInstanceGroup(instance_group=instance_group_1) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 0 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 0 + assert instance_group_1 not in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 From 07ea17b3a2b712b7340a559ae1b8482c48b202ce Mon Sep 17 00:00:00 2001 From: roomrys Date: Thu, 2 May 2024 12:55:33 -0700 Subject: [PATCH 2/2] Handle case when upserting_instances with no instances in at least one view --- sleap/io/cameras.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 3d115c20c..4ffd9adf4 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1657,7 +1657,7 @@ def numpy( frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2 cams_to_include_mask = np.array( - [cam in self.cams_to_include for cam in self.cameras] + [cam in self.cams_to_include for cam in self.session.cameras] ) # M=all x 1 return frame_group_numpy[cams_to_include_mask] # M=include x T x N x 2 @@ -1953,8 +1953,12 @@ def _create_and_add_labeled_frame(self, camera: Camcorder) -> LabeledFrame: f"Camcorder {camera} is not linked to a video in this " f"RecordingSession {self.session}." ) - - labeled_frame = LabeledFrame(video=video, frame_idx=self.frame_idx) + # First try to find the `LabeledFrame` in the `RecordingSession`'s `Labels` + labeled_frames = self.session.labels.find(video=video, frame_idx=self.frame_idx) + if len(labeled_frames) > 0: + labeled_frame = labeled_frames[0] + else: + labeled_frame = LabeledFrame(video=video, frame_idx=self.frame_idx) self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera) return labeled_frame