diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 8ac4d87fb..ef6055a45 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -49,7 +49,6 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.dialogs.merge import MergeDialog, ReplaceSkeletonTableDialog from sleap.gui.dialogs.message import MessageDialog from sleap.gui.dialogs.missingfiles import MissingFilesDialog -from sleap.gui.dialogs.query import QueryDialog from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track @@ -750,7 +749,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportAlphaTracker(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - video_path = params["video_path"] if "video_path" in params else None labels = Labels.load_alphatracker( @@ -790,7 +788,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportNWB(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_nwb(filename=params["filename"]) new_window = context.app.__class__() @@ -823,7 +820,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportDeepPoseKit(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.from_deepposekit( filename=params["filename"], video_path=params["video_path"], @@ -872,7 +868,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportLEAP(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_leap_matlab( filename=params["filename"], ) @@ -903,7 +898,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportCoco(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_coco( filename=params["filename"], img_dir=params["img_dir"], use_missing_gui=True ) @@ -935,7 +929,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportDeepLabCut(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_deeplabcut(filename=params["filename"]) new_window = context.app.__class__() @@ -1309,7 +1302,6 @@ def do_action(context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - from sleap.gui.dialogs.export_clip import ExportClipDialog dialog = ExportClipDialog() @@ -1585,7 +1577,6 @@ class GoNextSuggestedFrame(NavCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - next_suggestion_frame = context.labels.get_next_suggestion( context.state["video"], context.state["frame_idx"], cls.seek_direction ) @@ -1771,7 +1762,6 @@ class ReplaceVideo(EditCommand): @staticmethod def do_action(context: CommandContext, params: dict) -> bool: - import_list = params["import_list"] for import_item, video in import_list: @@ -1900,7 +1890,6 @@ def ask(context: CommandContext, params: dict) -> bool: video_file_names = [] total_num_labeled_frames = 0 for idx in row_idxs: - video = videos[idx] if video is None: return False @@ -1945,7 +1934,6 @@ def load_skeleton(filename: str): def compare_skeletons( skeleton: Skeleton, new_skeleton: Skeleton ) -> Tuple[List[str], List[str], List[str]]: - delete_nodes = [] add_nodes = [] if skeleton.node_names != new_skeleton.node_names: @@ -2724,7 +2712,6 @@ class GenerateSuggestions(EditCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - if len(context.labels.videos) == 0: print("Error: no videos to generate suggestions for") return @@ -2852,21 +2839,6 @@ def ask_and_do(cls, context: CommandContext, params: dict): class AddInstance(EditCommand): topics = [UpdateTopic.frame, UpdateTopic.project_instances, UpdateTopic.suggestions] - @staticmethod - def get_previous_frame_index(context: CommandContext) -> Optional[int]: - frames = context.labels.frames( - context.state["video"], - from_frame_idx=context.state["frame_idx"], - reverse=True, - ) - - try: - next_idx = next(frames).frame_idx - except: - return - - return next_idx - @classmethod def do_action(cls, context: CommandContext, params: dict): copy_instance = params.get("copy_instance", None) @@ -2880,6 +2852,175 @@ def do_action(cls, context: CommandContext, params: dict): if len(context.state["skeleton"]) == 0: return + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=copy_instance, init_method=init_method + ) + + new_instance = AddInstance.create_new_instance( + context=context, + from_predicted=from_predicted, + copy_instance=copy_instance, + mark_complete=mark_complete, + init_method=init_method, + location=location, + from_prev_frame=from_prev_frame, + ) + + # Add the instance + context.labels.add_instance(context.state["labeled_frame"], new_instance) + + if context.state["labeled_frame"] not in context.labels.labels: + context.labels.append(context.state["labeled_frame"]) + + @staticmethod + def create_new_instance( + context: CommandContext, + from_predicted: bool, + copy_instance: Optional[Instance], + mark_complete: bool, + init_method: str, + location: Optional[QtCore.QPoint], + from_prev_frame: bool, + ): + """Create new instance.""" + + # Now create the new instance + new_instance = Instance( + skeleton=context.state["skeleton"], + from_predicted=from_predicted, + frame=context.state["labeled_frame"], + ) + + has_missing_nodes = AddInstance.set_visible_nodes( + context=context, + copy_instance=copy_instance, + new_instance=new_instance, + mark_complete=mark_complete, + ) + + if has_missing_nodes: + AddInstance.fill_missing_nodes( + context=context, + copy_instance=copy_instance, + init_method=init_method, + new_instance=new_instance, + location=location, + ) + + # If we're copying a predicted instance or from another frame, copy the track + if hasattr(copy_instance, "score") or from_prev_frame: + new_instance.track = copy_instance.track + + return new_instance + + @staticmethod + def fill_missing_nodes( + context: CommandContext, + copy_instance: Optional[Instance], + init_method: str, + new_instance: Instance, + location: Optional[QtCore.QPoint], + ): + """Fill in missing nodes for new instance. + + Args: + context: The command context. + copy_instance: The instance to copy from. + init_method: The initialization method. + new_instance: The new instance. + location: The location of the instance. + + Returns: + None + """ + + # mark the node as not "visible" if we're copying from a predicted instance without this node + is_visible = copy_instance is None or (not hasattr(copy_instance, "score")) + + if init_method == "force_directed": + AddMissingInstanceNodes.add_force_directed_nodes( + context=context, + instance=new_instance, + visible=is_visible, + center_point=location, + ) + elif init_method == "random": + AddMissingInstanceNodes.add_random_nodes( + context=context, instance=new_instance, visible=is_visible + ) + elif init_method == "template": + AddMissingInstanceNodes.add_nodes_from_template( + context=context, + instance=new_instance, + visible=is_visible, + center_point=location, + ) + else: + AddMissingInstanceNodes.add_best_nodes( + context=context, instance=new_instance, visible=is_visible + ) + + @staticmethod + def set_visible_nodes( + context: CommandContext, + copy_instance: Optional[Instance], + new_instance: Instance, + mark_complete: bool, + ) -> Tuple[Instance, bool]: + """Sets visible nodes for new instance. + + Args: + context: The command context. + copy_instance: The instance to copy from. + new_instance: The new instance. + mark_complete: Whether to mark the instance as complete. + + Returns: + Whether the new instance has missing nodes. + """ + + if copy_instance is None: + return True + + has_missing_nodes = False + + # go through each node in skeleton + for node in context.state["skeleton"].node_names: + # if we're copying from a skeleton that has this node + if node in copy_instance and not copy_instance[node].isnan(): + # just copy x, y, and visible + # we don't want to copy a PredictedPoint or score attribute + new_instance[node] = Point( + x=copy_instance[node].x, + y=copy_instance[node].y, + visible=copy_instance[node].visible, + complete=mark_complete, + ) + else: + has_missing_nodes = True + + return has_missing_nodes + + @staticmethod + def find_instance_to_copy_from( + context: CommandContext, copy_instance: Optional[Instance], init_method: bool + ) -> Tuple[Optional[Instance], bool, bool]: + """Find instance to copy from. + + Args: + context: The command context. + copy_instance: The instance to copy from. + init_method: The initialization method. + + Returns: + The instance to copy from, whether it's from a predicted instance, and + whether it's from a previous frame. + """ + from_predicted = copy_instance from_prev_frame = False @@ -2905,7 +3046,7 @@ def do_action(cls, context: CommandContext, params: dict): ) or init_method == "prior_frame": # Otherwise, if there are instances in previous frames, # copy the points from one of those instances. - prev_idx = cls.get_previous_frame_index(context) + prev_idx = AddInstance.get_previous_frame_index(context) if prev_idx is not None: prev_instances = context.labels.find( @@ -2931,70 +3072,24 @@ def do_action(cls, context: CommandContext, params: dict): from_predicted = from_predicted if hasattr(from_predicted, "score") else None - # Now create the new instance - new_instance = Instance( - skeleton=context.state["skeleton"], - from_predicted=from_predicted, - frame=context.state["labeled_frame"], - ) - - has_missing_nodes = False + return copy_instance, from_predicted, from_prev_frame - # go through each node in skeleton - for node in context.state["skeleton"].node_names: - # if we're copying from a skeleton that has this node - if ( - copy_instance is not None - and node in copy_instance - and not copy_instance[node].isnan() - ): - # just copy x, y, and visible - # we don't want to copy a PredictedPoint or score attribute - new_instance[node] = Point( - x=copy_instance[node].x, - y=copy_instance[node].y, - visible=copy_instance[node].visible, - complete=mark_complete, - ) - else: - has_missing_nodes = True - - if has_missing_nodes: - # mark the node as not "visible" if we're copying from a predicted instance without this node - is_visible = copy_instance is None or (not hasattr(copy_instance, "score")) - - if init_method == "force_directed": - AddMissingInstanceNodes.add_force_directed_nodes( - context=context, - instance=new_instance, - visible=is_visible, - center_point=location, - ) - elif init_method == "random": - AddMissingInstanceNodes.add_random_nodes( - context=context, instance=new_instance, visible=is_visible - ) - elif init_method == "template": - AddMissingInstanceNodes.add_nodes_from_template( - context=context, - instance=new_instance, - visible=is_visible, - center_point=location, - ) - else: - AddMissingInstanceNodes.add_best_nodes( - context=context, instance=new_instance, visible=is_visible - ) + @staticmethod + def get_previous_frame_index(context: CommandContext) -> Optional[int]: + """Returns index of previous frame.""" - # If we're copying a predicted instance or from another frame, copy the track - if hasattr(copy_instance, "score") or from_prev_frame: - new_instance.track = copy_instance.track + frames = context.labels.frames( + context.state["video"], + from_frame_idx=context.state["frame_idx"], + reverse=True, + ) - # Add the instance - context.labels.add_instance(context.state["labeled_frame"], new_instance) + try: + next_idx = next(frames).frame_idx + except: + return - if context.state["labeled_frame"] not in context.labels.labels: - context.labels.append(context.state["labeled_frame"]) + return next_idx class SetInstancePointLocations(EditCommand):