Skip to content

Commit

Permalink
Refactor AddInstance command (#1561)
Browse files Browse the repository at this point in the history
* Refactor AddInstance command

* Add staticmethod wrappers

* Return early from set_visible_nodes
  • Loading branch information
roomrys authored Oct 19, 2023
1 parent 1e0627a commit 5c3441c
Showing 1 changed file with 184 additions and 89 deletions.
273 changes: 184 additions & 89 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class which inherits from `AppCommand` (or a more specialized class such as
from sleap.gui.dialogs.merge import MergeDialog, ReplaceSkeletonTableDialog
from sleap.gui.dialogs.message import MessageDialog
from sleap.gui.dialogs.missingfiles import MissingFilesDialog
from sleap.gui.dialogs.query import QueryDialog
from sleap.gui.state import GuiState
from sleap.gui.suggestions import VideoFrameSuggestions
from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track
Expand Down Expand Up @@ -750,7 +749,6 @@ def ask(context: "CommandContext", params: dict) -> bool:
class ImportAlphaTracker(AppCommand):
@staticmethod
def do_action(context: "CommandContext", params: dict):

video_path = params["video_path"] if "video_path" in params else None

labels = Labels.load_alphatracker(
Expand Down Expand Up @@ -790,7 +788,6 @@ def ask(context: "CommandContext", params: dict) -> bool:
class ImportNWB(AppCommand):
@staticmethod
def do_action(context: "CommandContext", params: dict):

labels = Labels.load_nwb(filename=params["filename"])

new_window = context.app.__class__()
Expand Down Expand Up @@ -823,7 +820,6 @@ def ask(context: "CommandContext", params: dict) -> bool:
class ImportDeepPoseKit(AppCommand):
@staticmethod
def do_action(context: "CommandContext", params: dict):

labels = Labels.from_deepposekit(
filename=params["filename"],
video_path=params["video_path"],
Expand Down Expand Up @@ -872,7 +868,6 @@ def ask(context: "CommandContext", params: dict) -> bool:
class ImportLEAP(AppCommand):
@staticmethod
def do_action(context: "CommandContext", params: dict):

labels = Labels.load_leap_matlab(
filename=params["filename"],
)
Expand Down Expand Up @@ -903,7 +898,6 @@ def ask(context: "CommandContext", params: dict) -> bool:
class ImportCoco(AppCommand):
@staticmethod
def do_action(context: "CommandContext", params: dict):

labels = Labels.load_coco(
filename=params["filename"], img_dir=params["img_dir"], use_missing_gui=True
)
Expand Down Expand Up @@ -935,7 +929,6 @@ def ask(context: "CommandContext", params: dict) -> bool:
class ImportDeepLabCut(AppCommand):
@staticmethod
def do_action(context: "CommandContext", params: dict):

labels = Labels.load_deeplabcut(filename=params["filename"])

new_window = context.app.__class__()
Expand Down Expand Up @@ -1309,7 +1302,6 @@ def do_action(context: CommandContext, params: dict):

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:

from sleap.gui.dialogs.export_clip import ExportClipDialog

dialog = ExportClipDialog()
Expand Down Expand Up @@ -1585,7 +1577,6 @@ class GoNextSuggestedFrame(NavCommand):

@classmethod
def do_action(cls, context: CommandContext, params: dict):

next_suggestion_frame = context.labels.get_next_suggestion(
context.state["video"], context.state["frame_idx"], cls.seek_direction
)
Expand Down Expand Up @@ -1771,7 +1762,6 @@ class ReplaceVideo(EditCommand):

@staticmethod
def do_action(context: CommandContext, params: dict) -> bool:

import_list = params["import_list"]

for import_item, video in import_list:
Expand Down Expand Up @@ -1900,7 +1890,6 @@ def ask(context: CommandContext, params: dict) -> bool:
video_file_names = []
total_num_labeled_frames = 0
for idx in row_idxs:

video = videos[idx]
if video is None:
return False
Expand Down Expand Up @@ -1945,7 +1934,6 @@ def load_skeleton(filename: str):
def compare_skeletons(
skeleton: Skeleton, new_skeleton: Skeleton
) -> Tuple[List[str], List[str], List[str]]:

delete_nodes = []
add_nodes = []
if skeleton.node_names != new_skeleton.node_names:
Expand Down Expand Up @@ -2724,7 +2712,6 @@ class GenerateSuggestions(EditCommand):

@classmethod
def do_action(cls, context: CommandContext, params: dict):

if len(context.labels.videos) == 0:
print("Error: no videos to generate suggestions for")
return
Expand Down Expand Up @@ -2852,21 +2839,6 @@ def ask_and_do(cls, context: CommandContext, params: dict):
class AddInstance(EditCommand):
topics = [UpdateTopic.frame, UpdateTopic.project_instances, UpdateTopic.suggestions]

@staticmethod
def get_previous_frame_index(context: CommandContext) -> Optional[int]:
frames = context.labels.frames(
context.state["video"],
from_frame_idx=context.state["frame_idx"],
reverse=True,
)

try:
next_idx = next(frames).frame_idx
except:
return

return next_idx

@classmethod
def do_action(cls, context: CommandContext, params: dict):
copy_instance = params.get("copy_instance", None)
Expand All @@ -2880,6 +2852,175 @@ def do_action(cls, context: CommandContext, params: dict):
if len(context.state["skeleton"]) == 0:
return

(
copy_instance,
from_predicted,
from_prev_frame,
) = AddInstance.find_instance_to_copy_from(
context, copy_instance=copy_instance, init_method=init_method
)

new_instance = AddInstance.create_new_instance(
context=context,
from_predicted=from_predicted,
copy_instance=copy_instance,
mark_complete=mark_complete,
init_method=init_method,
location=location,
from_prev_frame=from_prev_frame,
)

# Add the instance
context.labels.add_instance(context.state["labeled_frame"], new_instance)

if context.state["labeled_frame"] not in context.labels.labels:
context.labels.append(context.state["labeled_frame"])

@staticmethod
def create_new_instance(
context: CommandContext,
from_predicted: bool,
copy_instance: Optional[Instance],
mark_complete: bool,
init_method: str,
location: Optional[QtCore.QPoint],
from_prev_frame: bool,
):
"""Create new instance."""

# Now create the new instance
new_instance = Instance(
skeleton=context.state["skeleton"],
from_predicted=from_predicted,
frame=context.state["labeled_frame"],
)

has_missing_nodes = AddInstance.set_visible_nodes(
context=context,
copy_instance=copy_instance,
new_instance=new_instance,
mark_complete=mark_complete,
)

if has_missing_nodes:
AddInstance.fill_missing_nodes(
context=context,
copy_instance=copy_instance,
init_method=init_method,
new_instance=new_instance,
location=location,
)

# If we're copying a predicted instance or from another frame, copy the track
if hasattr(copy_instance, "score") or from_prev_frame:
new_instance.track = copy_instance.track

return new_instance

@staticmethod
def fill_missing_nodes(
context: CommandContext,
copy_instance: Optional[Instance],
init_method: str,
new_instance: Instance,
location: Optional[QtCore.QPoint],
):
"""Fill in missing nodes for new instance.
Args:
context: The command context.
copy_instance: The instance to copy from.
init_method: The initialization method.
new_instance: The new instance.
location: The location of the instance.
Returns:
None
"""

# mark the node as not "visible" if we're copying from a predicted instance without this node
is_visible = copy_instance is None or (not hasattr(copy_instance, "score"))

if init_method == "force_directed":
AddMissingInstanceNodes.add_force_directed_nodes(
context=context,
instance=new_instance,
visible=is_visible,
center_point=location,
)
elif init_method == "random":
AddMissingInstanceNodes.add_random_nodes(
context=context, instance=new_instance, visible=is_visible
)
elif init_method == "template":
AddMissingInstanceNodes.add_nodes_from_template(
context=context,
instance=new_instance,
visible=is_visible,
center_point=location,
)
else:
AddMissingInstanceNodes.add_best_nodes(
context=context, instance=new_instance, visible=is_visible
)

@staticmethod
def set_visible_nodes(
context: CommandContext,
copy_instance: Optional[Instance],
new_instance: Instance,
mark_complete: bool,
) -> Tuple[Instance, bool]:
"""Sets visible nodes for new instance.
Args:
context: The command context.
copy_instance: The instance to copy from.
new_instance: The new instance.
mark_complete: Whether to mark the instance as complete.
Returns:
Whether the new instance has missing nodes.
"""

if copy_instance is None:
return True

has_missing_nodes = False

# go through each node in skeleton
for node in context.state["skeleton"].node_names:
# if we're copying from a skeleton that has this node
if node in copy_instance and not copy_instance[node].isnan():
# just copy x, y, and visible
# we don't want to copy a PredictedPoint or score attribute
new_instance[node] = Point(
x=copy_instance[node].x,
y=copy_instance[node].y,
visible=copy_instance[node].visible,
complete=mark_complete,
)
else:
has_missing_nodes = True

return has_missing_nodes

@staticmethod
def find_instance_to_copy_from(
context: CommandContext, copy_instance: Optional[Instance], init_method: bool
) -> Tuple[Optional[Instance], bool, bool]:
"""Find instance to copy from.
Args:
context: The command context.
copy_instance: The instance to copy from.
init_method: The initialization method.
Returns:
The instance to copy from, whether it's from a predicted instance, and
whether it's from a previous frame.
"""

from_predicted = copy_instance
from_prev_frame = False

Expand All @@ -2905,7 +3046,7 @@ def do_action(cls, context: CommandContext, params: dict):
) or init_method == "prior_frame":
# Otherwise, if there are instances in previous frames,
# copy the points from one of those instances.
prev_idx = cls.get_previous_frame_index(context)
prev_idx = AddInstance.get_previous_frame_index(context)

if prev_idx is not None:
prev_instances = context.labels.find(
Expand All @@ -2931,70 +3072,24 @@ def do_action(cls, context: CommandContext, params: dict):

from_predicted = from_predicted if hasattr(from_predicted, "score") else None

# Now create the new instance
new_instance = Instance(
skeleton=context.state["skeleton"],
from_predicted=from_predicted,
frame=context.state["labeled_frame"],
)

has_missing_nodes = False
return copy_instance, from_predicted, from_prev_frame

# go through each node in skeleton
for node in context.state["skeleton"].node_names:
# if we're copying from a skeleton that has this node
if (
copy_instance is not None
and node in copy_instance
and not copy_instance[node].isnan()
):
# just copy x, y, and visible
# we don't want to copy a PredictedPoint or score attribute
new_instance[node] = Point(
x=copy_instance[node].x,
y=copy_instance[node].y,
visible=copy_instance[node].visible,
complete=mark_complete,
)
else:
has_missing_nodes = True

if has_missing_nodes:
# mark the node as not "visible" if we're copying from a predicted instance without this node
is_visible = copy_instance is None or (not hasattr(copy_instance, "score"))

if init_method == "force_directed":
AddMissingInstanceNodes.add_force_directed_nodes(
context=context,
instance=new_instance,
visible=is_visible,
center_point=location,
)
elif init_method == "random":
AddMissingInstanceNodes.add_random_nodes(
context=context, instance=new_instance, visible=is_visible
)
elif init_method == "template":
AddMissingInstanceNodes.add_nodes_from_template(
context=context,
instance=new_instance,
visible=is_visible,
center_point=location,
)
else:
AddMissingInstanceNodes.add_best_nodes(
context=context, instance=new_instance, visible=is_visible
)
@staticmethod
def get_previous_frame_index(context: CommandContext) -> Optional[int]:
"""Returns index of previous frame."""

# If we're copying a predicted instance or from another frame, copy the track
if hasattr(copy_instance, "score") or from_prev_frame:
new_instance.track = copy_instance.track
frames = context.labels.frames(
context.state["video"],
from_frame_idx=context.state["frame_idx"],
reverse=True,
)

# Add the instance
context.labels.add_instance(context.state["labeled_frame"], new_instance)
try:
next_idx = next(frames).frame_idx
except:
return

if context.state["labeled_frame"] not in context.labels.labels:
context.labels.append(context.state["labeled_frame"])
return next_idx


class SetInstancePointLocations(EditCommand):
Expand Down

0 comments on commit 5c3441c

Please sign in to comment.