diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index ef6055a45..90f40397e 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from enum import Enum from glob import glob from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import attr import cv2 @@ -2879,13 +2879,13 @@ def do_action(cls, context: CommandContext, params: dict): @staticmethod def create_new_instance( context: CommandContext, - from_predicted: bool, - copy_instance: Optional[Instance], + from_predicted: Optional[PredictedInstance], + copy_instance: Optional[Union[Instance, PredictedInstance]], mark_complete: bool, init_method: str, location: Optional[QtCore.QPoint], from_prev_frame: bool, - ): + ) -> Instance: """Create new instance.""" # Now create the new instance @@ -2913,6 +2913,7 @@ def create_new_instance( # If we're copying a predicted instance or from another frame, copy the track if hasattr(copy_instance, "score") or from_prev_frame: + copy_instance = cast(Union[PredictedInstance, Instance], copy_instance) new_instance.track = copy_instance.track return new_instance @@ -2920,7 +2921,7 @@ def create_new_instance( @staticmethod def fill_missing_nodes( context: CommandContext, - copy_instance: Optional[Instance], + copy_instance: Optional[Union[Instance, PredictedInstance]], init_method: str, new_instance: Instance, location: Optional[QtCore.QPoint], @@ -2967,10 +2968,10 @@ def fill_missing_nodes( @staticmethod def set_visible_nodes( context: CommandContext, - copy_instance: Optional[Instance], + copy_instance: Optional[Union[Instance, PredictedInstance]], new_instance: Instance, mark_complete: bool, - ) -> Tuple[Instance, bool]: + ) -> bool: """Sets visible nodes for new instance. Args: @@ -2988,15 +2989,28 @@ def set_visible_nodes( has_missing_nodes = False - # go through each node in skeleton + # Calculate scale factor for getting new x and y values. + old_size_width = copy_instance.frame.video.shape[2] + old_size_height = copy_instance.frame.video.shape[1] + new_size_width = new_instance.frame.video.shape[2] + new_size_height = new_instance.frame.video.shape[1] + scale_width = new_size_width / old_size_width + scale_height = new_size_height / old_size_height + + # Go through each node in skeleton. for node in context.state["skeleton"].node_names: - # if we're copying from a skeleton that has this node + # If we're copying from a skeleton that has this node. if node in copy_instance and not copy_instance[node].isnan(): - # just copy x, y, and visible - # we don't want to copy a PredictedPoint or score attribute + # Ensure x, y inside current frame, then copy x, y, and visible. + # We don't want to copy a PredictedPoint or score attribute. + x_old = copy_instance[node].x + y_old = copy_instance[node].y + x_new = x_old * scale_width + y_new = y_old * scale_height + new_instance[node] = Point( - x=copy_instance[node].x, - y=copy_instance[node].y, + x=x_new, + y=y_new, visible=copy_instance[node].visible, complete=mark_complete, ) @@ -3007,18 +3021,22 @@ def set_visible_nodes( @staticmethod def find_instance_to_copy_from( - context: CommandContext, copy_instance: Optional[Instance], init_method: bool - ) -> Tuple[Optional[Instance], bool, bool]: + context: CommandContext, + copy_instance: Optional[Union[Instance, PredictedInstance]], + init_method: bool, + ) -> Tuple[ + Optional[Union[Instance, PredictedInstance]], Optional[PredictedInstance], bool + ]: """Find instance to copy from. Args: context: The command context. - copy_instance: The instance to copy from. + copy_instance: The `Instance` to copy from. init_method: The initialization method. Returns: - The instance to copy from, whether it's from a predicted instance, and - whether it's from a previous frame. + The instance to copy from, the predicted instance (if it is from a predicted + instance, else None), and whether it's from a previous frame. """ from_predicted = copy_instance @@ -3071,6 +3089,7 @@ def find_instance_to_copy_from( from_prev_frame = True from_predicted = from_predicted if hasattr(from_predicted, "score") else None + from_predicted = cast(Optional[PredictedInstance], from_predicted) return copy_instance, from_predicted, from_prev_frame