Skip to content

Commit

Permalink
Merge branch 'liezl/add-gui-elements-for-sessions' into andrew/color-…
Browse files Browse the repository at this point in the history
…by-instance-groups
  • Loading branch information
roomrys authored May 24, 2024
2 parents ff0aedc + 45eeea7 commit 75e2a0d
Show file tree
Hide file tree
Showing 4 changed files with 562 additions and 13 deletions.
58 changes: 57 additions & 1 deletion sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
from sleap.skeleton import Skeleton
from sleap.util import parse_uri_path


logger = getLogger(__name__)


Expand Down Expand Up @@ -810,9 +809,18 @@ def new_instance_menu_action():
self.commands.deleteFrameLimitPredictions,
)

### Sessions Menu ###

sessionsMenu = self.menuBar().addMenu("Sessions")

self.inst_groups_menu = sessionsMenu.addMenu("Set Instance Group")
self.inst_groups_delete_menu = sessionsMenu.addMenu("Delete Instance Group")
self.state.connect("frame_idx", self._update_sessions_menu)

### Tracks Menu ###

tracksMenu = self.menuBar().addMenu("Tracks")

self.track_menu = tracksMenu.addMenu("Set Instance Track")
add_menu_check_item(
tracksMenu, "propagate track labels", "Propagate Track Labels"
Expand Down Expand Up @@ -1120,6 +1128,7 @@ def _update_gui_state(self):

# Update menus

self.inst_groups_menu.setEnabled(has_selected_instance)
self.track_menu.setEnabled(has_selected_instance)
self.delete_tracks_menu.setEnabled(has_tracks)
self._menu_actions["clear selection"].setEnabled(has_selected_instance)
Expand Down Expand Up @@ -1253,10 +1262,12 @@ def _has_topic(topic_list):

if _has_topic([UpdateTopic.frame, UpdateTopic.project_instances]):
self.state["last_interacted_frame"] = self.state["labeled_frame"]
self._update_sessions_menu()

if _has_topic([UpdateTopic.sessions]):
self.update_cameras_model()
self.update_unlinked_videos_model()
self._update_sessions_menu()

def update_unlinked_videos_model(self):
"""Update the unlinked videos model with the selected session."""
Expand Down Expand Up @@ -1411,6 +1422,51 @@ def _update_track_menu(self):
"New Track", self.commands.addTrack, Qt.CTRL + Qt.Key_0
)

def _update_sessions_menu(self):
"""Update the instance groups menu based on the frame index."""

# Clear menus before adding more items
self.inst_groups_menu.clear()
self.inst_groups_delete_menu.clear()

# Get the session
session = self.state.get("session")
if session is None:
return

# Get the frame group for the current frame
frame_idx = self.state["frame_idx"]
frame_group = session.frame_groups.get(frame_idx, None)
if frame_group is not None:
for inst_group_ind, instance_group in enumerate(
frame_group.instance_groups
):
# Create shortcut key for first 9 groups
key_command = ""
if inst_group_ind < 9:
key_command = Qt.SHIFT + Qt.Key_0 + inst_group_ind + 1

# Update the Set Instance Group menu
self.inst_groups_menu.addAction(
instance_group.name,
lambda x=instance_group: self.commands.setInstanceGroup(x),
key_command,
)

# Update the Delete Instance Group menu
self.inst_groups_delete_menu.addAction(
instance_group.name,
lambda x=instance_group: self.commands.deleteInstanceGroup(
instance_group=x
),
)

self.inst_groups_menu.addAction(
"New Instance Group",
self.commands.addInstanceGroup,
Qt.SHIFT + Qt.Key_0,
)

def _update_seekbar_marks(self):
"""Updates marks on seekbar."""
set_slider_marks_from_labels(
Expand Down
117 changes: 117 additions & 0 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,14 @@ def setInstanceTrack(self, new_track: "Track"):
"""Sets track for selected instance."""
self.execute(SetSelectedInstanceTrack, new_track=new_track)

def addInstanceGroup(self):
"""Sets the instance group for selected instance."""
self.execute(AddInstanceGroup)

def setInstanceGroup(self, instance_group: Optional["InstanceGroup"]):
"""Sets the instance group for selected instance."""
self.execute(SetSelectedInstanceGroup, instance_group=instance_group)

def deleteTrack(self, track: "Track"):
"""Delete a track and remove from all instances."""
self.execute(DeleteTrack, track=track)
Expand All @@ -601,6 +609,10 @@ def deleteMultipleTracks(self, delete_all: bool = False):
"""Delete all tracks."""
self.execute(DeleteMultipleTracks, delete_all=delete_all)

def deleteInstanceGroup(self, instance_group: "InstanceGroup"):
"""Delete an instance group."""
self.execute(DeleteInstanceGroup, instance_group=instance_group)

def copyInstanceTrack(self):
"""Copies the selected instance's track to the track clipboard."""
self.execute(CopyInstanceTrack)
Expand Down Expand Up @@ -2690,6 +2702,30 @@ def ask_and_do(context: CommandContext, params: dict):
context.signal_update([UpdateTopic.project_instances])


class AddInstanceGroup(EditCommand):
topics = [UpdateTopic.sessions]

@staticmethod
def do_action(context, params):

# Get session and frame index
frame_idx = context.state["frame_idx"]
session: RecordingSession = context.state["session"]
if session is None:
raise ValueError("Cannot add instance group without session.")

# Get or create frame group
frame_group = session.frame_groups.get(frame_idx, None)
if frame_group is None:
frame_group = session.new_frame_group(frame_idx=frame_idx)

# Create and add instance group
instance_group = frame_group.add_instance_group(instance_group=None)

# Now add the selected instance to the `InstanceGroup`
context.execute(SetSelectedInstanceGroup, instance_group=instance_group)


class AddTrack(EditCommand):
topics = [UpdateTopic.tracks]

Expand All @@ -2706,6 +2742,61 @@ def do_action(context: CommandContext, params: dict):
context.execute(SetSelectedInstanceTrack, new_track=new_track)


class SetSelectedInstanceGroup(EditCommand):
@staticmethod
def do_action(context, params):
"""Set the `selected_instance` to the `instance_group`.
Args:
context: The command context.
state: The context state.
instance: The selected instance.
frame_idx: The frame index.
video: The video.
session: The recording session.
params: The command parameters.
instance_group: The `InstanceGroup` to set the selected instance to.
Raises:
ValueError: If the `RecordingSession` is None.
ValueError: If the `FrameGroup` does not exist for the frame index.
ValueError: If the `Video` is not linked to a `Camcorder`.
"""

selected_instance = context.state["instance"]
frame_idx = context.state["frame_idx"]
video = context.state["video"]

base_message = (
f"Cannot set instance group for selected instance [{selected_instance}]."
)

# `RecordingSession` should not be None
session: RecordingSession = context.state["session"]
if session is None:
raise ValueError(f"{base_message} No session for video [{video}]")

# `FrameGroup` should already exist
frame_group = session.frame_groups.get(frame_idx, None)
if frame_group is None:
raise ValueError(
f"{base_message} Frame group does not exist for frame [{frame_idx}] in "
f"{session}."
)

# We need the camera and instance group to set the instance group
camera = session.get_camera(video=video)
if camera is None:
raise ValueError(f"{base_message} No camera linked to video [{video}]")
instance_group = params["instance_group"]

# Set the instance group
frame_group.add_instance(
instance=selected_instance, camera=camera, instance_group=instance_group
)


class SetSelectedInstanceTrack(EditCommand):
topics = [UpdateTopic.tracks]

Expand Down Expand Up @@ -2795,6 +2886,31 @@ def do_action(context: CommandContext, params: dict):
context.labels.remove_unused_tracks()


class DeleteInstanceGroup(EditCommand):
topics = [UpdateTopic.sessions]

@staticmethod
def do_action(context, params):

instance_group = params["instance_group"]
frame_idx = context.state["frame_idx"]

base_message = f"Cannot delete instance group [{instance_group}]."

# `RecordingSession` should not be None
session: RecordingSession = context.state["session"]
if session is None:
raise ValueError(f"{base_message} No session in context state.")

# `FrameGroup` should already exist
frame_group = session.frame_groups.get(frame_idx, None)
if frame_group is None:
raise ValueError(f"{base_message} No frame group for frame {frame_idx}.")

# Remove the instance group
frame_group.remove_instance_group(instance_group=instance_group)


class CopyInstanceTrack(EditCommand):
@staticmethod
def do_action(context: CommandContext, params: dict):
Expand Down Expand Up @@ -3376,6 +3492,7 @@ def add_nodes_from_template(
def add_force_directed_nodes(
cls, context, instance, visible, center_point: QtCore.QPoint = None
):

import networkx as nx

center_point = center_point or context.app.player.getVisibleRect().center()
Expand Down
52 changes: 40 additions & 12 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def add_instance(self, cam: Camcorder, instance: Instance):
)

# Add the instance to the `InstanceGroup`
self.replace_instance(cam, instance)
self.replace_instance(cam=cam, instance=instance)

def replace_instance(self, cam: Camcorder, instance: Instance):
"""Replace an `Instance` in the `InstanceGroup`.
Expand All @@ -668,6 +668,9 @@ def replace_instance(self, cam: Camcorder, instance: Instance):
# Remove the instance if it already exists
self.remove_instance(instance_or_cam=instance)

# Remove the instance currently at the cam (if any)
self.remove_instance(instance_or_cam=cam)

# Replace the instance in the `InstanceGroup`
self._instance_by_camcorder[cam] = instance
self._camcorder_by_instance[instance] = cam
Expand Down Expand Up @@ -1571,7 +1574,7 @@ def __attrs_post_init__(self):
for camera in self.session.camera_cluster.cameras:
self._instances_by_cam[camera] = set()
for instance_group in self.instance_groups:
self.add_instance_group(instance_group)
self.add_instance_group(instance_group=instance_group)

@property
def instance_groups(self) -> List[InstanceGroup]:
Expand Down Expand Up @@ -1668,7 +1671,7 @@ def numpy(

frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2
cams_to_include_mask = np.array(
[cam in self.cams_to_include for cam in self.cameras]
[cam in self.cams_to_include for cam in self.session.cameras]
) # M=all x 1

return frame_group_numpy[cams_to_include_mask] # M=include x T x N x 2
Expand Down Expand Up @@ -1713,6 +1716,15 @@ def add_instance(

# Add the `Instance` to the `InstanceGroup`
if instance_group is not None:
# Remove any existing `Instance` in given `InstanceGroup` at same `Camcorder`
preexisting_instance = instance_group.get_instance(camera)
if preexisting_instance is not None:
self.remove_instance(instance=preexisting_instance)

# Remove the `Instance` from the `FrameGroup` if it is already exists
self.remove_instance(instance=instance, remove_empty_instance_group=True)

# Add the `Instance` to the `InstanceGroup`
instance_group.add_instance(cam=camera, instance=instance)
else:
self._raise_if_instance_not_in_instance_group(instance=instance)
Expand All @@ -1726,11 +1738,15 @@ def add_instance(
labeled_frame = instance.frame
self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera)

def remove_instance(self, instance: Instance):
def remove_instance(
self, instance: Instance, remove_empty_instance_group: bool = False
):
"""Removes an `Instance` from the `FrameGroup`.
Args:
instance: `Instance` to remove from the `FrameGroup`.
remove_empty_instance_group: If True, then remove the `InstanceGroup` if it
is empty. Default is False.
"""

instance_group = self.get_instance_group(instance=instance)
Expand All @@ -1747,12 +1763,22 @@ def remove_instance(self, instance: Instance):
instance_group.remove_instance(instance_or_cam=instance)

# Remove the `Instance` from the `FrameGroup`
self._instances_by_cam[camera].remove(instance)
if instance in self._instances_by_cam[camera]:
self._instances_by_cam[camera].remove(instance)
else:
logger.debug(
f"Instance {instance} not found in this FrameGroup: "
f"{self._instances_by_cam[camera]}."
)

# Remove "empty" `LabeledFrame`s from the `FrameGroup`
if len(self._instances_by_cam[camera]) < 1:
self.remove_labeled_frame(labeled_frame_or_camera=camera)

# Remove the `InstanceGroup` if it is empty
if remove_empty_instance_group and len(instance_group.instances) < 1:
self.remove_instance_group(instance_group=instance_group)

def add_instance_group(
self, instance_group: Optional[InstanceGroup] = None
) -> InstanceGroup:
Expand Down Expand Up @@ -1819,11 +1845,9 @@ def remove_instance_group(self, instance_group: InstanceGroup):
# Remove the `Instance`s from the `FrameGroup`
for camera, instance in instance_group.instance_by_camcorder.items():
self._instances_by_cam[camera].remove(instance)

# Remove the `LabeledFrame` from the `FrameGroup`
labeled_frame = self.get_labeled_frame(camera=camera)
if labeled_frame is not None:
self.remove_labeled_frame(labeled_frame_or_camera=camera)
# Remove the `LabeledFrame` if no more grouped instances
if len(self._instances_by_cam[camera]) < 1:
self.remove_labeled_frame(labeled_frame_or_camera=camera)

# TODO(LM): maintain this as a dictionary for quick lookups
def get_instance_group(self, instance: Instance) -> Optional[InstanceGroup]:
Expand Down Expand Up @@ -1964,8 +1988,12 @@ def _create_and_add_labeled_frame(self, camera: Camcorder) -> LabeledFrame:
f"Camcorder {camera} is not linked to a video in this "
f"RecordingSession {self.session}."
)

labeled_frame = LabeledFrame(video=video, frame_idx=self.frame_idx)
# First try to find the `LabeledFrame` in the `RecordingSession`'s `Labels`
labeled_frames = self.session.labels.find(video=video, frame_idx=self.frame_idx)
if len(labeled_frames) > 0:
labeled_frame = labeled_frames[0]
else:
labeled_frame = LabeledFrame(video=video, frame_idx=self.frame_idx)
self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera)

return labeled_frame
Expand Down
Loading

0 comments on commit 75e2a0d

Please sign in to comment.