From ebfc47b300a61bf1a649d8f5eed780039318d73c Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:36:24 -0700 Subject: [PATCH 01/27] Add `InstancesList` class to handle backref to `LabeledFrame` (#1807) * Add InstancesList class to handle backref to LabeledFrame * Register structure/unstructure hooks for InstancesList * Add tests for the InstanceList class * Handle case where instance are passed in but labeled_frame is None * Add tests relevant methods in LabeledFrame * Delegate setting frame to InstancesList * Add test for PredictedInstance.frame after complex merge * Add todo comment to not use Instance.frame * Add rtest for InstasnceList.remove * Use normal list for informative `merged_instances` * Add test for copy and clear * Add copy and clear methods, use normal lists in merge method --- sleap/instance.py | 167 ++++++++++++++++++++++++----- tests/test_instance.py | 231 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 363 insertions(+), 35 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index c14038552..67e96f330 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -364,7 +364,7 @@ class Instance: from_predicted: Optional["PredictedInstance"] = attr.ib(default=None) _points: PointArray = attr.ib(default=None) _nodes: List = attr.ib(default=None) - frame: Union["LabeledFrame", None] = attr.ib(default=None) + frame: Union["LabeledFrame", None] = attr.ib(default=None) # TODO(LM): Make private # The underlying Point array type that this instances point array should be. _point_array_type = PointArray @@ -1214,6 +1214,9 @@ def unstructure_instance(x: Instance): converter.register_unstructure_hook(Instance, unstructure_instance) converter.register_unstructure_hook(PredictedInstance, unstructure_instance) + converter.register_unstructure_hook( + InstancesList, lambda x: [converter.unstructure(inst) for inst in x] + ) ## STRUCTURE HOOKS @@ -1247,6 +1250,7 @@ def structure_instances_list(x, type): converter.register_structure_hook( Union[List[Instance], List[PredictedInstance]], structure_instances_list ) + converter.register_structure_hook(InstancesList, structure_instances_list) # Structure forward reference for PredictedInstance for the Instance.from_predicted # attribute. @@ -1278,6 +1282,127 @@ def structure_point_array(x, t): return converter +class InstancesList(list): + """A list of `Instance`s associated with a `LabeledFrame`. + + This class should only be used for the `LabeledFrame.instances` attribute. + """ + + def __init__(self, *args, labeled_frame: Optional["LabeledFrame"] = None): + super(InstancesList, self).__init__(*args) + + # Set the labeled frame for each instance + self.labeled_frame = labeled_frame + + @property + def labeled_frame(self) -> "LabeledFrame": + """Return the `LabeledFrame` associated with this list of instances.""" + + return self._labeled_frame + + @labeled_frame.setter + def labeled_frame(self, labeled_frame: "LabeledFrame"): + """Set the `LabeledFrame` associated with this list of instances. + + This updates the `frame` attribute on each instance. + + Args: + labeled_frame: The `LabeledFrame` to associate with this list of instances. + """ + + try: + # If the labeled frame is the same as the one we're setting, then skip + if self._labeled_frame == labeled_frame: + return + except AttributeError: + # Only happens on init and updates each instance.frame (even if None) + pass + + # Otherwise, update the frame for each instance + self._labeled_frame = labeled_frame + for instance in self: + instance.frame = labeled_frame + + def append(self, instance: Union[Instance, PredictedInstance]): + """Append an `Instance` or `PredictedInstance` to the list, setting the frame. + + Args: + item: The `Instance` or `PredictedInstance` to append to the list. + """ + + if not isinstance(instance, (Instance, PredictedInstance)): + raise ValueError( + f"InstancesList can only contain Instance or PredictedInstance objects," + f" but got {type(instance)}." + ) + instance.frame = self.labeled_frame + super().append(instance) + + def extend(self, instances: List[Union[PredictedInstance, Instance]]): + """Extend the list with a list of `Instance`s or `PredictedInstance`s. + + Args: + instances: A list of `Instance` or `PredictedInstance` objects to add to the + list. + + Returns: + None + """ + for instance in instances: + self.append(instance) + + def __delitem__(self, index): + """Remove instance (by index), and set instance.frame to None.""" + + instance: Instance = self.__getitem__(index) + super().__delitem__(index) + + # Modify the instance to remove reference to the frame + instance.frame = None + + def insert(self, index: int, instance: Union[Instance, PredictedInstance]) -> None: + super().insert(index, instance) + instance.frame = self.labeled_frame + + def __setitem__(self, index, instance: Union[Instance, PredictedInstance]): + """Set nth instance in frame to the given instance. + + Args: + index: The index of instance to replace with new instance. + value: The new instance to associate with frame. + + Returns: + None. + """ + super().__setitem__(index, instance) + instance.frame = self.labeled_frame + + def pop(self, index: int) -> Union[Instance, PredictedInstance]: + """Remove and return instance at index, setting instance.frame to None.""" + + instance = super().pop(index) + instance.frame = None + return instance + + def remove(self, instance: Union[Instance, PredictedInstance]) -> None: + """Remove instance from list, setting instance.frame to None.""" + super().remove(instance) + instance.frame = None + + def clear(self) -> None: + """Remove all instances from list, setting instance.frame to None.""" + for instance in self: + instance.frame = None + super().clear() + + def copy(self) -> list: + """Return a shallow copy of the list of instances as a list. + + Note: This will not return an `InstancesList` object, but a normal list. + """ + return list(self) + + @attr.s(auto_attribs=True, eq=False, repr=False, str=False) class LabeledFrame: """Holds labeled data for a single frame of a video. @@ -1290,9 +1415,7 @@ class LabeledFrame: video: Video = attr.ib() frame_idx: int = attr.ib(converter=int) - _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib( - default=attr.Factory(list) - ) + _instances: InstancesList = attr.ib(default=attr.Factory(InstancesList)) def __attrs_post_init__(self): """Called by attrs. @@ -1302,8 +1425,7 @@ def __attrs_post_init__(self): """ # Make sure all instances have a reference to this frame - for instance in self.instances: - instance.frame = self + self.instances = self._instances def __len__(self) -> int: """Return number of instances associated with frame.""" @@ -1319,13 +1441,8 @@ def index(self, value: Instance) -> int: def __delitem__(self, index): """Remove instance (by index) from frame.""" - value = self.instances.__getitem__(index) - self.instances.__delitem__(index) - # Modify the instance to remove reference to this frame - value.frame = None - def __repr__(self) -> str: """Return a readable representation of the LabeledFrame.""" return ( @@ -1348,9 +1465,6 @@ def insert(self, index: int, value: Instance): """ self.instances.insert(index, value) - # Modify the instance to have a reference back to this frame - value.frame = self - def __setitem__(self, index, value: Instance): """Set nth instance in frame to the given instance. @@ -1363,9 +1477,6 @@ def __setitem__(self, index, value: Instance): """ self.instances.__setitem__(index, value) - # Modify the instance to have a reference back to this frame - value.frame = self - def find( self, track: Optional[Union[Track, int]] = -1, user: bool = False ) -> List[Instance]: @@ -1393,7 +1504,7 @@ def instances(self) -> List[Instance]: return self._instances @instances.setter - def instances(self, instances: List[Instance]): + def instances(self, instances: Union[InstancesList, List[Instance]]): """Set the list of instances associated with this frame. Updates the `frame` attribute on each instance to the @@ -1408,9 +1519,11 @@ def instances(self, instances: List[Instance]): None """ - # Make sure to set the frame for each instance to this LabeledFrame - for instance in instances: - instance.frame = self + # Make sure to set the LabeledFrame for each instance to this frame + if isinstance(instances, InstancesList): + instances.labeled_frame = self + else: + instances = InstancesList(instances, labeled_frame=self) self._instances = instances @@ -1685,22 +1798,20 @@ def complex_frame_merge( * list of conflicting instances from base * list of conflicting instances from new """ - merged_instances = [] - redundant_instances = [] - extra_base_instances = copy(base_frame.instances) - extra_new_instances = [] + merged_instances: List[Instance] = [] # Only used for informing user + redundant_instances: List[Instance] = [] + extra_base_instances: List[Instance] = list(base_frame.instances) + extra_new_instances: List[Instance] = [] for new_inst in new_frame: redundant = False for base_inst in base_frame.instances: if new_inst.matches(base_inst): - base_inst.frame = None extra_base_instances.remove(base_inst) redundant_instances.append(base_inst) redundant = True continue if not redundant: - new_inst.frame = None extra_new_instances.append(new_inst) conflict = False @@ -1732,7 +1843,7 @@ def complex_frame_merge( else: # No conflict, so include all instances in base base_frame.instances.extend(extra_new_instances) - merged_instances = copy(extra_new_instances) + merged_instances: List[Instance] = copy(extra_new_instances) extra_base_instances = [] extra_new_instances = [] diff --git a/tests/test_instance.py b/tests/test_instance.py index 74a8b192e..58a630a8b 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -1,19 +1,21 @@ -import os -import math import copy +import math +import os +from typing import List -import pytest import numpy as np +import pytest -from sleap.skeleton import Skeleton +from sleap import Labels from sleap.instance import ( Instance, - PredictedInstance, + InstancesList, + LabeledFrame, Point, + PredictedInstance, PredictedPoint, - LabeledFrame, ) -from sleap import Labels +from sleap.skeleton import Skeleton def test_instance_node_get_set_item(skeleton): @@ -310,6 +312,8 @@ def test_frame_merge_predicted_and_user(skeleton, centered_pair_vid): # and we want to retain both even though they perfectly match. assert user_inst in user_frame.instances assert pred_inst in user_frame.instances + assert user_inst.frame == user_frame + assert pred_inst.frame == user_frame assert len(user_frame.instances) == 2 @@ -529,3 +533,216 @@ def test_instance_structuring_from_predicted(centered_pair_predictions): # Unstructure -> structure labels_copy = labels.copy() + + +def test_instances_list(centered_pair_predictions): + + labels = centered_pair_predictions + + def test_extend(instances: InstancesList, list_of_instances: List[Instance]): + instances.extend(list_of_instances) + assert len(instances) == len(list_of_instances) + for instance in instances: + assert isinstance(instance, PredictedInstance) + if instances.labeled_frame is None: + assert instance.frame is None + else: + assert instance.frame == instances.labeled_frame + + def test_append(instances: InstancesList, instance: Instance): + prev_len = len(instances) + instances.append(instance) + assert len(instances) == prev_len + 1 + assert instances[-1] == instance + assert instance.frame == instances.labeled_frame + + def test_labeled_frame_setter( + instances: InstancesList, labeled_frame: LabeledFrame + ): + instances.labeled_frame = labeled_frame + for instance in instances: + assert instance.frame == labeled_frame + + # Case 1: Create an empty instances list + labeled_frame = labels.labeled_frames[0] + list_of_instances = list(labeled_frame.instances) + instances = InstancesList() + assert len(instances) == 0 + assert instances._labeled_frame is None + assert instances.labeled_frame is None + + # Extend instances list + assert not isinstance(list_of_instances, InstancesList) + assert isinstance(list_of_instances, list) + test_extend(instances, list_of_instances) + + # Set the labeled frame + test_labeled_frame_setter(instances, labeled_frame) + + # Case 2: Create an empy instances list but initialize the labeled frame + instances = InstancesList(labeled_frame=labeled_frame) + assert len(instances) == 0 + assert instances._labeled_frame == labeled_frame + assert instances.labeled_frame == labeled_frame + + # Extend instances to the list from a different labeled frame + labeled_frame = labels.labeled_frames[1] + list_of_instances = list(labeled_frame.instances) + test_extend(instances, list_of_instances) + + # Add instance to the list + instance = list_of_instances[0] + instance.frame = None + test_append(instances, instance) + + # Set the labeled frame + test_labeled_frame_setter(instances, labeled_frame) + + # Test InstancesList.copy + instances_copy = instances.copy() + assert len(instances_copy) == len(instances) + assert not isinstance(instances_copy, InstancesList) + assert isinstance(instances_copy, list) + + # Test InstancesList.clear + instances_in_instances = list(instances) + instances.clear() + assert len(instances) == 0 + for instance in instances_in_instances: + assert instance.frame is None + + # Case 3: Create an instances list with a list of instances + labeled_frame = labels.labeled_frames[0] + list_of_instances = list(labeled_frame.instances) + instances = InstancesList(list_of_instances) + assert len(instances) == len(list_of_instances) + assert instances._labeled_frame is None + assert instances.labeled_frame is None + for instance in instances: + assert instance.frame is None + + # Add instance to the list + instance = list_of_instances[0] + test_append(instances, instance) + + # Case 4: Create an instances list with a list of instances and initialize the frame + labeled_frame_1 = labels.labeled_frames[0] + labeled_frame_2 = labels.labeled_frames[1] + list_of_instances = list(labeled_frame_2.instances) + instances = InstancesList(list_of_instances, labeled_frame=labeled_frame_1) + assert len(instances) == len(list_of_instances) + assert instances._labeled_frame == labeled_frame + assert instances.labeled_frame == labeled_frame + for instance in instances: + assert instance.frame == labeled_frame + + # Test InstancesList.__delitem__ + instance_to_remove = instances[0] + del instances[0] + assert instance_to_remove not in instances + assert instance_to_remove.frame is None + + # Test InstancesList.insert + instances.insert(0, instance_to_remove) + assert instances[0] == instance_to_remove + assert instance_to_remove.frame == instances.labeled_frame + + # Test InstancesList.__setitem__ + new_instance = labeled_frame_1.instances[0] + new_instance.frame = None + instances[0] = new_instance + assert instances[0] == new_instance + assert new_instance.frame == instances.labeled_frame + + # Test InstancesList.pop + popped_instance = instances.pop(0) + assert popped_instance.frame is None + + # Test InstancesList.remove + instance_to_remove = instances[0] + instances.remove(instance_to_remove) + assert instance_to_remove.frame is None + assert instance_to_remove not in instances + + # Case 5: Create an instances list from an instances list + instances_1 = InstancesList(list_of_instances, labeled_frame=labeled_frame_1) + instances = InstancesList(instances_1) + assert len(instances) == len(instances_1) + assert instances._labeled_frame is None + assert instances.labeled_frame is None + for instance in instances: + assert instance.frame is None + + +def test_instances_list_with_labeled_frame(centered_pair_predictions): + labels: Labels = centered_pair_predictions + labels_lf_0: LabeledFrame = labels.labeled_frames[0] + video = labels_lf_0.video + frame_idx = labels_lf_0.frame_idx + + def test_post_init(labeled_frame: LabeledFrame): + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame + + # Create labeled frame from list of instances + instances = list(labels_lf_0.instances) + for instance in instances: + instance.frame = None # Change frame to None to test if it is set correctly + labeled_frame = LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + assert isinstance(labeled_frame.instances, InstancesList) + assert len(labeled_frame.instances) == len(instances) + test_post_init(labeled_frame) + + # Create labeled frame from instances list + instances = InstancesList(labels_lf_0.instances) + labeled_frame = LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + assert isinstance(labeled_frame.instances, InstancesList) + assert len(labeled_frame.instances) == len(instances) + test_post_init(labeled_frame) + + # Test LabeledFrame.__len__ + assert len(labeled_frame.instances) == len(instances) + + # Test LabeledFrame.__getitem__ + assert labeled_frame[0] == instances[0] + + # Test LabeledFrame.index + assert labeled_frame.index(instances[0]) == instances.index(instances[0]) == 0 + + # Test LabeledFrame.__delitem__ + instance_to_remove = labeled_frame[0] + del labeled_frame[0] + assert instance_to_remove not in labeled_frame.instances + assert instance_to_remove.frame is None + + # Test LabeledFrame.__repr__ + print(labeled_frame) + + # Test LabeledFrame.insert + labeled_frame.insert(0, instance_to_remove) + assert labeled_frame[0] == instance_to_remove + assert instance_to_remove.frame == labeled_frame + + # Test LabeledFrame.__setitem__ + new_instance = instances[1] + new_instance.frame = None + labeled_frame[0] = new_instance + assert labeled_frame[0] == new_instance + assert new_instance.frame == labeled_frame + + # Test instances.setter (empty list) + labeled_frame.instances = [] + assert len(labeled_frame.instances) == 0 + assert labeled_frame.instances.labeled_frame == labeled_frame + # Test instances.setter (InstancesList) + labeled_frame.instances = labels.labeled_frames[1].instances + assert len(labeled_frame.instances) == len(labels.labeled_frames[1].instances) + assert labeled_frame.instances.labeled_frame == labeled_frame + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame + # Test instances.setter (populated list) + labeled_frame.instances = list(labels.labeled_frames[1].instances) + assert len(labeled_frame.instances) == len(labels.labeled_frames[1].instances) + assert labeled_frame.instances.labeled_frame == labeled_frame + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame From 05283629a67385f40a6e8302519393fb2f87ecfb Mon Sep 17 00:00:00 2001 From: Elizabeth <106755962+eberrigan@users.noreply.github.com> Date: Fri, 28 Jun 2024 10:32:35 -0500 Subject: [PATCH 02/27] Bump to v1.4.1a2 (#1835) bump to 1.4.1a2 --- .github/ISSUE_TEMPLATE/bug_report.md | 2 +- docs/conf.py | 5 +++-- docs/installation.md | 6 +++--- sleap/version.py | 3 +-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index c0fdb26ca..108c2a65e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -28,7 +28,7 @@ Please include information about how you installed. - OS: - Version(s): - + - SLEAP installation method (listed [here](https://sleap.ai/installation.html#)): - [ ] [Conda from package](https://sleap.ai/installation.html#conda-package) - [ ] [Conda from source](https://sleap.ai/installation.html#conda-from-source) diff --git a/docs/conf.py b/docs/conf.py index d8470a190..6cd9593ef 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,7 +28,7 @@ copyright = f"2019–{date.today().year}, Talmo Lab" # The short X.Y version -version = "1.4.1a1" +version = "1.4.1a2" # Get the sleap version # with open("../sleap/version.py") as f: @@ -36,7 +36,7 @@ # version = re.search("\d.+(?=['\"])", version_file).group(0) # Release should be the full branch name -release = "v1.4.1a1" +release = "v1.4.1a2" html_title = f"SLEAP ({release})" html_short_title = "SLEAP" @@ -85,6 +85,7 @@ pygments_style = "sphinx" pygments_dark_style = "monokai" + # Autosummary linkcode resolution # https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html def linkcode_resolve(domain, info): diff --git a/docs/installation.md b/docs/installation.md index dcb233904..c0ab66580 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -137,13 +137,13 @@ SLEAP can be installed three different ways: via {ref}`conda package Date: Fri, 28 Jun 2024 10:51:50 -0700 Subject: [PATCH 03/27] Updated trail length viewing options (#1822) * updated trail length optptions * Updated trail length options in the view menu * Updated `prefs` to include length info from `preferences.yaml` * Added trail length as method of `MainWindow` * Updated trail length documentation * black formatting --------- Co-authored-by: Keya Loding --- docs/guides/gui.md | 2 +- sleap/gui/app.py | 3 +++ sleap/gui/overlays/tracks.py | 4 +++- sleap/prefs.py | 4 ++++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/guides/gui.md b/docs/guides/gui.md index 88cf3f656..813ed68fa 100644 --- a/docs/guides/gui.md +++ b/docs/guides/gui.md @@ -60,7 +60,7 @@ Note that many of the menu command have keyboard shortcuts which can be configur "**Edge Style**" controls whether edges are drawn as thin lines or as wedges which indicate the {ref}`orientation` of the instance (as well as the direction of the part affinity field which would be used to predict the connection between nodes when using a "bottom-up" approach). -"**Trail Length**" allows you to show a trail of where each instance was located in prior frames (the length of the trail is the number of prior frames). This can be useful when proofreading predictions since it can help you detect swaps in the identities of animals across frames. +"**Trail Length**" allows you to show a trail of where each instance was located in prior frames (the length of the trail is the number of prior frames). This can be useful when proofreading predictions since it can help you detect swaps in the identities of animals across frames. By default, you can only select trail lengths of up to 250 frames. You can use a custom trail length by modifying the default length in the `preferences.yaml` file. However, using trail lengths longer than about 500 frames can result in significant lag. "**Fit Instances to View**" allows you to toggle whether the view is auto-zoomed to the instances in each frame. This can be useful when proofreading predictions. diff --git a/sleap/gui/app.py b/sleap/gui/app.py index becc1d83a..736d7207f 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -151,6 +151,7 @@ def __init__( self.state["edge style"] = prefs["edge style"] self.state["fit"] = False self.state["color predicted"] = prefs["color predicted"] + self.state["trail_length"] = prefs["trail length"] self.state["trail_shade"] = prefs["trail shade"] self.state["marker size"] = prefs["marker size"] self.state["propagate track labels"] = prefs["propagate track labels"] @@ -221,6 +222,7 @@ def closeEvent(self, event): prefs["edge style"] = self.state["edge style"] prefs["propagate track labels"] = self.state["propagate track labels"] prefs["color predicted"] = self.state["color predicted"] + prefs["trail length"] = self.state["trail_length"] prefs["trail shade"] = self.state["trail_shade"] prefs["share usage data"] = self.state["share usage data"] @@ -1025,6 +1027,7 @@ def _load_overlays(self): labels=self.labels, player=self.player, trail_shade=self.state["trail_shade"], + trail_length=self.state["trail_length"], ) self.overlays["instance"] = InstanceOverlay( labels=self.labels, player=self.player, state=self.state diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index 361585719..bf0b633cd 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -48,7 +48,9 @@ def __attrs_post_init__(self): @classmethod def get_length_options(cls): - return (0, 10, 50, 100, 250) + if prefs["trail length"] != 0: + return (0, 10, 50, 100, 250, 500, prefs["trail length"]) + return (0, 10, 50, 100, 250, 500) @classmethod def get_shade_options(cls): diff --git a/sleap/prefs.py b/sleap/prefs.py index 3d5a2113e..8790f1d3f 100644 --- a/sleap/prefs.py +++ b/sleap/prefs.py @@ -45,6 +45,10 @@ def load_(self): self._prefs = util.get_config_yaml(self._filename) if not hasattr(self._prefs, "get"): self._prefs = self._defaults + else: + self._prefs["trail length"] = self._prefs.get( + "trail length", self._defaults["trail length"] + ) except FileNotFoundError: self._prefs = self._defaults From c8e3cd09bc6325ff18e2ad352b705f6175880263 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Mon, 1 Jul 2024 11:20:59 -0700 Subject: [PATCH 04/27] Handle case when no frame selection for trail overlay (#1832) --- sleap/gui/overlays/tracks.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index bf0b633cd..c5f091658 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -1,17 +1,16 @@ """Track trail and track list overlays.""" +from typing import Dict, Iterable, List, Optional, Tuple + +import attr +from qtpy import QtCore, QtGui + from sleap.gui.overlays.base import BaseOverlay +from sleap.gui.widgets.video import QtTextWithBackground from sleap.instance import Track from sleap.io.dataset import Labels from sleap.io.video import Video from sleap.prefs import prefs -from sleap.gui.widgets.video import QtTextWithBackground - -import attr - -from typing import Iterable, List, Optional, Dict - -from qtpy import QtCore, QtGui @attr.s(auto_attribs=True) @@ -58,7 +57,9 @@ def get_shade_options(cls): return {"Dark": 0.6, "Normal": 1.0, "Light": 1.25} - def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]): + def get_track_trails( + self, frame_selection: Iterable["LabeledFrame"] + ) -> Optional[Dict[Track, List[List[Tuple[float, float]]]]]: """Get data needed to draw track trail. Args: @@ -154,6 +155,8 @@ def add_to_scene(self, video: Video, frame_idx: int): frame_selection = self.get_frame_selection(video, frame_idx) all_track_trails = self.get_track_trails(frame_selection) + if all_track_trails is None: + return for track, trails in all_track_trails.items(): trail_color = tuple( From 324377ebe0720911b8c3ae7bb6f080670445e0d1 Mon Sep 17 00:00:00 2001 From: Shrivaths Shyam <52810689+shrivaths16@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:39:22 -0700 Subject: [PATCH 05/27] Menu option to open preferences directory and update to util functions to pathlib (#1843) * Add menu to view preferences directory and update to pathlib * text formatting --- sleap/gui/app.py | 25 ++++++++++++++++++++++++- sleap/util.py | 30 ++++++++++-------------------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 736d7207f..e872ce9a6 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -53,6 +53,8 @@ from logging import getLogger from pathlib import Path from typing import Callable, List, Optional, Tuple +import sys +import subprocess from qtpy import QtCore, QtGui from qtpy.QtCore import QEvent, Qt @@ -84,7 +86,7 @@ from sleap.io.video import available_video_exts from sleap.prefs import prefs from sleap.skeleton import Skeleton -from sleap.util import parse_uri_path +from sleap.util import parse_uri_path, get_config_file logger = getLogger(__name__) @@ -515,6 +517,13 @@ def add_submenu_choices(menu, title, options, key): fileMenu, "reset prefs", "Reset preferences to defaults...", self.resetPrefs ) + add_menu_item( + fileMenu, + "open preference directory", + "Open Preferences Directory...", + self.openPrefs, + ) + fileMenu.addSeparator() add_menu_item(fileMenu, "close", "Quit", self.close) @@ -1330,6 +1339,20 @@ def resetPrefs(self): ) msg.exec_() + def openPrefs(self): + """Open preference file directory""" + pref_path = get_config_file("preferences.yaml") + # Make sure the pref_path is a directory rather than a file + if pref_path.is_file(): + pref_path = pref_path.parent + # Open the file explorer at the folder containing the preferences.yaml file + if sys.platform == "win32": + subprocess.Popen(["explorer", str(pref_path)]) + elif sys.platform == "darwin": + subprocess.Popen(["open", str(pref_path)]) + else: + subprocess.Popen(["xdg-open", str(pref_path)]) + def _update_track_menu(self): """Updates track menu options.""" self.track_menu.clear() diff --git a/sleap/util.py b/sleap/util.py index 5edbf164b..eef762ff4 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -270,30 +270,20 @@ def get_config_file( The full path to the specified config file. """ - desired_path = None # Handle case where get_defaults, but cannot find package_path + desired_path = Path.home() / f".sleap/{sleap_version.__version__}/{shortname}" - if not get_defaults: - desired_path = os.path.expanduser( - f"~/.sleap/{sleap_version.__version__}/{shortname}" - ) + # Make sure there's a ~/.sleap// directory to store user version of the config file. + desired_path.parent.mkdir(parents=True, exist_ok=True) - # Make sure there's a ~/.sleap// directory to store user version of the - # config file. - try: - os.makedirs(os.path.expanduser(f"~/.sleap/{sleap_version.__version__}")) - except FileExistsError: - pass - - # If we don't care whether the file exists, just return the path - if ignore_file_not_found: - return desired_path - - # If we do care whether the file exists, check the package version of the - # config file if we can't find the user version. + # If we don't care whether the file exists, just return the path + if ignore_file_not_found: + return desired_path - if get_defaults or not os.path.exists(desired_path): + # If we do care whether the file exists, check the package version of the config file if we can't find the user version. + if get_defaults or not desired_path.exists(): package_path = get_package_file(f"config/{shortname}") - if not os.path.exists(package_path): + package_path = Path(package_path) + if not package_path.exists(): raise FileNotFoundError( f"Cannot locate {shortname} config file at {desired_path} or {package_path}." ) From 14c21b4459773c99b941b77867d5746becc68ff7 Mon Sep 17 00:00:00 2001 From: Hajin Park Date: Wed, 10 Jul 2024 11:49:43 -0700 Subject: [PATCH 06/27] Add `Keep visualizations` checkbox to training GUI (#1824) * Renamed save_visualizations to view_visualizations for clarity * Added Delete Visualizations button to the training pipeline gui, exposed del_viz_predictions config option to the user * Reverted view_ back to save_ and changed new training checkbox to Keep visualization images after training. * Fixed keep_viz config option state override bug and updated keep_viz doc description * Added test case for reading training CLI argument correctly * Removed unnecessary testing code * Creating test case to check for viz folder * Finished tests to check CLI argument reading and viz directory existence * Use empty string instead of None in cli args test * Use keep_viz_images false in most all test configs (except test to override config) --------- Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- docs/guides/cli.md | 6 ++- ..._and_inference_on_an_example_dataset.ipynb | 4 +- sleap/config/pipeline_form.yaml | 5 ++ sleap/gui/learning/runners.py | 17 ++++-- sleap/nn/config/outputs.py | 6 +-- sleap/nn/training.py | 13 ++++- .../training_profiles/baseline.centroid.json | 1 + .../baseline_large_rf.bottomup.json | 1 + .../baseline_large_rf.single.json | 1 + .../baseline_large_rf.topdown.json | 1 + .../baseline_medium_rf.bottomup.json | 1 + .../baseline_medium_rf.single.json | 1 + .../baseline_medium_rf.topdown.json | 1 + .../pretrained.bottomup.json | 1 + .../pretrained.centroid.json | 1 + .../training_profiles/pretrained.single.json | 1 + .../training_profiles/pretrained.topdown.json | 1 + .../initial_config.json | 1 + .../training_config.json | 1 + .../initial_config.json | 2 +- .../training_config.json | 2 +- .../initial_config.json | 1 + .../training_config.json | 1 + .../initial_config.json | 1 + .../training_config.json | 1 + .../initial_config.json | 1 + .../training_config.json | 1 + .../initial_config.json | 1 + .../training_config.json | 1 + tests/gui/test_dialogs.py | 1 - tests/nn/test_training.py | 54 ++++++++++++++++++- 31 files changed, 114 insertions(+), 17 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index c29270299..ab62f3130 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -36,8 +36,8 @@ optional arguments: ```none usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS] - [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] - [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] + [--test_labels TEST_LABELS] [--tensorboard] [--save_viz] + [--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX] [--suffix SUFFIX] training_job_path [labels_path] @@ -68,6 +68,8 @@ optional arguments: --save_viz Enable saving of prediction visualizations to the run folder if not already specified in the training job config. + --keep_viz Keep prediction visualization images in the run + folder after training if --save_viz is enabled. --zmq Enable ZMQ logging (for GUI) if not already specified in the training job config. --run_name RUN_NAME Run name to use when saving file, overrides other run diff --git a/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb b/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb index b0211bbca..4e26cb286 100644 --- a/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb +++ b/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb @@ -335,7 +335,7 @@ " \"runs_folder\": \"models\",\n", " \"tags\": [],\n", " \"save_visualizations\": true,\n", - " \"delete_viz_images\": true,\n", + " \"keep_viz_images\": true,\n", " \"zip_outputs\": false,\n", " \"log_to_csv\": true,\n", " \"checkpointing\": {\n", @@ -727,7 +727,7 @@ " \"runs_folder\": \"models\",\n", " \"tags\": [],\n", " \"save_visualizations\": true,\n", - " \"delete_viz_images\": true,\n", + " \"keep_viz_images\": true,\n", " \"zip_outputs\": false,\n", " \"log_to_csv\": true,\n", " \"checkpointing\": {\n", diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index be9e272c7..c730fa9c4 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -286,6 +286,11 @@ training: type: bool default: true +- name: _keep_viz + label: Keep Prediction Visualization Images After Training + type: bool + default: false + - name: _predict_frames label: Predict On type: list diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index a2e84788c..7569607a0 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -1,4 +1,5 @@ """Run training/inference in background process via CLI.""" + import abc import attr import os @@ -500,9 +501,11 @@ def write_pipeline_files( "data_path": os.path.basename(data_path), "models": [Path(p).as_posix() for p in new_cfg_filenames], "output_path": prediction_output_path, - "type": "labels" - if type(item_for_inference) == DatasetItemForInference - else "video", + "type": ( + "labels" + if type(item_for_inference) == DatasetItemForInference + else "video" + ), "only_suggested_frames": only_suggested_frames, "tracking": tracking_args, } @@ -544,6 +547,7 @@ def run_learning_pipeline( """ save_viz = inference_params.get("_save_viz", False) + keep_viz = inference_params.get("_keep_viz", False) if "movenet" in inference_params["_pipeline"]: trained_job_paths = [inference_params["_pipeline"]] @@ -557,6 +561,7 @@ def run_learning_pipeline( inference_params=inference_params, gui=True, save_viz=save_viz, + keep_viz=keep_viz, ) # Check that all the models were trained @@ -585,6 +590,7 @@ def run_gui_training( inference_params: Dict[str, Any], gui: bool = True, save_viz: bool = False, + keep_viz: bool = False, ) -> Dict[Text, Text]: """ Runs training for each training job. @@ -594,6 +600,7 @@ def run_gui_training( config_info_list: List of ConfigFileInfo with configs for training. gui: Whether to show gui windows and process gui events. save_viz: Whether to save visualizations from training. + keep_viz: Whether to keep prediction visualization images after training. Returns: Dictionary, keys are head name, values are path to trained config. @@ -683,6 +690,7 @@ def waiting(): video_paths=video_path_list, waiting_callback=waiting, save_viz=save_viz, + keep_viz=keep_viz, ) if ret == "success": @@ -825,6 +833,7 @@ def train_subprocess( video_paths: Optional[List[Text]] = None, waiting_callback: Optional[Callable] = None, save_viz: bool = False, + keep_viz: bool = False, ): """Runs training inside subprocess.""" run_path = job_config.outputs.run_path @@ -853,6 +862,8 @@ def train_subprocess( if save_viz: cli_args.append("--save_viz") + if keep_viz: + cli_args.append("--keep_viz") # Use cli arg since cli ignores setting in config if job_config.outputs.tensorboard.write_logs: diff --git a/sleap/nn/config/outputs.py b/sleap/nn/config/outputs.py index ffb0d76e4..ccb6077b1 100644 --- a/sleap/nn/config/outputs.py +++ b/sleap/nn/config/outputs.py @@ -151,8 +151,8 @@ class OutputsConfig: save_visualizations: If True, will render and save visualizations of the model predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the split is one of "train", "validation", "test". - delete_viz_images: If True, delete the saved visualizations after training - completes. This is useful to reduce the model folder size if you do not need + keep_viz_images: If True, keep the saved visualization images after training + completes. This is useful unchecked to reduce the model folder size if you do not need to keep the visualization images. zip_outputs: If True, compress the run folder to a zip file. This will be named "{run_folder}.zip". @@ -170,7 +170,7 @@ class OutputsConfig: runs_folder: Text = "models" tags: List[Text] = attr.ib(factory=list) save_visualizations: bool = True - delete_viz_images: bool = True + keep_viz_images: bool = False zip_outputs: bool = False log_to_csv: bool = True checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig) diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 6a64e43b6..9e4245b88 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -946,7 +946,7 @@ def train(self): if self.config.outputs.save_outputs: if ( self.config.outputs.save_visualizations - and self.config.outputs.delete_viz_images + and not self.config.outputs.keep_viz_images ): self.cleanup() @@ -997,7 +997,7 @@ def cleanup(self): def package(self): """Package model folder into a zip file for portability.""" - if self.config.outputs.delete_viz_images: + if not self.config.outputs.keep_viz_images: self.cleanup() logger.info(f"Packaging results to: {self.run_path}.zip") shutil.make_archive( @@ -1864,6 +1864,14 @@ def create_trainer_using_cli(args: Optional[List] = None): "already specified in the training job config." ), ) + parser.add_argument( + "--keep_viz", + action="store_true", + help=( + "Keep prediction visualization images in the run folder after training when " + "--save_viz is enabled." + ), + ) parser.add_argument( "--zmq", action="store_true", @@ -1949,6 +1957,7 @@ def create_trainer_using_cli(args: Optional[List] = None): if args.suffix != "": job_config.outputs.run_name_suffix = args.suffix job_config.outputs.save_visualizations |= args.save_viz + job_config.outputs.keep_viz_images = args.keep_viz if args.labels_path == "": args.labels_path = None args.video_paths = args.video_paths.split(",") diff --git a/sleap/training_profiles/baseline.centroid.json b/sleap/training_profiles/baseline.centroid.json index 933989ecf..3a54db25c 100755 --- a/sleap/training_profiles/baseline.centroid.json +++ b/sleap/training_profiles/baseline.centroid.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.bottomup.json b/sleap/training_profiles/baseline_large_rf.bottomup.json index ea45c9b25..18fb3104f 100644 --- a/sleap/training_profiles/baseline_large_rf.bottomup.json +++ b/sleap/training_profiles/baseline_large_rf.bottomup.json @@ -125,6 +125,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.single.json b/sleap/training_profiles/baseline_large_rf.single.json index 75e97b1a6..3feeccd69 100644 --- a/sleap/training_profiles/baseline_large_rf.single.json +++ b/sleap/training_profiles/baseline_large_rf.single.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_large_rf.topdown.json b/sleap/training_profiles/baseline_large_rf.topdown.json index 9b17f6832..38e96594b 100644 --- a/sleap/training_profiles/baseline_large_rf.topdown.json +++ b/sleap/training_profiles/baseline_large_rf.topdown.json @@ -117,6 +117,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.bottomup.json b/sleap/training_profiles/baseline_medium_rf.bottomup.json index 1cc35330a..61b08515c 100644 --- a/sleap/training_profiles/baseline_medium_rf.bottomup.json +++ b/sleap/training_profiles/baseline_medium_rf.bottomup.json @@ -125,6 +125,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.single.json b/sleap/training_profiles/baseline_medium_rf.single.json index 579f6c8c3..0951bc761 100644 --- a/sleap/training_profiles/baseline_medium_rf.single.json +++ b/sleap/training_profiles/baseline_medium_rf.single.json @@ -116,6 +116,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/baseline_medium_rf.topdown.json b/sleap/training_profiles/baseline_medium_rf.topdown.json index 9e3a0bde5..9eccb76c1 100755 --- a/sleap/training_profiles/baseline_medium_rf.topdown.json +++ b/sleap/training_profiles/baseline_medium_rf.topdown.json @@ -117,6 +117,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.bottomup.json b/sleap/training_profiles/pretrained.bottomup.json index 3e4f3935f..57b7398b5 100644 --- a/sleap/training_profiles/pretrained.bottomup.json +++ b/sleap/training_profiles/pretrained.bottomup.json @@ -122,6 +122,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.centroid.json b/sleap/training_profiles/pretrained.centroid.json index a5df5e48a..74c43d3e2 100644 --- a/sleap/training_profiles/pretrained.centroid.json +++ b/sleap/training_profiles/pretrained.centroid.json @@ -113,6 +113,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.single.json b/sleap/training_profiles/pretrained.single.json index 7ca907007..615f0de4d 100644 --- a/sleap/training_profiles/pretrained.single.json +++ b/sleap/training_profiles/pretrained.single.json @@ -113,6 +113,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/sleap/training_profiles/pretrained.topdown.json b/sleap/training_profiles/pretrained.topdown.json index aeeaebbd8..be0d97de8 100644 --- a/sleap/training_profiles/pretrained.topdown.json +++ b/sleap/training_profiles/pretrained.topdown.json @@ -114,6 +114,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": true, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json index 7e52d1703..2ae0e925c 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json @@ -128,6 +128,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json index bcb2f26d5..7b6f817aa 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json @@ -191,6 +191,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json index 045890b21..5d8081628 100644 --- a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +++ b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json @@ -141,7 +141,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, - "delete_viz_images": true, + "keep_viz_images": false, "zip_outputs": false, "log_to_csv": true, "checkpointing": { diff --git a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json index 070e9d3c0..9591e5b52 100644 --- a/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +++ b/tests/data/models/min_tracks_2node.UNet.topdown_multiclass/training_config.json @@ -208,7 +208,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, - "delete_viz_images": true, + "keep_viz_images": false, "zip_outputs": false, "log_to_csv": true, "checkpointing": { diff --git a/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json b/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json index 8e39fea3f..68e4f894e 100644 --- a/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.bottomup/initial_config.json @@ -127,6 +127,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.bottomup/training_config.json b/tests/data/models/minimal_instance.UNet.bottomup/training_config.json index d1fb718ba..e3bfbc5f8 100644 --- a/tests/data/models/minimal_instance.UNet.bottomup/training_config.json +++ b/tests/data/models/minimal_instance.UNet.bottomup/training_config.json @@ -192,6 +192,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json b/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json index 739d8e3e7..f4914aae4 100644 --- a/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.centered_instance/initial_config.json @@ -119,6 +119,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json b/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json index 7b6782a68..e747f6862 100644 --- a/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json +++ b/tests/data/models/minimal_instance.UNet.centered_instance/training_config.json @@ -179,6 +179,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centroid/initial_config.json b/tests/data/models/minimal_instance.UNet.centroid/initial_config.json index 41d8ac8c3..977654b2e 100644 --- a/tests/data/models/minimal_instance.UNet.centroid/initial_config.json +++ b/tests/data/models/minimal_instance.UNet.centroid/initial_config.json @@ -118,6 +118,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_instance.UNet.centroid/training_config.json b/tests/data/models/minimal_instance.UNet.centroid/training_config.json index 2d2280a31..02e9683e1 100644 --- a/tests/data/models/minimal_instance.UNet.centroid/training_config.json +++ b/tests/data/models/minimal_instance.UNet.centroid/training_config.json @@ -175,6 +175,7 @@ "runs_folder": "models", "tags": [], "save_visualizations": false, + "keep_viz_images": false, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json b/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json index cb2e4f353..f2bb907fa 100644 --- a/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json +++ b/tests/data/models/minimal_robot.UNet.single_instance/initial_config.json @@ -120,6 +120,7 @@ "" ], "save_visualizations": false, + "keep_viz_images": true, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/data/models/minimal_robot.UNet.single_instance/training_config.json b/tests/data/models/minimal_robot.UNet.single_instance/training_config.json index 66901c9f0..dffecc1d9 100644 --- a/tests/data/models/minimal_robot.UNet.single_instance/training_config.json +++ b/tests/data/models/minimal_robot.UNet.single_instance/training_config.json @@ -180,6 +180,7 @@ "" ], "save_visualizations": false, + "keep_viz_images": true, "log_to_csv": true, "checkpointing": { "initial_model": false, diff --git a/tests/gui/test_dialogs.py b/tests/gui/test_dialogs.py index 4455550fb..611a73c85 100644 --- a/tests/gui/test_dialogs.py +++ b/tests/gui/test_dialogs.py @@ -1,6 +1,5 @@ """Module to test the dialogs of the GUI (contained in sleap/gui/dialogs).""" - import os from pathlib import Path diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index b6696e819..72db17bb5 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -123,34 +123,61 @@ def test_train_load_single_instance( assert (w == w2).all() -def test_train_single_instance(min_labels_robot, cfg): +def test_train_single_instance(min_labels_robot, cfg, tmp_path): cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=False ) + + # Set save directory + cfg.outputs.run_name = "test_run" + cfg.outputs.runs_folder = str(tmp_path / "training_runs") # ensure it's a string + cfg.outputs.save_visualizations = True + cfg.outputs.keep_viz_images = True + cfg.outputs.save_outputs = True # enable saving + trainer = SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() + + run_path = Path(cfg.outputs.runs_folder, cfg.outputs.run_name) + viz_path = run_path / "viz" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) + assert viz_path.exists() -def test_train_single_instance_with_offset(min_labels_robot, cfg): +def test_train_single_instance_with_offset(min_labels_robot, cfg, tmp_path): cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=True ) + + # Set save directory + cfg.outputs.run_name = "test_run" + cfg.outputs.runs_folder = str(tmp_path / "training_runs") # ensure it's a string + cfg.outputs.save_visualizations = False + cfg.outputs.keep_viz_images = False + cfg.outputs.save_outputs = True # enable saving + trainer = SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() + + run_path = Path(cfg.outputs.runs_folder, cfg.outputs.run_name) + viz_path = run_path / "viz" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) assert trainer.keras_model.output_names[1] == "OffsetRefinementHead" assert tuple(trainer.keras_model.outputs[1].shape) == (None, 320, 560, 4) + assert not viz_path.exists() + def test_train_centroids(training_labels, cfg): cfg.model.heads.centroid = CentroidsHeadConfig( @@ -360,3 +387,26 @@ def test_resume_training_cli( trainer = sleap_train(cli_args) assert trainer.config.model.base_checkpoint == base_checkpoint_path + + +@pytest.mark.parametrize("keep_viz_cli", ["", "--keep_viz"]) +def test_keep_viz_cli( + keep_viz_cli, + min_single_instance_robot_model_path: str, + tmp_path: str, +): + """Test training CLI for --keep_viz option.""" + cfg_dir = min_single_instance_robot_model_path + cfg = TrainingJobConfig.load_json(str(Path(cfg_dir, "training_config.json"))) + + # Save training config to tmp folder + cfg_path = str(Path(tmp_path, "training_config.json")) + cfg.save_json(cfg_path) + + cli_args = [cfg_path, keep_viz_cli] + trainer = sleap_train(cli_args) + + # Check that --keep_viz is set correctly + assert trainer.config.outputs.keep_viz_images == ( + True if keep_viz_cli == "--keep_viz" else False + ) From 3e2bd2501f706f0694703e7a3d34de4ca00ccd57 Mon Sep 17 00:00:00 2001 From: Elise Davis Date: Thu, 18 Jul 2024 17:39:16 -0700 Subject: [PATCH 07/27] Allowing inference on multiple videos via `sleap-track` (#1784) * implementing proposed code changes from issue #1777 * comments * configuring output_path to support multiple video inputs * fixing errors from preexisting test cases * Test case / code fixes * extending test cases for mp4 folders * test case for output directory * black and code rabbit fixes * code rabbit fixes * as_posix errors resolved * syntax error * adding test data * black * output error resolved * edited for push to dev branch * black * errors fixed, test cases implemented * invalid output test and invalid input test * deleting debugging statements * deleting print statements * black * deleting unnecessary test case * implemented tmpdir * deleting extraneous file * fixing broken test case * fixing test_sleap_track_invalid_output * removing support for multiple slp files * implementing talmo's comments * adding comments --- sleap/nn/inference.py | 299 ++++++++++++++++++++++++++----------- tests/nn/test_inference.py | 274 ++++++++++++++++++++++++++++++++- 2 files changed, 486 insertions(+), 87 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index af8ef2c6c..7f9e91ec9 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -5288,12 +5288,9 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: A tuple of `(provider, data_path)` with the data `Provider` and path to the data that was specified in the args. """ + # Figure out which input path to use. - labels_path = getattr(args, "labels", None) - if labels_path is not None: - data_path = labels_path - else: - data_path = args.data_path + data_path = args.data_path if data_path is None or data_path == "": raise ValueError( @@ -5301,33 +5298,73 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: "Run 'sleap-track -h' to see full command documentation." ) - if data_path.endswith(".slp"): - labels = sleap.load_file(data_path) - - if args.only_labeled_frames: - provider = LabelsReader.from_user_labeled_frames(labels) - elif args.only_suggested_frames: - provider = LabelsReader.from_unlabeled_suggestions(labels) - elif getattr(args, "video.index") != "": - provider = VideoReader( - video=labels.videos[int(getattr(args, "video.index"))], - example_indices=frame_list(args.frames), - ) - else: - provider = LabelsReader(labels) + data_path_obj = Path(data_path) + + # Check that input value is valid + if not data_path_obj.exists(): + raise ValueError("Path to data_path does not exist") + + # Check for multiple video inputs + # Compile file(s) into a list for later itteration + if data_path_obj.is_dir(): + data_path_list = [] + for file_path in data_path_obj.iterdir(): + if file_path.is_file(): + data_path_list.append(Path(file_path)) + elif data_path_obj.is_file(): + data_path_list = [data_path_obj] + + # Provider list to accomodate multiple video inputs + output_provider_list = [] + output_data_path_list = [] + for file_path in data_path_list: + # Create a provider for each file + if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1: + print(f"slp file skipped: {file_path.as_posix()}") + + elif file_path.as_posix().endswith(".slp"): + labels = sleap.load_file(file_path.as_posix()) + + if args.only_labeled_frames: + output_provider_list.append( + LabelsReader.from_user_labeled_frames(labels) + ) + elif args.only_suggested_frames: + output_provider_list.append( + LabelsReader.from_unlabeled_suggestions(labels) + ) + elif getattr(args, "video.index") != "": + output_provider_list.append( + VideoReader( + video=labels.videos[int(getattr(args, "video.index"))], + example_indices=frame_list(args.frames), + ) + ) + else: + output_provider_list.append(LabelsReader(labels)) - else: - print(f"Video: {data_path}") - # TODO: Clean this up. - video_kwargs = dict( - dataset=vars(args).get("video.dataset"), - input_format=vars(args).get("video.input_format"), - ) - provider = VideoReader.from_filepath( - filename=data_path, example_indices=frame_list(args.frames), **video_kwargs - ) + output_data_path_list.append(file_path) - return provider, data_path + else: + try: + video_kwargs = dict( + dataset=vars(args).get("video.dataset"), + input_format=vars(args).get("video.input_format"), + ) + output_provider_list.append( + VideoReader.from_filepath( + filename=file_path.as_posix(), + example_indices=frame_list(args.frames), + **video_kwargs, + ) + ) + print(f"Video: {file_path.as_posix()}") + output_data_path_list.append(file_path) + # TODO: Clean this up. + except Exception: + print(f"Error reading file: {file_path.as_posix()}") + + return output_provider_list, output_data_path_list def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: @@ -5422,8 +5459,6 @@ def main(args: Optional[list] = None): pprint(vars(args)) print() - output_path = args.output - # Setup devices. if args.cpu or not sleap.nn.system.is_gpu_system(): sleap.nn.system.use_cpu_only() @@ -5461,7 +5496,19 @@ def main(args: Optional[list] = None): print() # Setup data loader. - provider, data_path = _make_provider_from_cli(args) + provider_list, data_path_list = _make_provider_from_cli(args) + + output_path = args.output + + # check if output_path is valid before running inference + if ( + output_path is not None + and Path(output_path).is_file() + and len(data_path_list) > 1 + ): + raise ValueError( + "output_path argument must be a directory if multiple video inputs are given" + ) # Setup tracker. tracker = _make_tracker_from_cli(args) @@ -5469,35 +5516,148 @@ def main(args: Optional[list] = None): if args.models is not None and "movenet" in args.models[0]: args.models = args.models[0] - # Either run inference (and tracking) or just run tracking + # Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run) if args.models is not None: - # Setup models. - predictor = _make_predictor_from_cli(args) - predictor.tracker = tracker - # Run inference! - labels_pr = predictor.predict(provider) + # Run inference on all files inputed + for data_path, provider in zip(data_path_list, provider_list): + # Setup models. + data_path_obj = Path(data_path) + predictor = _make_predictor_from_cli(args) + predictor.tracker = tracker + + # Run inference! + labels_pr = predictor.predict(provider) - if output_path is None: - output_path = data_path + ".predictions.slp" + # if output path was not provided, create an output path + if output_path is None: + output_path = f"{data_path.as_posix()}.predictions.slp" + output_path_obj = Path(output_path) - labels_pr.provenance["model_paths"] = predictor.model_paths - labels_pr.provenance["predictor"] = type(predictor).__name__ + else: + output_path_obj = Path(output_path) + # if output_path was provided and multiple inputs were provided, create a directory to store outputs + if len(data_path_list) > 1: + output_path = ( + output_path_obj + / data_path_obj.with_suffix(".predictions.slp").name + ) + output_path_obj = Path(output_path) + # Create the containing directory if needed. + output_path_obj.parent.mkdir(exist_ok=True, parents=True) + + labels_pr.provenance["model_paths"] = predictor.model_paths + labels_pr.provenance["predictor"] = type(predictor).__name__ + + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path_obj.as_posix() + labels_pr.provenance["output_path"] = output_path_obj.as_posix() + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + print("Provenance:") + pprint(labels_pr.provenance) + print() + + labels_pr.provenance["args"] = vars(args) + + # Save results. + labels_pr.save(output_path) + print("Saved output:", output_path) + + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) + + # Reset output_path for next iteration + output_path = args.output + + # running tracking on existing prediction file elif getattr(args, "tracking.tracker") is not None: - # Load predictions - print("Loading predictions...") - labels_pr = sleap.load_file(args.data_path) - frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + for data_path, provider in zip(data_path_list, provider_list): + # Load predictions + data_path_obj = Path(data_path) + print("Loading predictions...") + labels_pr = sleap.load_file(data_path_obj.as_posix()) + frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + + print("Starting tracker...") + frames = run_tracker(frames=frames, tracker=tracker) + tracker.final_pass(frames) + + labels_pr = Labels(labeled_frames=frames) + + if output_path is None: + output_path = f"{data_path}.{tracker.get_name()}.slp" + output_path_obj = Path(output_path) + + else: + output_path_obj = Path(output_path) + if ( + output_path_obj.exists() + and output_path_obj.is_file() + and len(data_path_list) > 1 + ): + raise ValueError( + "output_path argument must be a directory if multiple video inputs are given" + ) - print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + elif not output_path_obj.exists() and len(data_path_list) > 1: + output_path = output_path_obj / data_path_obj.with_suffix( + ".predictions.slp" + ) + output_path_obj = Path(output_path) + output_path_obj.parent.mkdir(exist_ok=True, parents=True) + + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path_obj.as_posix() + labels_pr.provenance["output_path"] = output_path_obj.as_posix() + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + print("Provenance:") + pprint(labels_pr.provenance) + print() + + labels_pr.provenance["args"] = vars(args) - labels_pr = Labels(labeled_frames=frames) + # Save results. + labels_pr.save(output_path) + print("Saved output:", output_path) - if output_path is None: - output_path = f"{data_path}.{tracker.get_name()}.slp" + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) + + # Reset output_path for next iteration + output_path = args.output else: raise ValueError( @@ -5506,36 +5666,3 @@ def main(args: Optional[list] = None): "To retrack on predictions, must specify tracker. " "Use \"sleap-track --tracking.tracker ...' to specify tracker to use." ) - - if args.no_empty_frames: - # Clear empty frames if specified. - labels_pr.remove_empty_frames() - - finish_timestamp = str(datetime.now()) - total_elapsed = time() - t0 - print("Finished inference at:", finish_timestamp) - print(f"Total runtime: {total_elapsed} secs") - print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") - - # Add provenance metadata to predictions. - labels_pr.provenance["sleap_version"] = sleap.__version__ - labels_pr.provenance["platform"] = platform.platform() - labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path - labels_pr.provenance["output_path"] = output_path - labels_pr.provenance["total_elapsed"] = total_elapsed - labels_pr.provenance["start_timestamp"] = start_timestamp - labels_pr.provenance["finish_timestamp"] = finish_timestamp - - print("Provenance:") - pprint(labels_pr.provenance) - print() - - labels_pr.provenance["args"] = vars(args) - - # Save results. - labels_pr.save(output_path) - print("Saved output:", output_path) - - if args.open_in_gui: - subprocess.call(["sleap-label", output_path]) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 1b0f88c7c..f99f136ab 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -3,6 +3,7 @@ import zipfile from pathlib import Path from typing import cast +import shutil import numpy as np import pytest @@ -1447,7 +1448,49 @@ def test_make_predictor_from_cli( assert predictor.max_instances == 5 -def test_sleap_track( +def test_make_predictor_from_cli_mult_input( + centered_pair_predictions: Labels, + min_centroid_model_path: str, + min_centered_instance_model_path: str, + min_bottomup_model_path: str, + tmpdir, +): + slp_path = tmpdir.mkdir("slp_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name for the video + + # Construct the destination path with a unique name for the SLP file + slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" + shutil.copy(slp_file, slp_dest_path) + + # Create sleap-track command + model_args = [ + f"--model {min_centroid_model_path} --model {min_centered_instance_model_path}", + f"--model {min_bottomup_model_path}", + ] + for model_arg in model_args: + args = ( + f"{slp_path} {model_arg} --video.index 0 --frames 1-3 " + "--cpu --max_instances 5" + ).split() + parser = _make_cli_parser() + args, _ = parser.parse_known_args(args=args) + + # Create predictor + predictor = _make_predictor_from_cli(args=args) + if isinstance(predictor, TopDownPredictor): + assert predictor.inference_model.centroid_crop.max_instances == 5 + elif isinstance(predictor, BottomUpPredictor): + assert predictor.max_instances == 5 + + +def test_sleap_track_single_input( centered_pair_predictions: Labels, min_centroid_model_path: str, min_centered_instance_model_path: str, @@ -1475,6 +1518,235 @@ def test_sleap_track( sleap_track(args=args) +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_slp( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + tmpdir, + centered_pair_predictions: Labels, + tracking, +): + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("slp_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + slp_path_obj = Path(slp_path) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name for the video + + # Construct the destination path with a unique name for the SLP file + slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" + shutil.copy(slp_file, slp_dest_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = { + ".mp4", + } # Add other video formats if necessary + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = f"{file_path}.predictions.slp" + assert Path(expected_output_file).exists() + + +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_slp_mp4( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tracking, + tmpdir, + centered_pair_predictions: Labels, +): + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("slp_mp4_directory") + + slp_file = slp_path / "old_slp.slp" + Labels.save(centered_pair_predictions, slp_file) + + # Copy and paste the video into temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + for file_path in slp_path_list: + if file_path.suffix == ".mp4": + expected_output_file = f"{file_path}.predictions.slp" + assert Path(expected_output_file).exists() + + +@pytest.mark.parametrize("tracking", ["simple", "flow", "None"]) +def test_sleap_track_mult_input_mp4( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tracking, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker {tracking} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + for file_path in slp_path_list: + if file_path.suffix == ".mp4": + expected_output_file = f"{file_path}.predictions.slp" + assert Path(expected_output_file).exists() + + +def test_sleap_track_output_mult( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + output_path = tmpdir.mkdir("output_directory") + output_path_obj = Path(output_path) + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"-o {output_path} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + slp_path = Path(slp_path) + + # Check if there are any files in the directory + for file_path in slp_path_list: + if file_path.suffix == ".mp4": + expected_output_file = output_path_obj / ( + file_path.stem + ".predictions.slp" + ) + assert Path(expected_output_file).exists() + + +def test_sleap_track_invalid_output( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + centered_pair_predictions: Labels, + tmpdir, +): + + output_path = Path(tmpdir, "output_file.slp").as_posix() + Labels.save(centered_pair_predictions, output_path) + + # Create temporary directory with the structured video files + slp_path = tmpdir.mkdir("mp4_directory") + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"-o {output_path} " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference + with pytest.raises(ValueError): + sleap_track(args=args) + + +def test_sleap_track_invalid_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, +): + + slp_path = "" + + # Create sleap-track command + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference + with pytest.raises(ValueError): + sleap_track(args=args) + + def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): """Test flow tracker instances are pruned.""" labels: Labels = centered_pair_predictions From 38a5ca785d02bae74e09b4102635d6711f096c46 Mon Sep 17 00:00:00 2001 From: getzze Date: Tue, 23 Jul 2024 17:48:27 +0100 Subject: [PATCH 08/27] Add object keypoint similarity method (#1003) * Add object keypoint similarity method * fix max_tracking * correct off-by-one error * correct off-by-one error --- sleap/config/pipeline_form.yaml | 44 ++++++++++-- sleap/gui/learning/runners.py | 8 +++ sleap/nn/tracker/components.py | 94 ++++++++++++++++++++++++- sleap/nn/tracking.py | 103 +++++++++++++++++++++++----- tests/fixtures/datasets.py | 7 ++ tests/nn/test_inference.py | 10 +-- tests/nn/test_tracker_components.py | 66 ++++++++++++++++-- 7 files changed, 300 insertions(+), 32 deletions(-) diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index c730fa9c4..d130b9cb9 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -52,7 +52,7 @@ training: This pipeline uses two models: a "centroid" model to locate and crop around each animal in the frame, and a "centered-instance confidence map" model for predicted node locations - for each individual animal predicted by the centroid model.' + for each individual animal predicted by the centroid model.' - label: Max Instances name: max_instances type: optional_int @@ -217,7 +217,7 @@ training: - name: controller_port label: Controller Port type: int - default: 9000 + default: 9000 range: 1024,65535 - name: publish_port @@ -388,7 +388,7 @@ inference: tracking-only: - name: batch_size - label: Batch Size + label: Batch Size type: int default: 4 range: 1,512 @@ -439,7 +439,7 @@ inference: label: Similarity Method type: list default: instance - options: instance,centroid,iou + options: "instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -478,6 +478,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks @@ -521,8 +537,8 @@ inference: - name: tracking.similarity label: Similarity Method type: list - default: iou - options: instance,centroid,iou + default: instance + options: "instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -557,6 +573,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 7569607a0..d0bb1f3ba 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -260,12 +260,20 @@ def make_predict_cli_call( "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", + "tracking.oks_score_weighting", ) for key in bool_items_as_ints: if key in self.inference_params: self.inference_params[key] = int(self.inference_params[key]) + remove_spaces_items = ("tracking.similarity",) + + for key in remove_spaces_items: + if key in self.inference_params: + value = self.inference_params[key] + self.inference_params[key] = value.replace(" ", "_") + for key, val in self.inference_params.items(): if not key.startswith(("_", "outputs.", "model.", "data.")): cli_args.extend((f"--{key}", str(val))) diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index 10b2953b7..b2f35b21f 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -14,7 +14,8 @@ """ import operator from collections import defaultdict -from typing import List, Tuple, Optional, TypeVar, Callable +import logging +from typing import List, Tuple, Union, Optional, TypeVar, Callable import attr import numpy as np @@ -23,6 +24,8 @@ from sleap import PredictedInstance, Instance, Track from sleap.nn import utils +logger = logging.getLogger(__name__) + InstanceType = TypeVar("InstanceType", Instance, PredictedInstance) @@ -40,6 +43,95 @@ def instance_similarity( return similarity +def factory_object_keypoint_similarity( + keypoint_errors: Optional[Union[List, int, float]] = None, + score_weighting: bool = False, + normalization_keypoints: str = "all", +) -> Callable: + """Factory for similarity function based on object keypoints. + + Args: + keypoint_errors: The standard error of the distance between the predicted + keypoint and the true value, in pixels. + If None or empty list, defaults to 1. + If a scalar or singleton list, every keypoint has the same error. + If a list, defines the error for each keypoint, the length should be equal + to the number of keypoints in the skeleton. + score_weighting: If True, use `score` of `PredictedPoint` to weigh + `keypoint_errors`. If False, do not add a weight to `keypoint_errors`. + normalization_keypoints: Determine how to normalize similarity score. One of + ["all", "ref", "union"]. If "all", similarity score is normalized by number + of reference points. If "ref", similarity score is normalized by number of + visible reference points. If "union", similarity score is normalized by + number of points both visible in query and reference instance. + Default is "all". + + Returns: + Callable that returns object keypoint similarity between two `Instance`s. + + """ + keypoint_errors = 1 if keypoint_errors is None else keypoint_errors + with np.errstate(divide="ignore"): + kp_precision = 1 / (2 * np.array(keypoint_errors) ** 2) + + def object_keypoint_similarity( + ref_instance: InstanceType, query_instance: InstanceType + ) -> float: + nonlocal kp_precision + # Keypoints + ref_points = ref_instance.points_array + query_points = query_instance.points_array + # Keypoint scores + if score_weighting: + ref_scores = getattr(ref_instance, "scores", np.ones(len(ref_points))) + query_scores = getattr(query_instance, "scores", np.ones(len(query_points))) + else: + ref_scores = 1 + query_scores = 1 + # Number of keypoint for normalization + if normalization_keypoints in ("ref", "union"): + ref_visible = ~(np.isnan(ref_points).any(axis=1)) + if normalization_keypoints == "ref": + max_n_keypoints = np.sum(ref_visible) + elif normalization_keypoints == "union": + query_visible = ~(np.isnan(query_points).any(axis=1)) + max_n_keypoints = np.sum(np.logical_and(ref_visible, query_visible)) + else: # if normalization_keypoints == "all": + max_n_keypoints = len(ref_points) + if max_n_keypoints == 0: + return 0 + + # Make sure the sizes of kp_precision and n_points match + if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size: + # Correct kp_precision size to fit number of points + n_points = ref_points.size // 2 + mess = ( + "keypoint_errors array should have the same size as the number of " + f"keypoints in the instance: {kp_precision.size} != {n_points}" + ) + + if kp_precision.size > n_points: + kp_precision = kp_precision[:n_points] + mess += "\nTruncating keypoint_errors array." + + else: # elif kp_precision.size < n_points: + pad = n_points - kp_precision.size + kp_precision = np.pad(kp_precision, (0, pad), "edge") + mess += "\nPadding keypoint_errors array by repeating the last value." + logger.warning(mess) + + # Compute distances + dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision + + similarity = ( + np.nansum(ref_scores * query_scores * np.exp(-dists)) / max_n_keypoints + ) + + return similarity + + return object_keypoint_similarity + + def centroid_distance( ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict() ) -> float: diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 9865b7db5..2b02839de 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -10,6 +10,7 @@ from sleap import Track, LabeledFrame, Skeleton from sleap.nn.tracker.components import ( + factory_object_keypoint_similarity, instance_similarity, centroid_distance, instance_iou, @@ -391,6 +392,7 @@ def get_ref_instances( def get_candidates( self, track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], + max_tracking: bool, t: int, img: np.ndarray, *args, @@ -404,7 +406,7 @@ def get_candidates( tracks = [] for track, matched_items in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) < self.max_tracks: tracks.append(track) for matched_item in matched_items: ref_t, ref_img = ( @@ -466,6 +468,7 @@ class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): def get_candidates( self, track_matching_queue_dict: Dict, + max_tracking: bool, *args, **kwargs, ) -> List[InstanceType]: @@ -473,7 +476,7 @@ def get_candidates( candidate_instances = [] tracks = [] for track, matched_instances in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) < self.max_tracks: tracks.append(track) for ref_instance in matched_instances: if ref_instance.instance_t.n_visible_points >= self.min_points: @@ -492,6 +495,7 @@ def get_candidates( instance=instance_similarity, centroid=centroid_distance, iou=instance_iou, + object_keypoint=instance_similarity, ) match_policies = dict( @@ -598,8 +602,15 @@ def _init_matching_queue(self): """Factory for instantiating default matching queue with specified size.""" return deque(maxlen=self.track_window) + @property + def has_max_tracking(self) -> bool: + return isinstance( + self.candidate_maker, + (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker), + ) + def reset_candidates(self): - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict: self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) else: @@ -610,14 +621,15 @@ def unique_tracks_in_queue(self) -> List[Track]: """Returns the unique tracks in the matching queue.""" unique_tracks = set() - for match_item in self.track_matching_queue: - for instance in match_item.instances_t: - unique_tracks.add(instance.track) - - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict.keys(): unique_tracks.add(track) + else: + for match_item in self.track_matching_queue: + for instance in match_item.instances_t: + unique_tracks.add(instance.track) + return list(unique_tracks) @property @@ -646,7 +658,7 @@ def track( # Infer timestep if not provided. if t is None: - if self.max_tracking: + if self.has_max_tracking: if len(self.track_matching_queue_dict) > 0: # Default to last timestep + 1 if available. @@ -684,10 +696,10 @@ def track( self.pre_cull_function(untracked_instances) # Build a pool of matchable candidate instances. - if self.max_tracking: + if self.has_max_tracking: candidate_instances = self.candidate_maker.get_candidates( track_matching_queue_dict=self.track_matching_queue_dict, - max_tracks=self.max_tracks, + max_tracking=self.max_tracking, t=t, img=img, ) @@ -721,13 +733,16 @@ def track( ) # Add the tracked instances to the dictionary of matched instances. - if self.max_tracking: + if self.has_max_tracking: for tracked_instance in tracked_instances: if tracked_instance.track in self.track_matching_queue_dict: self.track_matching_queue_dict[tracked_instance.track].append( MatchedFrameInstance(t, tracked_instance, img) ) - elif len(self.track_matching_queue_dict) < self.max_tracks: + elif ( + not self.max_tracking + or len(self.track_matching_queue_dict) < self.max_tracks + ): self.track_matching_queue_dict[tracked_instance.track] = deque( maxlen=self.track_window ) @@ -773,7 +788,8 @@ def spawn_for_untracked_instances( # Skip if we've reached the maximum number of tracks. if ( - self.max_tracking + self.has_max_tracking + and self.max_tracking and len(self.track_matching_queue_dict) >= self.max_tracks ): break @@ -838,8 +854,17 @@ def make_tracker_by_name( # Max tracking options max_tracks: Optional[int] = None, max_tracking: bool = False, + # Object keypoint similarity options + oks_errors: Optional[list] = None, + oks_score_weighting: bool = False, + oks_normalization: str = "all", **kwargs, ) -> BaseTracker: + # Parse max_tracking arguments, only True if max_tracks is not None and > 0 + max_tracking = max_tracking if max_tracks else False + if max_tracking and tracker in ("simple", "flow"): + # Force a candidate maker of 'maxtracks' type + tracker += "maxtracks" if tracker.lower() == "none": candidate_maker = None @@ -858,7 +883,14 @@ def make_tracker_by_name( raise ValueError(f"{match} is not a valid tracker matching function.") candidate_maker = tracker_policies[tracker](min_points=min_match_points) - similarity_function = similarity_policies[similarity] + if similarity == "object_keypoint": + similarity_function = factory_object_keypoint_similarity( + keypoint_errors=oks_errors, + score_weighting=oks_score_weighting, + normalization_keypoints=oks_normalization, + ) + else: + similarity_function = similarity_policies[similarity] matching_function = match_policies[match] if tracker == "flow": @@ -931,7 +963,10 @@ def get_by_name_factory_options(cls): option = dict(name="max_tracking", default=False) option["type"] = bool - option["help"] = "If true then the tracker will cap the max number of tracks." + option["help"] = ( + "If true then the tracker will cap the max number of tracks. " + "Falls back to false if `max_tracks` is not defined or 0." + ) options.append(option) option = dict(name="max_tracks", default=None) @@ -1054,6 +1089,42 @@ def int_list_func(s): ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." options.append(option) + def float_list_func(s): + return [float(x.strip()) for x in s.split(",")] if s else None + + option = dict(name="oks_errors", default="1") + option["type"] = float_list_func + option["help"] = ( + "For Object Keypoint similarity: the standard error of the distance " + "between the predicted keypoint and the true value, in pixels.\n" + "If None or empty list, defaults to 1. If a scalar or singleton list, " + "every keypoint has the same error. If a list, defines the error for each " + "keypoint, the length should be equal to the number of keypoints in the " + "skeleton." + ) + options.append(option) + + option = dict(name="oks_score_weighting", default="0") + option["type"] = int + option["help"] = ( + "For Object Keypoint similarity: if 0 (default), only the distance between the reference " + "and query keypoint is used to compute the similarity. If 1, each distance is weighted " + "by the prediction scores of the reference and query keypoint." + ) + options.append(option) + + option = dict(name="oks_normalization", default="all") + option["type"] = str + option["options"] = ["all", "ref", "union"] + option["help"] = ( + "For Object Keypoint similarity: Determine how to normalize similarity score. " + "If 'all', similarity score is normalized by number of reference points. " + "If 'ref', similarity score is normalized by number of visible reference points. " + "If 'union', similarity score is normalized by number of points both visible " + "in query and reference instance." + ) + options.append(option) + return options @classmethod diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 801fcc092..ec5dfbc29 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -41,6 +41,13 @@ def centered_pair_predictions(): return Labels.load_file(TEST_JSON_PREDICTIONS) +@pytest.fixture +def centered_pair_predictions_sorted(centered_pair_predictions): + labels: Labels = centered_pair_predictions + labels.labeled_frames.sort(key=lambda lf: lf.frame_idx) + return labels + + @pytest.fixture def min_labels(): return Labels.load_file(TEST_JSON_MIN_LABELS) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index f99f136ab..98f5fbcec 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1373,7 +1373,7 @@ def test_retracking( # Create sleap-track command cmd = ( f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 " - "--cpu" + "--tracking.similarity object_keypoint --cpu" ) if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" @@ -1393,6 +1393,8 @@ def test_retracking( parser = _make_cli_parser() args, _ = parser.parse_known_args(args=args) tracker = _make_tracker_from_cli(args) + # Additional check for similarity method + assert tracker.similarity_function.__name__ == "object_keypoint_similarity" output_path = f"{slp_path}.{tracker.get_name()}.slp" # Assert tracked predictions file exists @@ -1747,9 +1749,9 @@ def test_sleap_track_invalid_input( sleap_track(args=args) -def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): +def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): """Test flow tracker instances are pruned.""" - labels: Labels = centered_pair_predictions + labels: Labels = centered_pair_predictions_sorted track_window = 5 # Setup tracker @@ -1759,7 +1761,7 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker) # Run tracking - frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) + frames = labels.labeled_frames # Run tracking on subset of frames using psuedo-implementation of # sleap.nn.tracking.run_tracker diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index f861241ee..5786945fb 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -9,23 +9,79 @@ FrameMatches, greedy_matching, ) +from sleap.io.dataset import Labels from sleap.instance import PredictedInstance from sleap.skeleton import Skeleton +def tracker_by_name(frames=None, **kwargs): + t = Tracker.make_tracker_by_name(**kwargs) + print(kwargs) + print(t.candidate_maker) + if frames is None: + t.track([]) + t.final_pass([]) + return + + for lf in frames: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + t.track(**track_args) + t.final_pass(frames) + + @pytest.mark.parametrize( "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] ) @pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) -def test_tracker_by_name(tracker, similarity, match, count): - t = Tracker.make_tracker_by_name( - "flow", "instance", "greedy", clean_instance_count=2 +def test_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + similarity, + match, + count, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + + tracker_by_name( + frames=frames, + tracker=tracker, + similarity=similarity, + match=match, + max_tracks=count, + ) + + +@pytest.mark.parametrize( + "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] +) +@pytest.mark.parametrize("oks_score_weighting", ["True", "False"]) +@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"]) +def test_oks_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + oks_score_weighting, + oks_normalization, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + + tracker_by_name( + frames=frames, + tracker=tracker, + similarity="object_keypoint", + matching="greedy", + oks_score_weighting=oks_score_weighting, + oks_normalization=oks_normalization, + max_tracks=2, ) - t.track([]) - t.final_pass([]) def test_cull_instances(centered_pair_predictions): From 1581506ce888647dfc2cff07569026137e60a45f Mon Sep 17 00:00:00 2001 From: gqcpm <63070177+gqcpm@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:51:35 -0700 Subject: [PATCH 09/27] Generate suggestions using max point displacement threshold (#1862) * create function max_point_displacement, _max_point_displacement_video. Add to yaml file. Create test for new function . . . will need to edit * remove unnecessary for loop, calculate proper displacement, adjusted tests accordingly * Increase range for displacement threshold * Fix frames not found bug * Return the latter frame index * Lint --------- Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- sleap/config/suggestions.yaml | 9 +++++- sleap/gui/suggestions.py | 52 +++++++++++++++++++++++++++++++++++ tests/gui/test_suggestions.py | 14 ++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/sleap/config/suggestions.yaml b/sleap/config/suggestions.yaml index 8cf89728a..1440530fc 100644 --- a/sleap/config/suggestions.yaml +++ b/sleap/config/suggestions.yaml @@ -3,7 +3,7 @@ main: label: Method type: stacked default: " " - options: " ,image features,sample,prediction score,velocity,frame chunk" + options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement" " ": sample: @@ -175,6 +175,13 @@ main: type: double default: 0.1 range: 0.1,1.0 + + "max point displacement": + - name: displacement_threshold + label: Maximum Displacement Threshold + type: int + default: 10 + range: 0,999 - name: target label: Target diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 48b916437..b85d6ac32 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame] prediction_score=cls.prediction_score, velocity=cls.velocity, frame_chunk=cls.frame_chunk, + max_point_displacement=cls.max_point_displacement, ) method = str.replace(params["method"], " ", "_") @@ -213,6 +214,7 @@ def _prediction_score_video( ): lfs = labels.find(video) frames = len(lfs) + # initiate an array filled with -1 to store frame index (starting from 0). idxs = np.full((frames), -1, dtype="int") @@ -291,6 +293,56 @@ def _velocity_video( return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod + def max_point_displacement( + cls, + labels: "Labels", + videos: List[Video], + displacement_threshold: float, + **kwargs, + ): + """Finds frames with maximum point displacement above a threshold.""" + + proposed_suggestions = [] + for video in videos: + proposed_suggestions.extend( + cls._max_point_displacement_video(video, labels, displacement_threshold) + ) + + suggestions = VideoFrameSuggestions.filter_unique_suggestions( + labels, videos, proposed_suggestions + ) + + return suggestions + + @classmethod + def _max_point_displacement_video( + cls, video: Video, labels: "Labels", displacement_threshold: float + ): + # Get numpy of shape (frames, tracks, nodes, x, y) + labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False) + + # Return empty list if not enough frames + n_frames, n_tracks, n_nodes, _ = labels_numpy.shape + + if n_frames < 2: + return [] + + # Calculate displacements + diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y) + euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes) + mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks) + + # Find frames where mean displacement is above threshold + threshold_mask = np.any( + mean_euc_norm > displacement_threshold, axis=-1 + ) # (frames - 1,) + frame_idxs = list( + np.argwhere(threshold_mask).flatten() + 1 + ) # [0, len(frames - 1)] + + return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod def frame_chunk( cls, diff --git a/tests/gui/test_suggestions.py b/tests/gui/test_suggestions.py index bbad73179..196ff2d35 100644 --- a/tests/gui/test_suggestions.py +++ b/tests/gui/test_suggestions.py @@ -24,6 +24,20 @@ def test_velocity_suggestions(centered_pair_predictions): assert suggestions[1].frame_idx == 45 +def test_max_point_displacement_suggestions(centered_pair_predictions): + suggestions = VideoFrameSuggestions.suggest( + labels=centered_pair_predictions, + params=dict( + videos=centered_pair_predictions.videos, + method="max_point_displacement", + displacement_threshold=6, + ), + ) + assert len(suggestions) == 19 + assert suggestions[0].frame_idx == 28 + assert suggestions[1].frame_idx == 82 + + def test_frame_increment(centered_pair_predictions: Labels): # Testing videos that have less frames than desired Samples per Video (stride) # Expected result is there should be n suggestions where n is equal to the frames From 28c34e22e0cf78e1774476d0ac76c7ea0b4814fe Mon Sep 17 00:00:00 2001 From: Andrew Park Date: Thu, 25 Jul 2024 20:15:04 -0700 Subject: [PATCH 10/27] Added Three Different Cases for Adding a New Instance (#1859) * implemented paste with offset * right click and then default will paste the new instance at the location of the cursor * modified the logics for creating new instance * refined the logic * fixed the logic for right click * refined logics for adding new instance at a specific location * Remove print statements * Comment code * Ensure that we choose a non nan reference node * Move OOB nodes to closest in-bounds position --------- Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- sleap/gui/commands.py | 61 ++++++++++++++++++++++++++++++++++++-- sleap/gui/widgets/video.py | 5 +++- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 1a64a071c..8df85fc8e 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -2913,6 +2913,8 @@ def create_new_instance( copy_instance=copy_instance, new_instance=new_instance, mark_complete=mark_complete, + init_method=init_method, + location=location, ) if has_missing_nodes: @@ -2984,6 +2986,8 @@ def set_visible_nodes( copy_instance: Optional[Union[Instance, PredictedInstance]], new_instance: Instance, mark_complete: bool, + init_method: str, + location: Optional[QtCore.QPoint] = None, ) -> bool: """Sets visible nodes for new instance. @@ -3010,6 +3014,25 @@ def set_visible_nodes( scale_width = new_size_width / old_size_width scale_height = new_size_height / old_size_height + # Default the offset is 0 + offset_x = 0 + offset_y = 0 + + # Using the menu or the hotkey + if init_method == "best": + offset_x = 10 + offset_y = 10 + + # Using right click and context menu + if location is not None: + reference_node = next( + (node for node in copy_instance if not node.isnan()), None + ) + reference_x = reference_node.x + reference_y = reference_node.y + offset_x = location.x() - (reference_x * scale_width) + offset_y = location.y() - (reference_y * scale_height) + # Go through each node in skeleton. for node in context.state["skeleton"].node_names: # If we're copying from a skeleton that has this node. @@ -3018,13 +3041,45 @@ def set_visible_nodes( # We don't want to copy a PredictedPoint or score attribute. x_old = copy_instance[node].x y_old = copy_instance[node].y - x_new = x_old * scale_width - y_new = y_old * scale_height + # Copy the instance without scale or offset if predicted + if isinstance(copy_instance, PredictedInstance): + x_new = x_old + y_new = y_old + else: + x_new = x_old * scale_width + y_new = y_old * scale_height + + # Apply offset if in bounds + x_new_offset = x_new + offset_x + y_new_offset = y_new + offset_y + + # Default visibility is same as copied instance. + visible = copy_instance[node].visible + + # If the node is offset to outside the frame, mark as not visible. + if x_new_offset < 0: + x_new = 0 + visible = False + elif x_new_offset > new_size_width: + x_new = new_size_width + visible = False + else: + x_new = x_new_offset + if y_new_offset < 0: + y_new = 0 + visible = False + elif y_new_offset > new_size_height: + y_new = new_size_height + visible = False + else: + y_new = y_new_offset + + # Update the new instance with the new x, y, and visibility. new_instance[node] = Point( x=x_new, y=y_new, - visible=copy_instance[node].visible, + visible=visible, complete=mark_complete, ) else: diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 502ea388e..745908048 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -367,7 +367,10 @@ def show_contextual_menu(self, where: QtCore.QPoint): menu.addAction("Add Instance:").setEnabled(False) - menu.addAction("Default", lambda: self.context.newInstance(init_method="best")) + menu.addAction( + "Default", + lambda: self.context.newInstance(init_method="best", location=scene_pos), + ) menu.addAction( "Average", From d3ad2261990a30effdbe568290c772fd3f6d5cc9 Mon Sep 17 00:00:00 2001 From: Elise Davis Date: Wed, 31 Jul 2024 13:09:35 -0700 Subject: [PATCH 11/27] Allow csv and text file support on sleap track (#1875) * initial changes * csv support and test case * increased code coverage * Error fixing, black, deletion of (self-written) unused code * final edits * black * documentation changes * documentation changes --- docs/guides/cli.md | 7 +- sleap/nn/inference.py | 256 +++++++++++++++++++++---------------- tests/nn/test_inference.py | 190 +++++++++++++++++++++++++-- 3 files changed, 328 insertions(+), 125 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index ab62f3130..03b806903 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -138,7 +138,10 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [- [data_path] positional arguments: - data_path Path to data to predict on. This can be a labels (.slp) file or any supported video format. + data_path Path to data to predict on. This can be one of the following: A .slp file containing labeled data; A folder containing multiple + video files in supported formats; An individual video file in a supported format; A CSV file with a column of video file paths. + If more than one column is provided in the CSV file, the first will be used for the input data paths and the next column will be + used as the output paths; A text file with a path to a video file on each line optional arguments: -h, --help show this help message and exit @@ -153,7 +156,7 @@ optional arguments: Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for initialization during labeling. -o OUTPUT, --output OUTPUT - The output filename to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'. + The output filename or directory path to use for the predicted data. If not provided, defaults to '[data_path].predictions.slp'. --no-empty-frames Clear any empty frames that did not have any detected instances before saving to output. --verbosity {none,rich,json} Verbosity of inference progress reporting. 'none' does not output anything during inference, 'rich' displays an updating diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 7f9e91ec9..421378d56 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -33,6 +33,7 @@ import atexit import subprocess import rich.progress +import pandas as pd from rich.pretty import pprint from collections import deque import json @@ -5285,8 +5286,10 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: args: Parsed CLI namespace. Returns: - A tuple of `(provider, data_path)` with the data `Provider` and path to the data - that was specified in the args. + `(provider_list, data_path_list, output_path_list)` where `provider_list` contains the data providers, + `data_path_list` contains the paths to the specified data, and the `output_path_list` contains the list + of output paths if a CSV file with a column of output paths was provided; otherwise, `output_path_list` + defaults to None """ # Figure out which input path to use. @@ -5300,50 +5303,94 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: data_path_obj = Path(data_path) + # Set output_path_list to None as a default to return later + output_path_list = None + # Check that input value is valid if not data_path_obj.exists(): raise ValueError("Path to data_path does not exist") - # Check for multiple video inputs - # Compile file(s) into a list for later itteration - if data_path_obj.is_dir(): - data_path_list = [] - for file_path in data_path_obj.iterdir(): - if file_path.is_file(): - data_path_list.append(Path(file_path)) elif data_path_obj.is_file(): - data_path_list = [data_path_obj] + # If the file is a CSV file, check for data_paths and output_paths + if data_path_obj.suffix.lower() == ".csv": + try: + data_path_column = None + # Read the CSV file + df = pd.read_csv(data_path) + + # collect data_paths from column + for col_index in range(df.shape[1]): + path_str = df.iloc[0, col_index] + if Path(path_str).exists(): + data_path_column = df.columns[col_index] + break + if data_path_column is None: + raise ValueError( + f"Column containing valid data_paths does not exist in the CSV file: {data_path}" + ) + raw_data_path_list = df[data_path_column].tolist() + + # optional output_path column to specify multiple output_paths + output_path_column_index = df.columns.get_loc(data_path_column) + 1 + if ( + output_path_column_index < df.shape[1] + and df.iloc[:, output_path_column_index].dtype == object + ): + # Ensure the next column exists + output_path_list = df.iloc[:, output_path_column_index].tolist() + else: + output_path_list = None + + except pd.errors.EmptyDataError as e: + raise ValueError(f"CSV file is empty: {data_path}. Error: {e}") from e + + # If the file is a text file, collect data_paths + elif data_path_obj.suffix.lower() == ".txt": + try: + with open(data_path_obj, "r") as file: + raw_data_path_list = [line.strip() for line in file.readlines()] + except Exception as e: + raise ValueError( + f"Error reading text file: {data_path}. Error: {e}" + ) from e + else: + raw_data_path_list = [data_path_obj.as_posix()] + + raw_data_path_list = [Path(p) for p in raw_data_path_list] + + # Check for multiple video inputs + # Compile file(s) into a list for later iteration + elif data_path_obj.is_dir(): + raw_data_path_list = [ + file_path for file_path in data_path_obj.iterdir() if file_path.is_file() + ] # Provider list to accomodate multiple video inputs - output_provider_list = [] - output_data_path_list = [] - for file_path in data_path_list: + provider_list = [] + data_path_list = [] + for file_path in raw_data_path_list: # Create a provider for each file - if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1: + if file_path.as_posix().endswith(".slp") and len(raw_data_path_list) > 1: print(f"slp file skipped: {file_path.as_posix()}") elif file_path.as_posix().endswith(".slp"): labels = sleap.load_file(file_path.as_posix()) if args.only_labeled_frames: - output_provider_list.append( - LabelsReader.from_user_labeled_frames(labels) - ) + provider_list.append(LabelsReader.from_user_labeled_frames(labels)) elif args.only_suggested_frames: - output_provider_list.append( - LabelsReader.from_unlabeled_suggestions(labels) - ) + provider_list.append(LabelsReader.from_unlabeled_suggestions(labels)) elif getattr(args, "video.index") != "": - output_provider_list.append( + provider_list.append( VideoReader( video=labels.videos[int(getattr(args, "video.index"))], example_indices=frame_list(args.frames), ) ) else: - output_provider_list.append(LabelsReader(labels)) + provider_list.append(LabelsReader(labels)) - output_data_path_list.append(file_path) + data_path_list.append(file_path) else: try: @@ -5351,7 +5398,7 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: dataset=vars(args).get("video.dataset"), input_format=vars(args).get("video.input_format"), ) - output_provider_list.append( + provider_list.append( VideoReader.from_filepath( filename=file_path.as_posix(), example_indices=frame_list(args.frames), @@ -5359,12 +5406,12 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: ) ) print(f"Video: {file_path.as_posix()}") - output_data_path_list.append(file_path) + data_path_list.append(file_path) # TODO: Clean this up. except Exception: print(f"Error reading file: {file_path.as_posix()}") - return output_provider_list, output_data_path_list + return provider_list, data_path_list, output_path_list def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: @@ -5496,19 +5543,20 @@ def main(args: Optional[list] = None): print() # Setup data loader. - provider_list, data_path_list = _make_provider_from_cli(args) + provider_list, data_path_list, output_path_list = _make_provider_from_cli(args) - output_path = args.output + output_path = None - # check if output_path is valid before running inference - if ( - output_path is not None - and Path(output_path).is_file() - and len(data_path_list) > 1 - ): - raise ValueError( - "output_path argument must be a directory if multiple video inputs are given" - ) + # if output_path has not been extracted from a csv file yet + if output_path_list is None and args.output is not None: + output_path = args.output + output_path_obj = Path(output_path) + + # check if output_path is valid before running inference + if Path(output_path).is_file() and len(data_path_list) > 1: + raise ValueError( + "output_path argument must be a directory if multiple video inputs are given" + ) # Setup tracker. tracker = _make_tracker_from_cli(args) @@ -5520,7 +5568,7 @@ def main(args: Optional[list] = None): if args.models is not None: # Run inference on all files inputed - for data_path, provider in zip(data_path_list, provider_list): + for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)): # Setup models. data_path_obj = Path(data_path) predictor = _make_predictor_from_cli(args) @@ -5531,21 +5579,25 @@ def main(args: Optional[list] = None): # if output path was not provided, create an output path if output_path is None: - output_path = f"{data_path.as_posix()}.predictions.slp" - output_path_obj = Path(output_path) + # if output path was not provided, create an output path + if output_path_list: + output_path = output_path_list[i] + + else: + output_path = data_path_obj.with_suffix(".predictions.slp") - else: output_path_obj = Path(output_path) - # if output_path was provided and multiple inputs were provided, create a directory to store outputs - if len(data_path_list) > 1: - output_path = ( - output_path_obj - / data_path_obj.with_suffix(".predictions.slp").name - ) - output_path_obj = Path(output_path) - # Create the containing directory if needed. - output_path_obj.parent.mkdir(exist_ok=True, parents=True) + # if output_path was provided and multiple inputs were provided, create a directory to store outputs + elif len(data_path_list) > 1: + output_path_obj = Path(output_path) + output_path = ( + output_path_obj + / (data_path_obj.with_suffix(".predictions.slp")).name + ) + output_path_obj = Path(output_path) + # Create the containing directory if needed. + output_path_obj.parent.mkdir(exist_ok=True, parents=True) labels_pr.provenance["model_paths"] = predictor.model_paths labels_pr.provenance["predictor"] = type(predictor).__name__ @@ -5577,7 +5629,12 @@ def main(args: Optional[list] = None): labels_pr.provenance["args"] = vars(args) # Save results. - labels_pr.save(output_path) + try: + labels_pr.save(output_path) + except Exception: + print("WARNING: Provided output path invalid.") + fallback_path = data_path_obj.with_suffix(".predictions.slp") + labels_pr.save(fallback_path) print("Saved output:", output_path) if args.open_in_gui: @@ -5588,76 +5645,57 @@ def main(args: Optional[list] = None): # running tracking on existing prediction file elif getattr(args, "tracking.tracker") is not None: - for data_path, provider in zip(data_path_list, provider_list): - # Load predictions - data_path_obj = Path(data_path) - print("Loading predictions...") - labels_pr = sleap.load_file(data_path_obj.as_posix()) - frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) + provider = provider_list[0] + data_path = data_path_list[0] - print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + # Load predictions + data_path = args.data_path + print("Loading predictions...") + labels_pr = sleap.load_file(data_path) + frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx) - labels_pr = Labels(labeled_frames=frames) + print("Starting tracker...") + frames = run_tracker(frames=frames, tracker=tracker) + tracker.final_pass(frames) - if output_path is None: - output_path = f"{data_path}.{tracker.get_name()}.slp" - output_path_obj = Path(output_path) + labels_pr = Labels(labeled_frames=frames) - else: - output_path_obj = Path(output_path) - if ( - output_path_obj.exists() - and output_path_obj.is_file() - and len(data_path_list) > 1 - ): - raise ValueError( - "output_path argument must be a directory if multiple video inputs are given" - ) + if output_path is None: + output_path = f"{data_path}.{tracker.get_name()}.slp" - elif not output_path_obj.exists() and len(data_path_list) > 1: - output_path = output_path_obj / data_path_obj.with_suffix( - ".predictions.slp" - ) - output_path_obj = Path(output_path) - output_path_obj.parent.mkdir(exist_ok=True, parents=True) + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() - if args.no_empty_frames: - # Clear empty frames if specified. - labels_pr.remove_empty_frames() + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") - finish_timestamp = str(datetime.now()) - total_elapsed = time() - t0 - print("Finished inference at:", finish_timestamp) - print(f"Total runtime: {total_elapsed} secs") - print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["command"] = " ".join(sys.argv) + labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["output_path"] = output_path + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp - # Add provenance metadata to predictions. - labels_pr.provenance["sleap_version"] = sleap.__version__ - labels_pr.provenance["platform"] = platform.platform() - labels_pr.provenance["command"] = " ".join(sys.argv) - labels_pr.provenance["data_path"] = data_path_obj.as_posix() - labels_pr.provenance["output_path"] = output_path_obj.as_posix() - labels_pr.provenance["total_elapsed"] = total_elapsed - labels_pr.provenance["start_timestamp"] = start_timestamp - labels_pr.provenance["finish_timestamp"] = finish_timestamp - - print("Provenance:") - pprint(labels_pr.provenance) - print() + print("Provenance:") + pprint(labels_pr.provenance) + print() - labels_pr.provenance["args"] = vars(args) + labels_pr.provenance["args"] = vars(args) - # Save results. - labels_pr.save(output_path) - print("Saved output:", output_path) + # Save results. + labels_pr.save(output_path) - if args.open_in_gui: - subprocess.call(["sleap-label", output_path]) + print("Saved output:", output_path) - # Reset output_path for next iteration - output_path = args.output + if args.open_in_gui: + subprocess.call(["sleap-label", output_path]) else: raise ValueError( diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 98f5fbcec..fd615ea81 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -4,12 +4,14 @@ from pathlib import Path from typing import cast import shutil +import csv import numpy as np import pytest +import pandas as pd import tensorflow as tf -import tensorflow_hub as hub from numpy.testing import assert_array_equal, assert_allclose +from sleap.io.video import available_video_exts import sleap from sleap.gui.learning import runners @@ -1511,7 +1513,7 @@ def test_sleap_track_single_input( sleap_track(args=args) # Assert predictions file exists - output_path = f"{slp_path}.predictions.slp" + output_path = Path(slp_path).with_suffix(".predictions.slp") assert Path(output_path).exists() # Create invalid sleap-track command @@ -1539,8 +1541,6 @@ def test_sleap_track_mult_input_slp( # Copy and paste the video into the temp dir multiple times num_copies = 3 for i in range(num_copies): - # Construct the destination path with a unique name for the video - # Construct the destination path with a unique name for the SLP file slp_dest_path = slp_path / f"old_slp_copy_{i}.slp" shutil.copy(slp_file, slp_dest_path) @@ -1558,13 +1558,11 @@ def test_sleap_track_mult_input_slp( sleap_track(args=args) # Assert predictions file exists - expected_extensions = { - ".mp4", - } # Add other video formats if necessary + expected_extensions = available_video_exts() for file_path in slp_path_list: if file_path.suffix in expected_extensions: - expected_output_file = f"{file_path}.predictions.slp" + expected_output_file = Path(file_path).with_suffix(".predictions.slp") assert Path(expected_output_file).exists() @@ -1604,10 +1602,11 @@ def test_sleap_track_mult_input_slp_mp4( # Run inference sleap_track(args=args) - # Assert predictions file exists + expected_extensions = available_video_exts() + for file_path in slp_path_list: - if file_path.suffix == ".mp4": - expected_output_file = f"{file_path}.predictions.slp" + if file_path.suffix in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") assert Path(expected_output_file).exists() @@ -1645,9 +1644,11 @@ def test_sleap_track_mult_input_mp4( sleap_track(args=args) # Assert predictions file exists + expected_extensions = available_video_exts() + for file_path in slp_path_list: - if file_path.suffix == ".mp4": - expected_output_file = f"{file_path}.predictions.slp" + if file_path.suffix in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") assert Path(expected_output_file).exists() @@ -1688,8 +1689,10 @@ def test_sleap_track_output_mult( slp_path = Path(slp_path) # Check if there are any files in the directory + expected_extensions = available_video_exts() + for file_path in slp_path_list: - if file_path.suffix == ".mp4": + if file_path.suffix in expected_extensions: expected_output_file = output_path_obj / ( file_path.stem + ".predictions.slp" ) @@ -1748,6 +1751,165 @@ def test_sleap_track_invalid_input( with pytest.raises(ValueError): sleap_track(args=args) + # Test with a non-existent path + slp_path = "/path/to/nonexistent/file.mp4" + + # Create sleap-track command for non-existent path + args = ( + f"{slp_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect a ValueError for non-existent path + with pytest.raises(ValueError): + sleap_track(args=args) + + +def test_sleap_track_csv_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = Path(tmpdir.mkdir("mp4_directory")) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + file_paths = [] + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + file_paths.append(dest_path) + + # Generate output paths for each data_path + output_paths = [ + file_path.with_suffix(".TESTpredictions.slp") for file_path in file_paths + ] + + # Create a CSV file with the file paths + csv_file_path = slp_path / "file_paths.csv" + with open(csv_file_path, mode="w", newline="") as csv_file: + csv_writer = csv.writer(csv_file) + csv_writer.writerow(["data_path", "output_path"]) + for data_path, output_path in zip(file_paths, output_paths): + csv_writer.writerow([data_path, output_path]) + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{csv_file_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = file_path.with_suffix(".TESTpredictions.slp") + assert Path(expected_output_file).exists() + + +def test_sleap_track_invalid_csv( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + tmpdir, +): + + # Create a CSV file with nonexistant data files + csv_nonexistant_files_path = tmpdir / "nonexistant_files.csv" + df_nonexistant_files = pd.DataFrame( + {"data_path": ["video1.mp4", "video2.mp4", "video3.mp4"]} + ) + df_nonexistant_files.to_csv(csv_nonexistant_files_path, index=False) + + # Create an empty CSV file + csv_empty_path = tmpdir / "empty.csv" + open(csv_empty_path, "w").close() + + # Create sleap-track command for missing 'data_path' column + args_missing_column = ( + f"{csv_nonexistant_files_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect ValueError for missing 'data_path' column + with pytest.raises( + ValueError, + ): + sleap_track(args=args_missing_column) + + # Create sleap-track command for empty CSV file + args_empty = ( + f"{csv_empty_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + # Run inference and expect ValueError for empty CSV file + with pytest.raises(ValueError): + sleap_track(args=args_empty) + + +def test_sleap_track_text_file_input( + min_centroid_model_path: str, + min_centered_instance_model_path: str, + centered_pair_vid_path, + tmpdir, +): + + # Create temporary directory with the structured video files + slp_path = Path(tmpdir.mkdir("mp4_directory")) + + # Copy and paste the video into the temp dir multiple times + num_copies = 3 + file_paths = [] + for i in range(num_copies): + # Construct the destination path with a unique name + dest_path = slp_path / f"centered_pair_vid_copy_{i}.mp4" + shutil.copy(centered_pair_vid_path, dest_path) + file_paths.append(dest_path) + + # Create a text file with the file paths + txt_file_path = slp_path / "file_paths.txt" + with open(txt_file_path, mode="w") as txt_file: + for file_path in file_paths: + txt_file.write(f"{file_path}\n") + + slp_path_obj = Path(slp_path) + + # Create sleap-track command + args = ( + f"{txt_file_path} --model {min_centroid_model_path} " + f"--tracking.tracker simple " + f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu" + ).split() + + slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()] + + # Run inference + sleap_track(args=args) + + # Assert predictions file exists + expected_extensions = available_video_exts() + + for file_path in slp_path_list: + if file_path.suffix in expected_extensions: + expected_output_file = Path(file_path).with_suffix(".predictions.slp") + assert Path(expected_output_file).exists() + def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): """Test flow tracker instances are pruned.""" From 5ab305c4cf8ef68f815f39071ec0b030ac66d4af Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Thu, 1 Aug 2024 13:42:33 -0700 Subject: [PATCH 12/27] Fix GUI crash on scroll (#1883) * Only pass wheelEvent to children that can handle it * Add test for wheelEvent --- sleap/gui/widgets/video.py | 27 ++++++++++----------- tests/gui/test_video_player.py | 44 +++++++++++++++++++++++++++++----- 2 files changed, 50 insertions(+), 21 deletions(-) diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 745908048..04965bbbb 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -1150,8 +1150,13 @@ def mouseDoubleClickEvent(self, event: QMouseEvent): QGraphicsView.mouseDoubleClickEvent(self, event) def wheelEvent(self, event): - """Custom event handler. Zoom in/out based on scroll wheel change.""" - # zoom on wheel when no mouse buttons are pressed + """Custom event handler to zoom in/out based on scroll wheel change. + + We cannot use the default QGraphicsView.wheelEvent behavior since that will + scroll the view. + """ + + # Zoom on wheel when no mouse buttons are pressed if event.buttons() == Qt.NoButton: angle = event.angleDelta().y() factor = 1.1 if angle > 0 else 0.9 @@ -1159,20 +1164,10 @@ def wheelEvent(self, event): self.zoomFactor = max(factor * self.zoomFactor, 1) self.updateViewer() - # Trigger wheelEvent for all child elements. This is a bit of a hack. - # We can't use QGraphicsView.wheelEvent(self, event) since that will scroll - # view. - # We want to trigger for all children, since wheelEvent should continue rotating - # an skeleton even if the skeleton node/node label is no longer under the - # cursor. - # Note that children expect a QGraphicsSceneWheelEvent event, which is why we're - # explicitly ignoring TypeErrors. Everything seems to work fine since we don't - # care about the mouse position; if we did, we'd need to map pos to scene. + # Trigger only for rotation-relevant children (otherwise GUI crashes) for child in self.items(): - try: + if isinstance(child, (QtNode, QtNodeLabel)): child.wheelEvent(event) - except TypeError: - pass def keyPressEvent(self, event): """Custom event hander, disables default QGraphicsView behavior.""" @@ -1590,7 +1585,9 @@ def mouseReleaseEvent(self, event): def wheelEvent(self, event): """Custom event handler for mouse scroll wheel.""" if self.dragParent: - angle = event.delta() / 20 + self.parentObject().rotation() + angle = ( + event.angleDelta().x() + event.angleDelta().y() + ) / 20 + self.parentObject().rotation() self.parentObject().setRotation(angle) event.accept() diff --git a/tests/gui/test_video_player.py b/tests/gui/test_video_player.py index b0661a4e1..c246f0489 100644 --- a/tests/gui/test_video_player.py +++ b/tests/gui/test_video_player.py @@ -3,14 +3,13 @@ from sleap.gui.widgets.video import ( QtVideoPlayer, GraphicsView, - QtInstance, QtVideoPlayer, QtTextWithBackground, VisibleBoundingBox, ) from qtpy import QtCore, QtWidgets -from qtpy.QtGui import QColor +from qtpy.QtGui import QColor, QWheelEvent def test_gui_video(qtbot): @@ -20,10 +19,6 @@ def test_gui_video(qtbot): assert vp.close() - # Click the button 20 times - # for i in range(20): - # qtbot.mouseClick(vp.btn, QtCore.Qt.LeftButton) - def test_gui_video_instances(qtbot, small_robot_mp4_vid, centered_pair_labels): vp = QtVideoPlayer(small_robot_mp4_vid) @@ -144,3 +139,40 @@ def test_VisibleBoundingBox(qtbot, centered_pair_labels): # Check if bounding box scaled appropriately assert inst.box.rect().width() - initial_width == 2 * dx assert inst.box.rect().height() - initial_height == 2 * dy + + +def test_wheelEvent(qtbot): + """Test the wheelEvent method of the GraphicsView class.""" + graphics_view = GraphicsView() + + # Create a QWheelEvent + position = QtCore.QPointF(100, 100) # The position of the wheel event + global_position = QtCore.QPointF(100, 100) # The global position of the wheel event + pixel_delta = QtCore.QPoint(0, 120) # The distance in pixels the wheel is rotated + angle_delta = QtCore.QPoint(0, 120) # The distance in degrees the wheel is rotated + buttons = QtCore.Qt.NoButton # No mouse button is pressed + modifiers = QtCore.Qt.NoModifier # No keyboard modifier is pressed + phase = QtCore.Qt.ScrollUpdate # The phase of the scroll event + inverted = False # The scroll direction is not inverted + source = ( + QtCore.Qt.MouseEventNotSynthesized + ) # The event is not synthesized from a touch or tablet event + + event = QWheelEvent( + position, + global_position, + pixel_delta, + angle_delta, + buttons, + modifiers, + phase, + inverted, + source, + ) + + # Call the wheelEvent method + print( + "Testing GraphicsView.wheelEvent which will result in exit code 127 " + "originating from a segmentation fault if it fails." + ) + graphics_view.wheelEvent(event) From 3813901ba4a0beb8745f6d8ba98132d74268eddc Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 7 Aug 2024 10:37:40 -0700 Subject: [PATCH 13/27] Fix typo to allow rendering videos with mp4 (Mac) (#1892) Fix typo to allow rendering videos with mp4 --- sleap/gui/commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 8df85fc8e..fce6458f5 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -1335,7 +1335,7 @@ def ask(context: CommandContext, params: dict) -> bool: context.app, caption="Save Video As...", dir=default_out_filename, - filter="Video (*.avi *mp4)", + filter="Video (*.avi *.mp4)", ) # Check if user hit cancel From 076f3dda2036c5833dec35ff4479d3ae1c7f7f55 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 7 Aug 2024 10:40:29 -0700 Subject: [PATCH 14/27] Do not apply offset when double clicking a `PredictedInstance` (#1888) * Add offset argument to newInstance and AddInstance * Apply offset of 10 for Add Instance menu button (Ctrl + I) * Add offset for docks Add Instance button * Make the QtVideoPlayer context menu unit-testable * Add test for creating a new instance * Add test for "New Instance" button in `InstancesDock` * Fix typo in docstring * Add docstrings and typehinting * Remove unused imports and sort imports --- sleap/gui/app.py | 6 +- sleap/gui/commands.py | 41 +++++++------ sleap/gui/widgets/docks.py | 2 +- sleap/gui/widgets/video.py | 72 ++++++++++++---------- tests/gui/test_commands.py | 103 ++++++++++++++++++++++++++++++++ tests/gui/widgets/test_docks.py | 36 +++++++++-- 6 files changed, 206 insertions(+), 54 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index e872ce9a6..e2396948a 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -704,13 +704,17 @@ def prev_vid(): ) def new_instance_menu_action(): + """Determine which action to use when using Ctrl + I or menu Add Instance. + + We always add an offset of 10. + """ method_key = [ key for (key, val) in instance_adding_methods.items() if val == self.state["instance_init_method"] ] if method_key: - self.commands.newInstance(init_method=method_key[0]) + self.commands.newInstance(init_method=method_key[0], offset=10) labelMenu = self.menuBar().addMenu("Labels") add_menu_item( diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index fce6458f5..e3ef8522d 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -504,6 +504,7 @@ def newInstance( init_method: str = "best", location: Optional[QtCore.QPoint] = None, mark_complete: bool = False, + offset: int = 0, ): """Creates a new instance, copying node coordinates as appropriate. @@ -513,6 +514,8 @@ def newInstance( init_method: Method to use for positioning nodes. location: The location where instance should be added (if node init method supports custom location). + mark_complete: Whether to mark the instance as complete. + offset: Offset to apply to the location if given. """ self.execute( AddInstance, @@ -520,6 +523,7 @@ def newInstance( init_method=init_method, location=location, mark_complete=mark_complete, + offset=offset, ) def setPointLocations( @@ -2858,6 +2862,7 @@ def do_action(cls, context: CommandContext, params: dict): init_method = params.get("init_method", "best") location = params.get("location", None) mark_complete = params.get("mark_complete", False) + offset = params.get("offset", 0) if context.state["labeled_frame"] is None: return @@ -2881,6 +2886,7 @@ def do_action(cls, context: CommandContext, params: dict): init_method=init_method, location=location, from_prev_frame=from_prev_frame, + offset=offset, ) # Add the instance @@ -2898,6 +2904,7 @@ def create_new_instance( init_method: str, location: Optional[QtCore.QPoint], from_prev_frame: bool, + offset: int = 0, ) -> Instance: """Create new instance.""" @@ -2915,6 +2922,7 @@ def create_new_instance( mark_complete=mark_complete, init_method=init_method, location=location, + offset=offset, ) if has_missing_nodes: @@ -2988,6 +2996,7 @@ def set_visible_nodes( mark_complete: bool, init_method: str, location: Optional[QtCore.QPoint] = None, + offset: int = 0, ) -> bool: """Sets visible nodes for new instance. @@ -2996,6 +3005,9 @@ def set_visible_nodes( copy_instance: The instance to copy from. new_instance: The new instance. mark_complete: Whether to mark the instance as complete. + init_method: The initialization method. + location: The location of the mouse click if any. + offset: The offset to apply to all nodes. Returns: Whether the new instance has missing nodes. @@ -3014,24 +3026,19 @@ def set_visible_nodes( scale_width = new_size_width / old_size_width scale_height = new_size_height / old_size_height - # Default the offset is 0 - offset_x = 0 - offset_y = 0 - - # Using the menu or the hotkey - if init_method == "best": - offset_x = 10 - offset_y = 10 + # The offset is 0, except when using Ctrl + I or Add Instance button. + offset_x = offset + offset_y = offset - # Using right click and context menu - if location is not None: - reference_node = next( - (node for node in copy_instance if not node.isnan()), None - ) - reference_x = reference_node.x - reference_y = reference_node.y - offset_x = location.x() - (reference_x * scale_width) - offset_y = location.y() - (reference_y * scale_height) + # Using right click and context menu with option "best" + if (init_method == "best") and (location is not None): + reference_node = next( + (node for node in copy_instance if not node.isnan()), None + ) + reference_x = reference_node.x + reference_y = reference_node.y + offset_x = location.x() - (reference_x * scale_width) + offset_y = location.y() - (reference_y * scale_height) # Go through each node in skeleton. for node in context.state["skeleton"].node_names: diff --git a/sleap/gui/widgets/docks.py b/sleap/gui/widgets/docks.py index 43e218adb..3375e4713 100644 --- a/sleap/gui/widgets/docks.py +++ b/sleap/gui/widgets/docks.py @@ -557,7 +557,7 @@ def create_table_edit_buttons(self) -> QWidget: hb = QHBoxLayout() self.add_button( - hb, "New Instance", lambda x: main_window.commands.newInstance() + hb, "New Instance", lambda x: main_window.commands.newInstance(offset=10) ) self.add_button( hb, "Delete Instance", main_window.commands.deleteSelectedInstance diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 04965bbbb..949703020 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -240,6 +240,8 @@ def __init__( self._register_shortcuts() + self.context_menu = None + self._menu_actions = dict() if self.context: self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) self.customContextMenuRequested.connect(self.show_contextual_menu) @@ -358,44 +360,54 @@ def add_shortcut(key, step): def setSeekbarSelection(self, a: int, b: int): self.seekbar.setSelection(a, b) - def show_contextual_menu(self, where: QtCore.QPoint): - if not self.is_menu_enabled: - return + def create_contextual_menu(self, scene_pos: QtCore.QPointF) -> QtWidgets.QMenu: + """Create the context menu for the viewer. - scene_pos = self.view.mapToScene(where) - menu = QtWidgets.QMenu() + This is called when the user right-clicks in the viewer. This function also + stores the menu actions in the `_menu_actions` attribute so that they can be + accessed later and stores the context menu in the `context_menu` attribute. - menu.addAction("Add Instance:").setEnabled(False) + Args: + scene_pos: The position in the scene where the menu was requested. - menu.addAction( - "Default", - lambda: self.context.newInstance(init_method="best", location=scene_pos), - ) + Returns: + The created context menu. + """ - menu.addAction( - "Average", - lambda: self.context.newInstance( - init_method="template", location=scene_pos - ), - ) + self.context_menu = QtWidgets.QMenu() + self.context_menu.addAction("Add Instance:").setEnabled(False) + + self._menu_actions = dict() + params_by_action_name = { + "Default": {"init_method": "best", "location": scene_pos}, + "Average": {"init_method": "template", "location": scene_pos}, + "Force Directed": {"init_method": "force_directed", "location": scene_pos}, + "Copy Prior Frame": {"init_method": "prior_frame"}, + "Random": {"init_method": "random", "location": scene_pos}, + } + for action_name, params in params_by_action_name.items(): + self._menu_actions[action_name] = self.context_menu.addAction( + action_name, lambda params=params: self.context.newInstance(**params) + ) - menu.addAction( - "Force Directed", - lambda: self.context.newInstance( - init_method="force_directed", location=scene_pos - ), - ) + return self.context_menu - menu.addAction( - "Copy Prior Frame", - lambda: self.context.newInstance(init_method="prior_frame"), - ) + def show_contextual_menu(self, where: QtCore.QPoint): + """Show the context menu at the given position in the viewer. - menu.addAction( - "Random", - lambda: self.context.newInstance(init_method="random", location=scene_pos), - ) + This is called when the user right-clicks in the viewer. This function calls + `create_contextual_menu` to create the menu and then shows the menu at the + given position. + + Args: + where: The position in the viewer where the menu was requested. + """ + if not self.is_menu_enabled: + return + + scene_pos = self.view.mapToScene(where) + menu = self.create_contextual_menu(scene_pos) menu.exec_(self.mapToGlobal(where)) def load_video(self, video: Video, plot=True): diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 899b1f4a0..ffd382ab1 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -3,11 +3,15 @@ import sys import time +import numpy as np from pathlib import PurePath, Path +from qtpy import QtCore from typing import List from sleap import Skeleton, Track, PredictedInstance +from sleap.gui.app import MainWindow from sleap.gui.commands import ( + AddInstance, CommandContext, ExportAnalysisFile, ExportDatasetWithImages, @@ -922,3 +926,102 @@ def no_gui_ask(cls, context, params): # Case 3: Export all frames and suggested frames with image data. context.exportFullPackage() assert_loaded_package_similar(path_to_pkg, sugg=True, pred=True) + + +def test_newInstance(qtbot, centered_pair_predictions: Labels): + + # Get the data + labels = centered_pair_predictions + lf = labels[0] + pred_inst = lf.instances[0] + video = labels.video + + # Set-up command context + main_window = MainWindow(labels=labels) + context = main_window.commands + context.state["labeled_frame"] = lf + context.state["frame_idx"] = lf.frame_idx + context.state["skeleton"] = labels.skeleton + context.state["video"] = labels.videos[0] + + # Case 1: Double clicking a prediction results in no offset for new instance + + # Double click on prediction + assert len(lf.instances) == 2 + main_window._handle_instance_double_click(instance=pred_inst) + + # Check new instance + assert len(lf.instances) == 3 + new_inst = lf.instances[-1] + assert new_inst.from_predicted is pred_inst + assert np.array_equal(new_inst.numpy(), pred_inst.numpy()) # No offset + + # Case 2: Using Ctrl + I (or menu "Add Instance" button) + + # Connect the action to a slot + add_instance_menu_action = main_window._menu_actions["add instance"] + triggered = False + + def on_triggered(): + nonlocal triggered + triggered = True + + add_instance_menu_action.triggered.connect(on_triggered) + + # Find which instance we are going to copy from + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=None, init_method="best" + ) + + # Click on the menu action + assert len(lf.instances) == 3 + add_instance_menu_action.trigger() + assert triggered, "Action not triggered" + + # Check new instance + assert len(lf.instances) == 4 + new_inst = lf.instances[-1] + offset = 10 + np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) + assert np.all( + np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) == offset + ) + + # Case 3: Using right click and "Default" option + + # Find which instance we are going to copy from + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=None, init_method="best" + ) + + video_player = main_window.player + right_click_location_x = video.shape[2] / 2 + right_click_location_y = video.shape[1] / 2 + right_click_location = QtCore.QPointF( + right_click_location_x, right_click_location_y + ) + video_player.create_contextual_menu(scene_pos=right_click_location) + default_action = video_player._menu_actions["Default"] + default_action.trigger() + + # Check new instance + assert len(lf.instances) == 5 + new_inst = lf.instances[-1] + reference_node_idx = np.where( + np.all( + new_inst.numpy() == [right_click_location_x, right_click_location_y], axis=1 + ) + )[0][0] + offset = ( + new_inst.numpy()[reference_node_idx] - copy_instance.numpy()[reference_node_idx] + ) + diff = np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) + assert np.all(diff == offset) diff --git a/tests/gui/widgets/test_docks.py b/tests/gui/widgets/test_docks.py index 69fe56a56..d5c16a763 100644 --- a/tests/gui/widgets/test_docks.py +++ b/tests/gui/widgets/test_docks.py @@ -1,15 +1,17 @@ """Module for testing dock widgets for the `MainWindow`.""" from pathlib import Path -import pytest + +import numpy as np + from sleap import Labels, Video from sleap.gui.app import MainWindow -from sleap.gui.commands import OpenSkeleton +from sleap.gui.commands import AddInstance, OpenSkeleton from sleap.gui.widgets.docks import ( InstancesDock, + SkeletonDock, SuggestionsDock, VideosDock, - SkeletonDock, ) @@ -99,11 +101,35 @@ def test_suggestions_dock(qtbot): assert dock.wgt_layout is dock.widget().layout() -def test_instances_dock(qtbot): +def test_instances_dock(qtbot, centered_pair_predictions: Labels): """Test the `DockWidget` class.""" - main_window = MainWindow() + main_window = MainWindow(labels=centered_pair_predictions) + labels = main_window.labels + context = main_window.commands + lf = context.state["labeled_frame"] dock = InstancesDock(main_window) assert dock.name == "Instances" assert dock.main_window is main_window assert dock.wgt_layout is dock.widget().layout() + + # Test new instance button + + offset = 10 + + # Find instance that we will copy from + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=None, init_method="best" + ) + n_instance = len(lf.instances) + dock.main_window._buttons["new instance"].click() + + # Check that new instance was added with offset + assert len(lf.instances) == n_instance + 1 + new_inst = lf.instances[-1] + diff = np.nan_to_num(new_inst.numpy() - copy_instance.numpy(), nan=offset) + assert np.all(diff == offset) From efdf3faa87019b070438eddbde1f917433aeff9a Mon Sep 17 00:00:00 2001 From: Elizabeth <106755962+eberrigan@users.noreply.github.com> Date: Thu, 15 Aug 2024 10:24:13 -0700 Subject: [PATCH 15/27] Refactor video writer to use imageio instead of skvideo (#1900) * modify `VideoWriter` to use imageio with ffmpeg backend * check to see if ffmpeg is present * use the new check for ffmpeg * import imageio.v2 * add imageio-ffmpeg to environments to test * using avi format for now * remove SKvideo videowriter * test `VideoWriterImageio` minimally * add more documentation for ffmpeg * default mp4 for ffmpeg should be mp4 * print using `IMAGEIO` when using ffmpeg * mp4 for ffmpeg * use mp4 ending in test * test `VideoWriterImageio` with avi file extension * test video with odd size * remove redundant filter since imageio-ffmpeg resizes automatically * black * remove unused import * use logging instead of print statement * import cv2 is needed for resize * remove logging --- environment.yml | 1 + environment_mac.yml | 1 + environment_no_cuda.yml | 1 + sleap/gui/commands.py | 6 +-- sleap/gui/dialogs/export_clip.py | 8 ++-- sleap/io/videowriter.py | 71 +++++++++++++++++++------------- tests/io/test_videowriter.py | 63 +++++++++++++++++++++++++++- 7 files changed, 114 insertions(+), 37 deletions(-) diff --git a/environment.yml b/environment.yml index 9c5758c13..2aba3c7d2 100644 --- a/environment.yml +++ b/environment.yml @@ -12,6 +12,7 @@ dependencies: # Packages SLEAP uses directly - conda-forge::attrs >=21.2.0 #,<=21.4.0 - conda-forge::cattrs ==1.1.1 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::networkx diff --git a/environment_mac.yml b/environment_mac.yml index 42d6e028c..9ab10a1b8 100644 --- a/environment_mac.yml +++ b/environment_mac.yml @@ -12,6 +12,7 @@ dependencies: - conda-forge::importlib-metadata <7.1.0 - conda-forge::cattrs ==1.1.1 - conda-forge::h5py + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::keras <2.10.0,>=2.9.0rc0 # Required by tensorflow-macos diff --git a/environment_no_cuda.yml b/environment_no_cuda.yml index fc13f839a..2adee7a89 100644 --- a/environment_no_cuda.yml +++ b/environment_no_cuda.yml @@ -13,6 +13,7 @@ dependencies: # Packages SLEAP uses directly - conda-forge::attrs >=21.2.0 #,<=21.4.0 - conda-forge::cattrs ==1.1.1 + - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin - conda-forge::jsonpickle ==1.2 - conda-forge::networkx diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index e3ef8522d..692f19c78 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -1329,12 +1329,10 @@ def ask(context: CommandContext, params: dict) -> bool: # makes mp4's that most programs can't open (VLC can). default_out_filename = context.state["filename"] + ".avi" - # But if we can write mpegs using sci-kit video, use .mp4 - # since it has trouble writing .avi files. - if VideoWriter.can_use_skvideo(): + if VideoWriter.can_use_ffmpeg(): default_out_filename = context.state["filename"] + ".mp4" - # Ask where use wants to save video file + # Ask where user wants to save video file filename, _ = FileDialog.save( context.app, caption="Save Video As...", diff --git a/sleap/gui/dialogs/export_clip.py b/sleap/gui/dialogs/export_clip.py index 312f9a807..f84766d18 100644 --- a/sleap/gui/dialogs/export_clip.py +++ b/sleap/gui/dialogs/export_clip.py @@ -11,16 +11,16 @@ def __init__(self): super().__init__(form_name="labeled_clip_form") - can_use_skvideo = VideoWriter.can_use_skvideo() + can_use_ffmpeg = VideoWriter.can_use_ffmpeg() - if can_use_skvideo: + if can_use_ffmpeg: message = ( "MP4 file will be encoded using " - "system ffmpeg via scikit-video (preferred option)." + "system ffmpeg via imageio (preferred option)." ) else: message = ( - "Unable to use ffpmeg via scikit-video. " + "Unable to use ffpmeg via imageio. " "AVI file will be encoding using OpenCV." ) diff --git a/sleap/io/videowriter.py b/sleap/io/videowriter.py index 510fad739..cd710c9d5 100644 --- a/sleap/io/videowriter.py +++ b/sleap/io/videowriter.py @@ -12,6 +12,7 @@ from abc import ABC, abstractmethod import cv2 import numpy as np +import imageio.v2 as iio class VideoWriter(ABC): @@ -32,22 +33,26 @@ def close(self): @staticmethod def safe_builder(filename, height, width, fps): """Builds VideoWriter based on available dependencies.""" - if VideoWriter.can_use_skvideo(): - return VideoWriterSkvideo(filename, height, width, fps) + if VideoWriter.can_use_ffmpeg(): + return VideoWriterImageio(filename, height, width, fps) else: return VideoWriterOpenCV(filename, height, width, fps) @staticmethod - def can_use_skvideo(): - # See if we can import skvideo + def can_use_ffmpeg(): + """Check if ffmpeg is available for writing videos.""" try: - import skvideo + import imageio_ffmpeg as ffmpeg except ImportError: return False - # See if skvideo can find FFMPEG - if skvideo.getFFmpegVersion() != "0.0.0": - return True + try: + # Try to get the version of the ffmpeg plugin + ffmpeg_version = ffmpeg.get_ffmpeg_version() + if ffmpeg_version: + return True + except Exception: + return False return False @@ -68,11 +73,11 @@ def close(self): self._writer.release() -class VideoWriterSkvideo(VideoWriter): - """Writes video using scikit-video as wrapper for ffmpeg. +class VideoWriterImageio(VideoWriter): + """Writes video using imageio as a wrapper for ffmpeg. Attributes: - filename: Path to mp4 file to save to. + filename: Path to video file to save to. height: Height of movie frames. width: Width of movie frames. fps: Playback framerate to save at. @@ -85,28 +90,38 @@ class VideoWriterSkvideo(VideoWriter): def __init__( self, filename, height, width, fps, crf: int = 21, preset: str = "superfast" ): - import skvideo.io - - fps = str(fps) - self._writer = skvideo.io.FFmpegWriter( + self.filename = filename + self.height = height + self.width = width + self.fps = fps + self.crf = crf + self.preset = preset + + import imageio_ffmpeg as ffmpeg + + # Imageio's ffmpeg writer parameters + # https://imageio.readthedocs.io/en/stable/examples.html#writing-videos-with-ffmpeg-and-vaapi + # Use `ffmpeg -h encoder=libx264`` to see all options for libx264 output_params + # output_params must be a list of strings + # iio.help(name='FFMPEG') to test + self.writer = iio.get_writer( filename, - inputdict={ - "-r": fps, - }, - outputdict={ - "-c:v": "libx264", - "-preset": preset, - "-vf": "scale=trunc(iw/2)*2:trunc(ih/2)*2", # Need even dims for libx264 - "-framerate": fps, - "-crf": str(crf), - "-pix_fmt": "yuv420p", - }, + fps=fps, + codec="libx264", + format="FFMPEG", + pixelformat="yuv420p", + output_params=[ + "-preset", + preset, + "-crf", + str(crf), + ], ) def add_frame(self, img, bgr: bool = False): if bgr: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - self._writer.writeFrame(img) + self.writer.append_data(img) def close(self): - self._writer.close() + self.writer.close() diff --git a/tests/io/test_videowriter.py b/tests/io/test_videowriter.py index dea193117..35d9bc6df 100644 --- a/tests/io/test_videowriter.py +++ b/tests/io/test_videowriter.py @@ -1,5 +1,7 @@ import os -from sleap.io.videowriter import VideoWriter, VideoWriterOpenCV +import cv2 +from pathlib import Path +from sleap.io.videowriter import VideoWriter, VideoWriterOpenCV, VideoWriterImageio def test_video_writer(tmpdir, small_robot_mp4_vid): @@ -38,3 +40,62 @@ def test_cv_video_writer(tmpdir, small_robot_mp4_vid): writer.close() assert os.path.exists(out_path) + + +def test_imageio_video_writer_avi(tmpdir, small_robot_mp4_vid): + out_path = Path(tmpdir) / "clip.avi" + + # Make sure imageio video writer works + writer = VideoWriterImageio( + out_path, + height=small_robot_mp4_vid.height, + width=small_robot_mp4_vid.width, + fps=small_robot_mp4_vid.fps, + ) + + writer.add_frame(small_robot_mp4_vid[0][0]) + writer.add_frame(small_robot_mp4_vid[1][0]) + + writer.close() + + assert os.path.exists(out_path) + # Check attributes + assert writer.height == small_robot_mp4_vid.height + assert writer.width == small_robot_mp4_vid.width + assert writer.fps == small_robot_mp4_vid.fps + assert writer.filename == out_path + assert writer.crf == 21 + assert writer.preset == "superfast" + + +def test_imageio_video_writer_odd_size(tmpdir, movenet_video): + out_path = Path(tmpdir) / "clip.mp4" + + # Reduce the size of the video frames by 1 pixel in each dimension + reduced_height = movenet_video.height - 1 + reduced_width = movenet_video.width - 1 + + # Initialize the writer with the reduced dimensions + writer = VideoWriterImageio( + out_path, + height=reduced_height, + width=reduced_width, + fps=movenet_video.fps, + ) + + # Resize frames and add them to the video + for i in range(len(movenet_video) - 1): + frame = movenet_video[i][0] # Access the actual frame object + reduced_frame = cv2.resize(frame, (reduced_width, reduced_height)) + writer.add_frame(reduced_frame) + + writer.close() + + # Assertions to validate the test + assert os.path.exists(out_path) + assert writer.height == reduced_height + assert writer.width == reduced_width + assert writer.fps == movenet_video.fps + assert writer.filename == out_path + assert writer.crf == 21 + assert writer.preset == "superfast" From f9d07b88995a87a29b5c4b5a9618139177b5fb60 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Thu, 15 Aug 2024 19:53:45 -0700 Subject: [PATCH 16/27] Use `Video.from_filename` when structuring videos (#1905) * Use Video.from_filename when structuring videos * Modify removal_test_labels to have extension in filename --- sleap/io/video.py | 21 ++++++++------------- tests/io/test_dataset.py | 2 +- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/sleap/io/video.py b/sleap/io/video.py index f64373d37..c8272cfbd 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -1545,22 +1545,17 @@ def cattr(): A cattr converter. """ - # When we are structuring video backends, try to fixup the video file paths - # in case they are coming from a different computer or the file has been moved. - def fixup_video(x, cl): - if "filename" in x: - x["filename"] = Video.fixup_path(x["filename"]) - if "file" in x: - x["file"] = Video.fixup_path(x["file"]) + # Use from_filename to fixup the video path and determine backend + def fixup_video(x: dict, cl: Video): + backend_dict = x.pop("backend") + filename = backend_dict.pop("filename", None) or backend_dict.pop( + "file", None + ) - return Video.make_specific_backend(cl, x) + return Video.from_filename(filename, **backend_dict) vid_cattr = cattr.Converter() - - # Check the type hint for backend and register the video path - # fixup hook for each type in the Union. - for t in attr.fields(Video).backend.type.__args__: - vid_cattr.register_structure_hook(t, fixup_video) + vid_cattr.register_structure_hook(Video, fixup_video) return vid_cattr diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 7f7ceb5d9..d71d4cc83 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1236,7 +1236,7 @@ def test_has_frame(): @pytest.fixture def removal_test_labels(): skeleton = Skeleton() - video = Video(backend=MediaVideo(filename="test")) + video = Video(backend=MediaVideo(filename="test.mp4")) lf_user_only = LabeledFrame( video=video, frame_idx=0, instances=[Instance(skeleton=skeleton)] ) From 260fb85e957700aca928015a44c49982c7477459 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Thu, 15 Aug 2024 20:08:42 -0700 Subject: [PATCH 17/27] Use | instead of + in key commands (#1907) * Use | instead of + in key commands * Lint --- sleap/gui/app.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index e2396948a..108930b66 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -757,12 +757,12 @@ def new_instance_menu_action(): labelMenu.addAction( "Copy Instance", self.commands.copyInstance, - Qt.CTRL + Qt.Key_C, + Qt.CTRL | Qt.Key_C, ) labelMenu.addAction( "Paste Instance", self.commands.pasteInstance, - Qt.CTRL + Qt.Key_V, + Qt.CTRL | Qt.Key_V, ) labelMenu.addSeparator() @@ -856,12 +856,12 @@ def new_instance_menu_action(): tracksMenu.addAction( "Copy Instance Track", self.commands.copyInstanceTrack, - Qt.CTRL + Qt.SHIFT + Qt.Key_C, + Qt.CTRL | Qt.SHIFT | Qt.Key_C, ) tracksMenu.addAction( "Paste Instance Track", self.commands.pasteInstanceTrack, - Qt.CTRL + Qt.SHIFT + Qt.Key_V, + Qt.CTRL | Qt.SHIFT | Qt.Key_V, ) tracksMenu.addSeparator() @@ -1361,10 +1361,24 @@ def _update_track_menu(self): """Updates track menu options.""" self.track_menu.clear() self.delete_tracks_menu.clear() + + # Create a dictionary mapping track indices to Qt.Key values + key_mapping = { + 0: Qt.Key_1, + 1: Qt.Key_2, + 2: Qt.Key_3, + 3: Qt.Key_4, + 4: Qt.Key_5, + 5: Qt.Key_6, + 6: Qt.Key_7, + 7: Qt.Key_8, + 8: Qt.Key_9, + 9: Qt.Key_0, + } for track_ind, track in enumerate(self.labels.tracks): key_command = "" if track_ind < 9: - key_command = Qt.CTRL + Qt.Key_0 + self.labels.tracks.index(track) + 1 + key_command = Qt.CTRL | key_mapping[track_ind] self.track_menu.addAction( f"{track.name}", lambda x=track: self.commands.setInstanceTrack(x), @@ -1374,7 +1388,7 @@ def _update_track_menu(self): f"{track.name}", lambda x=track: self.commands.deleteTrack(x) ) self.track_menu.addAction( - "New Track", self.commands.addTrack, Qt.CTRL + Qt.Key_0 + "New Track", self.commands.addTrack, Qt.CTRL | Qt.Key_0 ) def _update_seekbar_marks(self): From 8a8ed575cf597f3319e679ab6b43776fef6b3eba Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Fri, 16 Aug 2024 07:44:07 -0700 Subject: [PATCH 18/27] Replace QtDesktop widget in preparation for PySide6 (#1908) * Replace to-be-depreciated QDesktopWidget * Remove unused imports and sort remaining imports --- sleap/gui/learning/dialog.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 184088897..2c2617036 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -1,23 +1,20 @@ """ Dialogs for running training and/or inference in GUI. """ -import cattr -import os +import json import shutil -import atexit import tempfile from pathlib import Path +from typing import Dict, List, Optional, Text, cast + +import cattr +from qtpy import QtCore, QtGui, QtWidgets import sleap from sleap import Labels, Video from sleap.gui.dialogs.filedialog import FileDialog from sleap.gui.dialogs.formbuilder import YamlFormWidget -from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield - -from typing import Dict, List, Text, Optional, cast - -from qtpy import QtWidgets, QtCore -import json +from sleap.gui.learning import configs, datagen, receptivefield, runners, scopedkeydict # List of fields which should show list of skeleton nodes NODE_LIST_FIELDS = [ @@ -171,7 +168,7 @@ def __init__( def adjust_initial_size(self): # Get screen size - screen = QtWidgets.QDesktopWidget().screenGeometry() + screen = QtGui.QGuiApplication.primaryScreen().availableGeometry() max_width = 1860 max_height = 1150 From 9e8f2b5a352edb6910211588a5b218aa09c3bdb5 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Sat, 17 Aug 2024 14:38:28 -0700 Subject: [PATCH 19/27] Remove unsupported |= operand to prepare for PySide6 (#1910) Fixes TypeError: unsupported operand type(s) for |=: 'int' and 'Option' --- sleap/gui/dialogs/filedialog.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sleap/gui/dialogs/filedialog.py b/sleap/gui/dialogs/filedialog.py index 930c71b0d..ff394d191 100644 --- a/sleap/gui/dialogs/filedialog.py +++ b/sleap/gui/dialogs/filedialog.py @@ -29,7 +29,8 @@ def set_dialog_type(cls, *args, **kwargs): if cls.is_non_native: kwargs["options"] = kwargs.get("options", 0) - kwargs["options"] |= QtWidgets.QFileDialog.DontUseNativeDialog + if not kwargs["options"]: + kwargs["options"] = QtWidgets.QFileDialog.DontUseNativeDialog # Make sure we don't send empty options argument if "options" in kwargs and not kwargs["options"]: From 0f3cf4e9a5a54b39efaf5ac5be3174826ea32b4c Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Sat, 17 Aug 2024 15:21:58 -0700 Subject: [PATCH 20/27] Use positional argument for exception type (#1912) traceback.format_exception has changed it's first positional argument's name from etype to exc in python 3.7 to 3.10 --- sleap/gui/commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 692f19c78..dfc0dbad8 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -2067,7 +2067,7 @@ def try_and_skip_if_error(func, *args, **kwargs): func(*args, **kwargs) except Exception as e: tb_str = traceback.format_exception( - etype=type(e), value=e, tb=e.__traceback__ + type(e), value=e, tb=e.__traceback__ ) logger.warning( f"Recieved the following error while replacing skeleton:\n" From 7a787bb424610fcf40e3c55b9c89f61c03380060 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:12:09 -0700 Subject: [PATCH 21/27] Replace all Video structuring with Video.cattr() (#1911) --- sleap/info/feature_suggestions.py | 2 +- sleap/io/asyncvideo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/info/feature_suggestions.py b/sleap/info/feature_suggestions.py index 51f9038a5..a5f773fa7 100644 --- a/sleap/info/feature_suggestions.py +++ b/sleap/info/feature_suggestions.py @@ -644,7 +644,7 @@ class ParallelFeaturePipeline(object): def get(self, video_idx): """Apply pipeline to single video by idx. Can be called in process.""" video_dict = self.videos_as_dicts[video_idx] - video = cattr.structure(video_dict, Video) + video = Video.cattr().structure(video_dict, Video) group_offset = video_idx * self.pipeline.n_clusters # t0 = time() diff --git a/sleap/io/asyncvideo.py b/sleap/io/asyncvideo.py index c48d21a8b..88876607e 100644 --- a/sleap/io/asyncvideo.py +++ b/sleap/io/asyncvideo.py @@ -166,7 +166,7 @@ def run(self): break if "video" in request: - self.video = cattr.structure(request["video"], Video) + self.video = Video.cattr().structure(request["video"], Video) logger.debug(f"loaded video: {self.video.filename}") if self.video is not None: From d9730e3a09abe337d69ede2d07dd7509548ace4f Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:06:14 -0700 Subject: [PATCH 22/27] Remove unused AsyncVideo class (#1917) Remove unused AsyncVideo --- sleap/io/asyncvideo.py | 218 ------------------------------------ tests/io/test_asyncvideo.py | 31 ----- 2 files changed, 249 deletions(-) delete mode 100644 sleap/io/asyncvideo.py delete mode 100644 tests/io/test_asyncvideo.py diff --git a/sleap/io/asyncvideo.py b/sleap/io/asyncvideo.py deleted file mode 100644 index 88876607e..000000000 --- a/sleap/io/asyncvideo.py +++ /dev/null @@ -1,218 +0,0 @@ -""" -Support for loading video frames (by chunk) in background process. -""" - -from sleap import Video -from sleap.message import PairedSender, PairedReceiver - -import cattr -import logging -import time -import numpy as np -from math import ceil -from multiprocessing import Process -from typing import Iterable, Iterator, List, Optional, Tuple - - -logger = logging.getLogger(__name__) - - -class AsyncVideo: - """Supports fetching chunks from video in background process.""" - - def __init__(self, base_port: int = 9010): - self.base_port = base_port - - # Spawn the server as a background process - self.server = AsyncVideoServer(self.base_port) - self.server.start() - - # Create sender/receiver for sending requests and receiving data via ZMQ - sender = PairedSender.from_tcp_ports(self.base_port, self.base_port + 1) - result_receiver = PairedReceiver.from_tcp_ports( - send_port=self.base_port + 2, rec_port=self.base_port + 3 - ) - - sender.setup() - result_receiver.setup() - - self.sender = sender - self.receiver = result_receiver - - # Use "handshake" to ensure that initial messages aren't dropped - self.handshake_success = sender.send_handshake() - - def close(self): - """Close the async video server and communication ports.""" - if self.sender and self.server: - self.sender.send_dict(dict(stop=True)) - self.server.join() - - self.server = None - - if self.sender: - self.sender.close() - self.sender = None - - if self.receiver: - self.receiver.close() - self.receiver = None - - def __del__(self): - self.close() - - @classmethod - def from_video( - cls, - video: Video, - frame_idxs: Optional[Iterable[int]] = None, - frames_per_chunk: int = 64, - ) -> "AsyncVideo": - """Create object and start loading frames in background process.""" - obj = cls() - obj.load_by_chunk( - video=video, frame_idxs=frame_idxs, frames_per_chunk=frames_per_chunk - ) - return obj - - def load_by_chunk( - self, - video: Video, - frame_idxs: Optional[Iterable[int]] = None, - frames_per_chunk: int = 64, - ): - """ - Sends request for loading video in background process. - - Args: - video: The :py:class:`Video` to load - frame_idxs: Frame indices we want to load; if None, then full video - is loaded. - frames_per_chunk: How many frames to load per chunk. - - Returns: - None, data should be accessed via :py:method:`chunks`. - """ - # prime the video since this seems to make frames load faster (!?) - video.test_frame - - request_dict = dict( - video=cattr.unstructure(video), frames_per_chunk=frames_per_chunk - ) - # if no frames are specified, whole video will be loaded - if frame_idxs is not None: - request_dict["frame_idxs"] = list(frame_idxs) - - # send the request - self.sender.send_dict(request_dict) - - @property - def chunks(self) -> Iterator[Tuple[List[int], np.ndarray]]: - """ - Generator for fetching chunks of frames. - - When all chunks are loaded, closes the server and communication ports. - - Yields: - Tuple with (list of frame indices, ndarray of frames) - """ - done = False - while not done: - results = self.receiver.check_messages() - if results: - for result in results: - yield result["frame_idxs"], result["ndarray"] - - if result["chunk"] == result["last_chunk"]: - done = True - - # automatically close when all chunks have been received - self.close() - - -class AsyncVideoServer(Process): - """ - Class which loads video frames in background on request. - - All interactions with video server should go through :py:class:`AsyncVideo` - which runs in local thread. - """ - - def __init__(self, base_port: int): - super(AsyncVideoServer, self).__init__() - - self.video = None - self.base_port = base_port - - def run(self): - receiver = PairedReceiver.from_tcp_ports(self.base_port + 1, self.base_port) - receiver.setup() - - result_sender = PairedSender.from_tcp_ports( - send_port=self.base_port + 3, rec_port=self.base_port + 2 - ) - result_sender.setup() - - running = True - while running: - requests = receiver.check_messages() - if requests: - - for request in requests: - - if "stop" in request: - running = False - logger.debug("stopping async video server") - break - - if "video" in request: - self.video = Video.cattr().structure(request["video"], Video) - logger.debug(f"loaded video: {self.video.filename}") - - if self.video is not None: - if "frames_per_chunk" in request: - - load_time = 0 - send_time = 0 - - per_chunk = request["frames_per_chunk"] - - frame_idxs = request.get( - "frame_idxs", list(range(self.video.frames)) - ) - - frame_count = len(frame_idxs) - chunks = ceil(frame_count / per_chunk) - - for chunk_idx in range(chunks): - start = per_chunk * chunk_idx - end = min(per_chunk * (chunk_idx + 1), frame_count) - chunk_frame_idxs = frame_idxs[start:end] - - # load the frames - t0 = time.time() - frames = self.video[chunk_frame_idxs] - t1 = time.time() - load_time += t1 - t0 - - metadata = dict( - chunk=chunk_idx, - last_chunk=chunks - 1, - frame_idxs=chunk_frame_idxs, - ) - - # send back results - t0 = time.time() - result_sender.send_array(metadata, frames) - t1 = time.time() - send_time += t1 - t0 - - logger.debug(f"returned chunk: {chunk_idx+1}/{chunks}") - - logger.debug(f"total load time: {load_time}") - logger.debug(f"total send time: {send_time}") - else: - logger.warning( - "unable to process message since no video loaded" - ) - logger.warning(request) diff --git a/tests/io/test_asyncvideo.py b/tests/io/test_asyncvideo.py deleted file mode 100644 index 1bc3f19c8..000000000 --- a/tests/io/test_asyncvideo.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -import sys -from sleap import Video -from sleap.io.asyncvideo import AsyncVideo - - -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="ZMQ testing breaks locally on Windows" -) -def test_async_video(centered_pair_vid, small_robot_mp4_vid): - async_video = AsyncVideo.from_video(centered_pair_vid, frames_per_chunk=23) - - all_idxs = [] - for idxs, frames in async_video.chunks: - assert len(idxs) in (23, 19) # 19 for last chunk - all_idxs.extend(idxs) - - assert frames.shape[0] == len(idxs) - assert frames.shape[1:] == centered_pair_vid.shape[1:] - - assert len(all_idxs) == centered_pair_vid.num_frames - - # make sure we can load another video (i.e., previous video closed) - - async_video = AsyncVideo.from_video( - small_robot_mp4_vid, frame_idxs=range(0, 10, 2), frames_per_chunk=10 - ) - - for idxs, frames in async_video.chunks: - # there should only be single chunk - assert idxs == list(range(0, 10, 2)) From c88412ccc7c358c26abf22bf2ddaf75f43cf5d46 Mon Sep 17 00:00:00 2001 From: Elizabeth <106755962+eberrigan@users.noreply.github.com> Date: Mon, 26 Aug 2024 08:08:29 -0700 Subject: [PATCH 23/27] Refactor `LossViewer` to use matplotlib (#1899) * use updated syntax for QtAgg backend of matplotlib * start add features to `MplCanvas` to replace QtCharts features in `LossViewer` (untested) * remove QtCharts imports and replace with MplCanvas * remove QtCharts imports and replace with MplCanvas * start using MplCanvas in LossViwer instead of QtCharts (untested) * use updated syntax * Uncomment all commented out QtChart * Add debug code * Refactor monitor to use LossViewer._init_series method * Add monitor only debug code * Add methods for setting up axes and legend * Add the matplotlib canvas to the widget * Resize axis with data (no log support yet) * Try using PathCollection for "batch" * Get "batch" plotting with ax.scatter (no log support yet) * Add log support * Add a _resize_axis method * Modify init_series to work for ax.plot as well * Use matplotlib to plot epoch_loss line * Add method _add_data_to_scatter * Add _add_data_to_plot method * Add docstring to _resize_axes * Add matplotlib plot for val_loss * Add matplotlib scatter for val_loss_best * Avoid errors with setting log scale before any positive values * Add x and y axes labels * Set title (removing html tags) * Add legend * Adjust positioning of plot * Lint * Leave MplCanvas unchanged * Removed unused training_monitor.LossViewer * Resize fonts * Move legend outside of plot * Add debug code for montitor aesthetics * Use latex formatting to bold parts of title * Make axes aesthetic * Add midpoint grid lines * Set initial limits on x and y axes to be 0+ * Ensure x axis minimum is always resized to 0+ * Adjust plot to account for plateau patience title * Add debug code for plateau patience title line * Lint * Set thicker line width * Remove unused import * Set log axis on initialization * Make tick labels smaller * Move plot down a smidge * Move ylabel left a bit * Lint * Add class LossPlot * Refactor LossViewer to use LossPlot * Remove QtCharts code * Remove debug codes * Allocate space for figure items based on item's size * Refactor LossPlot to use underscores for internal methods * Ensure y_min, y_max not equal Otherwise we get an unnecessary teminal message: UserWarning: Attempting to set identical bottom == top == 3.0 results in singular transformations; automatically expanding. self.axes.set_ylim(y_min, y_max) --------- Co-authored-by: roomrys Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- sleap/gui/app.py | 1 - sleap/gui/widgets/monitor.py | 884 ++++++++++++++++++++------ sleap/gui/widgets/mpl.py | 5 +- sleap/gui/widgets/training_monitor.py | 566 ----------------- sleap/nn/training.py | 2 +- 5 files changed, 705 insertions(+), 753 deletions(-) delete mode 100644 sleap/gui/widgets/training_monitor.py diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 108930b66..4c75dac3f 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -44,7 +44,6 @@ frame and instances listed in data view table. """ - import os import platform import random diff --git a/sleap/gui/widgets/monitor.py b/sleap/gui/widgets/monitor.py index a16456983..5b0ce1ae8 100644 --- a/sleap/gui/widgets/monitor.py +++ b/sleap/gui/widgets/monitor.py @@ -1,21 +1,590 @@ """GUI for monitoring training progress interactively.""" -import numpy as np -from time import perf_counter -from sleap.nn.config.training_job import TrainingJobConfig -from sleap.gui.utils import is_port_free, select_zmq_port -import zmq -import jsonpickle import logging -from typing import Optional, Dict -from qtpy import QtCore, QtWidgets, QtGui -from qtpy.QtCharts import QtCharts +from time import perf_counter +from typing import Dict, Optional, Tuple + import attr +import jsonpickle +import numpy as np +import zmq +from matplotlib.collections import PathCollection +import matplotlib.transforms as mtransforms +from qtpy import QtCore, QtWidgets +from sleap.gui.utils import is_port_free, select_zmq_port +from sleap.gui.widgets.mpl import MplCanvas +from sleap.nn.config.training_job import TrainingJobConfig logger = logging.getLogger(__name__) +class LossPlot(MplCanvas): + """Matplotlib canvas for diplaying training and validation loss curves.""" + + def __init__( + self, + width: int = 5, + height: int = 4, + dpi: int = 100, + log_scale: bool = True, + ignore_outliers: bool = False, + ): + super().__init__(width=width, height=height, dpi=dpi) + + self._log_scale: bool = log_scale + + self.ignore_outliers = ignore_outliers + + # Initialize the series for the plot + self.series: dict = {} + COLOR_TRAIN = (18, 158, 220) + COLOR_VAL = (248, 167, 52) + COLOR_BEST_VAL = (151, 204, 89) + + # Initialize scatter series for batch training loss + self.series["batch"] = self._init_series( + series_type=self.axes.scatter, + name="Batch Training Loss", + color=COLOR_TRAIN + (48,), + border_color=(255, 255, 255, 25), + ) + + # Initialize line series for epoch training loss + self.series["epoch_loss"] = self._init_series( + series_type=self.axes.plot, + name="Epoch Training Loss", + color=COLOR_TRAIN + (255,), + line_width=3.0, + ) + + # Initialize line series for epoch validation loss + self.series["val_loss"] = self._init_series( + series_type=self.axes.plot, + name="Epoch Validation Loss", + color=COLOR_VAL + (255,), + line_width=3.0, + zorder=4, # Below best validation loss series + ) + + # Initialize scatter series for best epoch validation loss + self.series["val_loss_best"] = self._init_series( + series_type=self.axes.scatter, + name="Best Validation Loss", + color=COLOR_BEST_VAL + (255,), + border_color=(255, 255, 255, 25), + zorder=5, # Above epoch validation loss series + ) + + # Set the x and y positions for the xy labels (as fraction of figure size) + self.ypos_xlabel = 0.1 + self.xpos_ylabel = 0.05 + + # Padding between the axes and the xy labels + self.xpos_padding = 0.2 + self.ypos_padding = 0.1 + + # Set up the major gridlines + self._setup_major_gridlines() + + # Set up the x-axis + self._setup_x_axis() + + # Set up the y-axis + self._set_up_y_axis() + + # Set up the legend + self.legend_width, legend_height = self._setup_legend() + + # Set up the title space + self.ypos_title = None + title_height = self._set_title_space() + self.ypos_title = 1 - title_height - self.ypos_padding + + # Determine the top height of the plot + top_height = max(title_height, legend_height) + + # Adjust the figure layout + self.xpos_left_plot = self.xpos_ylabel + self.xpos_padding + self.xpos_right_plot = 0.97 + self.ypos_bottom_plot = self.ypos_xlabel + self.ypos_padding + self.ypos_top_plot = 1 - top_height - self.ypos_padding + + # Adjust the top parameters as needed + self.fig.subplots_adjust( + left=self.xpos_left_plot, + right=self.xpos_right_plot, + top=self.ypos_top_plot, + bottom=self.ypos_bottom_plot, + ) + + @property + def log_scale(self): + """Returns True if the plot has a log scale for y-axis.""" + + return self._log_scale + + @log_scale.setter + def log_scale(self, val): + """Sets the scale of the y axis to log if True else linear.""" + + if isinstance(val, bool): + self._log_scale = val + + y_scale = "log" if self._log_scale else "linear" + self.axes.set_yscale(y_scale) + self.redraw_plot() + + def set_data_on_scatter(self, xs, ys, which): + """Set data on a scatter plot. + + Not to be used with line plots. + + Args: + xs: The x-coordinates of the data points. + ys: The y-coordinates of the data points. + which: The type of data point. Possible values are: + * "batch" + * "val_loss_best" + """ + + offsets = np.column_stack((xs, ys)) + self.series[which].set_offsets(offsets) + + def add_data_to_plot(self, x, y, which): + """Add data to a line plot. + + Not to be used with scatter plots. + + Args: + x: The x-coordinate of the data point. + y: The y-coordinate of the data point. + which: The type of data point. Possible values are: + * "epoch_loss" + * "val_loss" + """ + + x_data, y_data = self.series[which].get_data() + self.series[which].set_data(np.append(x_data, x), np.append(y_data, y)) + + def resize_axes(self, x, y): + """Resize axes to fit data. + + This is only called when plotting batches. + + Args: + x: The x-coordinates of the data points. + y: The y-coordinates of the data points. + """ + + # Set X scale to show all points + x_min, x_max = self._calculate_xlim(x) + self.axes.set_xlim(x_min, x_max) + + # Set Y scale, ensuring that y_min and y_max do not lead to sngular transform + y_min, y_max = self._calculate_ylim(y) + y_min, y_max = self.axes.yaxis.get_major_locator().nonsingular(y_min, y_max) + self.axes.set_ylim(y_min, y_max) + + # Add gridlines at midpoint between major ticks (major gridlines are automatic) + self._add_midpoint_gridlines() + + # Redraw the plot + self.redraw_plot() + + def redraw_plot(self): + """Redraw the plot.""" + + self.fig.canvas.draw_idle() + + def set_title(self, title, color=None): + """Set the title of the plot. + + Args: + title: The title text to display. + """ + + if color is None: + color = "black" + + self.axes.set_title( + title, fontweight="light", fontsize="small", color=color, x=0.55, y=1.03 + ) + + def update_runtime_title( + self, + epoch: int, + dt_min: int, + dt_sec: int, + last_epoch_val_loss: float = None, + penultimate_epoch_val_loss: float = None, + mean_epoch_time_min: int = None, + mean_epoch_time_sec: int = None, + eta_ten_epochs_min: int = None, + epochs_in_plateau: int = None, + plateau_patience: int = None, + epoch_in_plateau_flag: bool = False, + best_val_x: int = None, + best_val_y: float = None, + epoch_size: int = None, + ): + + # Add training epoch and runtime info + title = self._get_training_epoch_and_runtime_text(epoch, dt_min, dt_sec) + + if last_epoch_val_loss is not None: + + if penultimate_epoch_val_loss is not None: + # Add mean epoch time and ETA for next 10 epochs + eta_text = self._get_eta_text( + mean_epoch_time_min, mean_epoch_time_sec, eta_ten_epochs_min + ) + title = self._add_with_newline(title, eta_text) + + # Add epochs in plateau if flag is set + if epoch_in_plateau_flag: + plateau_text = self._get_epochs_in_plateau_text( + epochs_in_plateau, plateau_patience + ) + title = self._add_with_newline(title, plateau_text) + + # Add last epoch validation loss + last_val_text = self._get_last_validation_loss_text(last_epoch_val_loss) + title = self._add_with_newline(title, last_val_text) + + # Add best epoch validation loss if available + if best_val_x is not None: + best_epoch = (best_val_x // epoch_size) + 1 + best_val_text = self._get_best_validation_loss_text( + best_val_y, best_epoch + ) + title = self._add_with_newline(title, best_val_text) + + self.set_title(title) + + @staticmethod + def _get_training_epoch_and_runtime_text(epoch: int, dt_min: int, dt_sec: int): + """Get the training epoch and runtime text to display in the plot. + + Args: + epoch: The current epoch. + dt_min: The number of minutes since training started. + dt_sec: The number of seconds since training started. + """ + + runtime_text = ( + r"Training Epoch $\mathbf{" + str(epoch + 1) + r"}$ / " + r"Runtime: $\mathbf{" + f"{int(dt_min):02}:{int(dt_sec):02}" + r"}$" + ) + + return runtime_text + + @staticmethod + def _get_eta_text(mean_epoch_time_min, mean_epoch_time_sec, eta_ten_epochs_min): + """Get the mean time and ETA text to display in the plot. + + Args: + mean_epoch_time_min: The mean time per epoch in minutes. + mean_epoch_time_sec: The mean time per epoch in seconds. + eta_ten_epochs_min: The estimated time for the next ten epochs in minutes. + """ + + runtime_text = ( + r"Mean Time per Epoch: $\mathbf{" + + f"{int(mean_epoch_time_min):02}:{int(mean_epoch_time_sec):02}" + + r"}$ / " + r"ETA Next 10 Epochs: $\mathbf{" + f"{int(eta_ten_epochs_min)}" + r"}$ min" + ) + + return runtime_text + + @staticmethod + def _get_epochs_in_plateau_text(epochs_in_plateau, plateau_patience): + """Get the epochs in plateau text to display in the plot. + + Args: + epochs_in_plateau: The number of epochs in plateau. + plateau_patience: The number of epochs to wait before stopping training. + """ + + plateau_text = ( + r"Epochs in Plateau: $\mathbf{" + f"{epochs_in_plateau}" + r"}$ / " + r"$\mathbf{" + f"{plateau_patience}" + r"}$" + ) + + return plateau_text + + @staticmethod + def _get_last_validation_loss_text(last_epoch_val_loss): + """Get the last epoch validation loss text to display in the plot. + + Args: + last_epoch_val_loss: The validation loss from the last epoch. + """ + + last_val_loss_text = ( + "Last Epoch Validation Loss: " + r"$\mathbf{" + f"{last_epoch_val_loss:.3e}" + r"}$" + ) + + return last_val_loss_text + + @staticmethod + def _get_best_validation_loss_text(best_val_y, best_epoch): + """Get the best epoch validation loss text to display in the plot. + + Args: + best_val_x: The epoch number of the best validation loss. + best_val_y: The best validation loss. + """ + + best_val_loss_text = ( + r"Best Epoch Validation Loss: $\mathbf{" + + f"{best_val_y:.3e}" + + r"}$ (epoch $\mathbf{" + + str(best_epoch) + + r"}$)" + ) + + return best_val_loss_text + + @staticmethod + def _add_with_newline(old_text: str, new_text: str): + """Add a new line to the text. + + Args: + old_text: The existing text. + new_text: The text to add on a new line. + """ + + return old_text + "\n" + new_text + + @staticmethod + def _calculate_xlim(x: np.ndarray, dx: float = 0.5): + """Calculates x-axis limits. + + Args: + x: Array of x data to fit the limits to. + dx: The padding to add to the limits. + + Returns: + Tuple of the minimum and maximum x-axis limits. + """ + + x_min = min(x) - dx + x_min = x_min if x_min > 0 else 0 + x_max = max(x) + dx + + return x_min, x_max + + def _calculate_ylim(self, y: np.ndarray, dy: float = 0.02): + """Calculates y-axis limits. + + Args: + y: Array of y data to fit the limits to. + dy: The padding to add to the limits. + + Returns: + Tuple of the minimum and maximum y-axis limits. + """ + + if self.ignore_outliers: + dy = np.ptp(y) * 0.02 + # Set Y scale to exclude outliers + q1, q3 = np.quantile(y, (0.25, 0.75)) + iqr = q3 - q1 # Interquartile range + y_min = q1 - iqr * 1.5 + y_max = q3 + iqr * 1.5 + + # Keep within range of data + y_min = max(y_min, min(y) - dy) + y_max = min(y_max, max(y) + dy) + else: + # Set Y scale to show all points + dy = np.ptp(y) * 0.02 + y_min = min(y) - dy + y_max = max(y) + dy + + # For log scale, low cannot be 0 + if self.log_scale: + y_min = max(y_min, 1e-8) + + return y_min, y_max + + def _set_title_space(self): + """Set up the title space. + + Returns: + The height of the title space as a decimal fraction of the total figure height. + """ + + # Set a dummy title of the plot + n_lines = 5 # Number of lines in the title + title_str = "\n".join( + [r"Number: $\mathbf{" + str(n) + r"}$" for n in range(n_lines + 1)] + ) + self.set_title( + title_str, color="white" + ) # Set the title color to white so it's not visible + + # Draw the canvas to ensure the title is created + self.fig.canvas.draw() + + # Get the title Text object + title = self.axes.title + + # Get the bounding box of the title in display coordinates + bbox = title.get_window_extent() + + # Transform the bounding box to figure coordinates + bbox = bbox.transformed(self.fig.transFigure.inverted()) + + # Calculate the height of the title as a percentage of the total figure height + title_height = bbox.height + + return title_height + + def _setup_x_axis(self): + """Set up the x axis. + + This includes setting the label, limits, and bottom/right adjustment. + """ + + self.axes.set_xlim(0, 1) + self.axes.set_xlabel("Batches", fontweight="bold", fontsize="small") + + # Set the x-label in the center of the axes and some amount above the bottom of the figure + blended_transform = mtransforms.blended_transform_factory( + self.axes.transAxes, self.fig.transFigure + ) + self.axes.xaxis.set_label_coords( + 0.5, self.ypos_xlabel, transform=blended_transform + ) + + def _set_up_y_axis(self): + """Set up the y axis. + + This includes setting the label, limits, scaling, and left adjustment. + """ + + # Set the minimum value of the y-axis depending on scaling + if self.log_scale: + yscale = "log" + y_min = 0.001 + else: + yscale = "linear" + y_min = 0 + self.axes.set_ylim(bottom=y_min) + self.axes.set_yscale(yscale) + + # Set the y-label name, size, wight, and position + self.axes.set_ylabel("Loss", fontweight="bold", fontsize="small") + self.axes.yaxis.set_label_coords( + self.xpos_ylabel, 0.5, transform=self.fig.transFigure + ) + + def _setup_legend(self): + """Set up the legend. + + Returns: + Tuple of the width and height of the legend as a decimal fraction of the total figure width and height. + """ + + # Move the legend outside the plot on the upper left + legend = self.axes.legend( + loc="upper left", + fontsize="small", + bbox_to_anchor=(0, 1), + bbox_transform=self.fig.transFigure, + ) + + # Draw the canvas to ensure the legend is created + self.fig.canvas.draw() + + # Get the bounding box of the legend in display coordinates + bbox = legend.get_window_extent() + + # Transform the bounding box to figure coordinates + bbox = bbox.transformed(self.fig.transFigure.inverted()) + + # Calculate the width and height of the legend as a percentage of the total figure width and height + return bbox.width, bbox.height + + def _setup_major_gridlines(self): + + # Set the outline color of the plot to gray + for spine in self.axes.spines.values(): + spine.set_edgecolor("#d3d3d3") # Light gray color + + # Remove the top and right axis spines + self.axes.spines["top"].set_visible(False) + self.axes.spines["right"].set_visible(False) + + # Set the tick markers color to light gray, but not the tick labels + self.axes.tick_params( + axis="both", which="both", color="#d3d3d3", labelsize="small" + ) + + # Add gridlines at the tick labels + self.axes.grid(True, which="major", linewidth=0.5, color="#d3d3d3") + + def _add_midpoint_gridlines(self): + # Clear existing minor vertical lines + for line in self.axes.get_lines(): + if line.get_linestyle() == ":": + line.remove() + + # Add gridlines at midpoint between major ticks + major_ticks = self.axes.yaxis.get_majorticklocs() + if len(major_ticks) > 1: + prev_major_tick = major_ticks[0] + for major_tick in major_ticks[:-1]: + midpoint = (major_tick + prev_major_tick) / 2 + self.axes.axhline( + midpoint, linestyle=":", linewidth=0.5, color="#d3d3d3" + ) + prev_major_tick = major_tick + + def _init_series( + self, + series_type, + color, + name: Optional[str] = None, + line_width: Optional[float] = None, + border_color: Optional[Tuple[int, int, int]] = None, + zorder: Optional[int] = None, + ): + + # Set the color + color = [c / 255.0 for c in color] # Normalize color values to [0, 1] + + # Create the series + series = series_type( + [], + [], + color=color, + label=name, + marker="o", + zorder=zorder, + ) + + # ax.plot returns a list of PathCollections, so we need to get the first one + if not isinstance(series, PathCollection): + series = series[0] + + if line_width is not None: + series.set_linewidth(line_width) + + # Set the border color (edge color) + if border_color is not None: + border_color = [ + c / 255.0 for c in border_color + ] # Normalize color values to [0, 1] + series.set_edgecolor(border_color) + + return series + + class LossViewer(QtWidgets.QMainWindow): """Qt window for showing in-progress training metrics sent over ZMQ.""" @@ -42,12 +611,13 @@ def __init__( self.zmq_ports = zmq_ports self.batches_to_show = -1 # -1 to show all - self.ignore_outliers = False - self.log_scale = True + self._ignore_outliers = False + self._log_scale = True self.message_poll_time_ms = 20 # ms self.redraw_batch_time_ms = 500 # ms self.last_redraw_batch = None + self.canvas = None self.reset() self.setup_zmq(zmq_context) @@ -88,100 +658,22 @@ def reset( what: String identifier indicating which job type the current run corresponds to. """ - self.chart = QtCharts.QChart() - - self.series = dict() + self.canvas = LossPlot( + width=5, + height=4, + dpi=100, + log_scale=self.log_scale, + ignore_outliers=self.ignore_outliers, + ) - COLOR_TRAIN = (18, 158, 220) - COLOR_VAL = (248, 167, 52) - COLOR_BEST_VAL = (151, 204, 89) + self.mp_series = dict() + self.mp_series["batch"] = self.canvas.series["batch"] + self.mp_series["epoch_loss"] = self.canvas.series["epoch_loss"] + self.mp_series["val_loss"] = self.canvas.series["val_loss"] + self.mp_series["val_loss_best"] = self.canvas.series["val_loss_best"] - self.series["batch"] = QtCharts.QScatterSeries() - self.series["batch"].setName("Batch Training Loss") - self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48)) - self.series["batch"].setMarkerSize(8.0) - self.series["batch"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["batch"]) - - self.series["epoch_loss"] = QtCharts.QLineSeries() - self.series["epoch_loss"].setName("Epoch Training Loss") - self.series["epoch_loss"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - pen = self.series["epoch_loss"].pen() - pen.setWidth(4) - self.series["epoch_loss"].setPen(pen) - self.chart.addSeries(self.series["epoch_loss"]) - - self.series["epoch_loss_scatter"] = QtCharts.QScatterSeries() - self.series["epoch_loss_scatter"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - self.series["epoch_loss_scatter"].setMarkerSize(12.0) - self.series["epoch_loss_scatter"].setBorderColor( - QtGui.QColor(255, 255, 255, 25) - ) - self.chart.addSeries(self.series["epoch_loss_scatter"]) - - self.series["val_loss"] = QtCharts.QLineSeries() - self.series["val_loss"].setName("Epoch Validation Loss") - self.series["val_loss"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - pen = self.series["val_loss"].pen() - pen.setWidth(4) - self.series["val_loss"].setPen(pen) - self.chart.addSeries(self.series["val_loss"]) - - self.series["val_loss_scatter"] = QtCharts.QScatterSeries() - self.series["val_loss_scatter"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - self.series["val_loss_scatter"].setMarkerSize(12.0) - self.series["val_loss_scatter"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["val_loss_scatter"]) - - self.series["val_loss_best"] = QtCharts.QScatterSeries() - self.series["val_loss_best"].setName("Best Validation Loss") - self.series["val_loss_best"].setColor(QtGui.QColor(*COLOR_BEST_VAL, 255)) - self.series["val_loss_best"].setMarkerSize(12.0) - self.series["val_loss_best"].setBorderColor(QtGui.QColor(32, 32, 32, 25)) - self.chart.addSeries(self.series["val_loss_best"]) - - axisX = QtCharts.QValueAxis() - axisX.setLabelFormat("%d") - axisX.setTitleText("Batches") - self.chart.addAxis(axisX, QtCore.Qt.AlignBottom) - - # Create the different Y axes that can be used. - self.axisY = dict() - - self.axisY["log"] = QtCharts.QLogValueAxis() - self.axisY["log"].setBase(10) - - self.axisY["linear"] = QtCharts.QValueAxis() - - # Apply settings that apply to all Y axes. - for axisY in self.axisY.values(): - axisY.setLabelFormat("%f") - axisY.setLabelsVisible(True) - axisY.setMinorTickCount(1) - axisY.setTitleText("Loss") - - # Use the default Y axis. - axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"] - - # Add axes to chart and series. - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisX) - series.attachAxis(axisY) - - # Setup legend. - self.chart.legend().setVisible(True) - self.chart.legend().setAlignment(QtCore.Qt.AlignTop) - self.chart.legend().setMarkerShape(QtCharts.QLegend.MarkerShapeCircle) - - # Hide scatters for epoch and val loss from legend. - for s in ("epoch_loss_scatter", "val_loss_scatter"): - self.chart.legend().markers(self.series[s])[0].setVisible(False) - - self.chartView = QtCharts.QChartView(self.chart) - self.chartView.setRenderHint(QtGui.QPainter.Antialiasing) layout = QtWidgets.QVBoxLayout() - layout.addWidget(self.chartView) + layout.addWidget(self.canvas) if self.show_controller: control_layout = QtWidgets.QHBoxLayout() @@ -256,14 +748,47 @@ def reset( self.last_batch_number = 0 self.is_running = False + @property + def log_scale(self): + """Returns True if the plot has a log scale for y-axis.""" + + return self._log_scale + + @log_scale.setter + def log_scale(self, val): + """Sets the scale of the y axis to log if True else linear.""" + + if isinstance(val, bool): + self._log_scale = val + + # Set the log scale on the canvas + self.canvas.log_scale = self._log_scale + + @property + def ignore_outliers(self): + """Returns True if the plot ignores outliers.""" + + return self._ignore_outliers + + @ignore_outliers.setter + def ignore_outliers(self, val): + """Sets whether to ignore outliers in the plot.""" + + if isinstance(val, bool): + self._ignore_outliers = val + + # Set the ignore_outliers on the canvas + self.canvas.ignore_outliers = self._ignore_outliers + def toggle_ignore_outliers(self): """Toggles whether to ignore outliers in chart scaling.""" + self.ignore_outliers = not self.ignore_outliers def toggle_log_scale(self): """Toggle whether to use log-scaled y-axis.""" + self.log_scale = not self.log_scale - self.update_y_axis() def set_batches_to_show(self, batches: str): """Set the number of batches to show on the x-axis. @@ -278,25 +803,6 @@ def set_batches_to_show(self, batches: str): else: self.batches_to_show = -1 - def update_y_axis(self): - """Update the y-axis when scale changes.""" - to = "log" if self.log_scale else "linear" - - # Remove other axes. - for name, axisY in self.axisY.items(): - if name != to: - if axisY in self.chart.axes(): - self.chart.removeAxis(axisY) - for series in self.chart.series(): - if axisY in series.attachedAxes(): - series.detachAxis(axisY) - - # Add axis. - axisY = self.axisY[to] - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisY) - def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): """Connect to ZMQ ports that listen to commands and updates. @@ -414,45 +920,67 @@ def add_datapoint(self, x: int, y: float, which: str): self.Y[-self.batches_to_show :], ) - points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] - self.series["batch"].replace(points) + # Set data, resize and redraw the plot + self._set_data_on_scatter(xs, ys, which) + self._resize_axes(xs, ys) + + else: - # Set X scale to show all points - dx = 0.5 - self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx) + if which == "val_loss": + if self.best_val_y is None or y < self.best_val_y: + self.best_val_x = x + self.best_val_y = y + self._set_data_on_scatter([x], [y], "val_loss_best") - if self.ignore_outliers: - dy = np.ptp(ys) * 0.02 - # Set Y scale to exclude outliers - q1, q3 = np.quantile(ys, (0.25, 0.75)) - iqr = q3 - q1 # interquartile range - low = q1 - iqr * 1.5 - high = q3 + iqr * 1.5 + # Add data and redraw the plot + self._add_data_to_plot(x, y, which) + self._redraw_plot() - low = max(low, min(ys) - dy) # keep within range of data - high = min(high, max(ys) + dy) - else: - # Set Y scale to show all points - dy = np.ptp(ys) * 0.02 - low = min(ys) - dy - high = max(ys) + dy + def _set_data_on_scatter(self, xs, ys, which): + """Add data to a scatter plot. - if self.log_scale: - low = max(low, 1e-8) # for log scale, low cannot be 0 + Not to be used with line plots. - self.chart.axisY().setRange(low, high) + Args: + xs: The x-coordinates of the data points. + ys: The y-coordinates of the data points. + which: The type of data point. Possible values are: + * "batch" + * "val_loss_best" + """ - else: - if which == "epoch_loss": - self.series["epoch_loss"].append(x, y) - self.series["epoch_loss_scatter"].append(x, y) - elif which == "val_loss": - self.series["val_loss"].append(x, y) - self.series["val_loss_scatter"].append(x, y) - if self.best_val_y is None or y < self.best_val_y: - self.best_val_x = x - self.best_val_y = y - self.series["val_loss_best"].replace([QtCore.QPointF(x, y)]) + self.canvas.set_data_on_scatter(xs, ys, which) + + def _add_data_to_plot(self, x, y, which): + """Add data to a line plot. + + Not to be used with scatter plots. + + Args: + x: The x-coordinate of the data point. + y: The y-coordinate of the data point. + which: The type of data point. Possible values are: + * "epoch_loss" + * "val_loss" + """ + + self.canvas.add_data_to_plot(x, y, which) + + def _redraw_plot(self): + """Redraw the plot.""" + + self.canvas.redraw_plot() + + def _resize_axes(self, x, y): + """Resize axes to fit data. + + This is only called when plotting batches. + + Args: + x: The x-coordinates of the data points. + y: The y-coordinates of the data points. + """ + self.canvas.resize_axes(x, y) def set_start_time(self, t0: float): """Mark the start flag and time of the run. @@ -469,35 +997,27 @@ def set_end(self): def update_runtime(self): """Update the title text with the current running time.""" + if self.is_timer_running: dt = perf_counter() - self.t0 dt_min, dt_sec = divmod(dt, 60) - title = f"Training Epoch {self.epoch + 1} / " - title += f"Runtime: {int(dt_min):02}:{int(dt_sec):02}" - if self.last_epoch_val_loss is not None: - if self.penultimate_epoch_val_loss is not None: - title += ( - f"
Mean Time per Epoch: " - f"{int(self.mean_epoch_time_min):02}:{int(self.mean_epoch_time_sec):02} / " - f"ETA Next 10 Epochs: {int(self.eta_ten_epochs_min)} min" - ) - if self.epoch_in_plateau_flag: - title += ( - f"
Epochs in Plateau: " - f"{self.epochs_in_plateau} / " - f"{self.config.optimization.early_stopping.plateau_patience}" - ) - title += ( - f"
Last Epoch Validation Loss: " - f"{self.last_epoch_val_loss:.3e}" - ) - if self.best_val_x is not None: - best_epoch = (self.best_val_x // self.epoch_size) + 1 - title += ( - f"
Best Epoch Validation Loss: " - f"{self.best_val_y:.3e} (epoch {best_epoch})" - ) - self.set_message(title) + + self.canvas.update_runtime_title( + epoch=self.epoch, + dt_min=dt_min, + dt_sec=dt_sec, + last_epoch_val_loss=self.last_epoch_val_loss, + penultimate_epoch_val_loss=self.penultimate_epoch_val_loss, + mean_epoch_time_min=self.mean_epoch_time_min, + mean_epoch_time_sec=self.mean_epoch_time_sec, + eta_ten_epochs_min=self.eta_ten_epochs_min, + epochs_in_plateau=self.epochs_in_plateau, + plateau_patience=self.config.optimization.early_stopping.plateau_patience, + epoch_in_plateau_flag=self.epoch_in_plateau_flag, + best_val_x=self.best_val_x, + best_val_y=self.best_val_y, + epoch_size=self.epoch_size, + ) @property def is_timer_running(self) -> bool: @@ -506,7 +1026,7 @@ def is_timer_running(self) -> bool: def set_message(self, text: str): """Set the chart title text.""" - self.chart.setTitle(text) + self.canvas.set_title(text) def check_messages( self, timeout: int = 10, times_to_check: int = 10, do_update: bool = True diff --git a/sleap/gui/widgets/mpl.py b/sleap/gui/widgets/mpl.py index a9b7fc838..890c1a67a 100644 --- a/sleap/gui/widgets/mpl.py +++ b/sleap/gui/widgets/mpl.py @@ -6,11 +6,10 @@ from qtpy import QtWidgets from matplotlib.figure import Figure -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as Canvas +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as Canvas import matplotlib -# Ensure using PyQt5 backend -matplotlib.use("QT5Agg") +matplotlib.use("QtAgg") class MplCanvas(Canvas): diff --git a/sleap/gui/widgets/training_monitor.py b/sleap/gui/widgets/training_monitor.py deleted file mode 100644 index ed405a747..000000000 --- a/sleap/gui/widgets/training_monitor.py +++ /dev/null @@ -1,566 +0,0 @@ -"""GUI for monitoring training progress interactively.""" - -import numpy as np -from time import perf_counter -from sleap.nn.config.training_job import TrainingJobConfig -import zmq -import jsonpickle -import logging -from typing import Optional -from qtpy import QtCore, QtWidgets, QtGui, QtCharts -import attr - -logger = logging.getLogger(__name__) - - -class LossViewer(QtWidgets.QMainWindow): - """Qt window for showing in-progress training metrics sent over ZMQ.""" - - on_epoch = QtCore.Signal() - - def __init__( - self, - zmq_context: Optional[zmq.Context] = None, - show_controller=True, - parent=None, - ): - super().__init__(parent) - - self.show_controller = show_controller - self.stop_button = None - self.cancel_button = None - self.canceled = False - - self.batches_to_show = -1 # -1 to show all - self.ignore_outliers = False - self.log_scale = True - self.message_poll_time_ms = 20 # ms - self.redraw_batch_time_ms = 500 # ms - self.last_redraw_batch = None - - self.reset() - self.setup_zmq(zmq_context) - - def __del__(self): - self.unbind() - - def close(self): - """Disconnect from ZMQ ports and close the window.""" - self.unbind() - super().close() - - def unbind(self): - """Disconnect from all ZMQ sockets.""" - if self.sub is not None: - self.sub.unbind(self.sub.LAST_ENDPOINT) - self.sub.close() - self.sub = None - - if self.zmq_ctrl is not None: - url = self.zmq_ctrl.LAST_ENDPOINT - self.zmq_ctrl.unbind(url) - self.zmq_ctrl.close() - self.zmq_ctrl = None - - # If we started out own zmq context, terminate it. - if not self.ctx_given and self.ctx is not None: - self.ctx.term() - self.ctx = None - - def reset( - self, - what: str = "", - config: TrainingJobConfig = attr.ib(factory=TrainingJobConfig), - ): - """Reset all chart series. - - Args: - what: String identifier indicating which job type the current run - corresponds to. - """ - self.chart = QtCharts.QChart() - - self.series = dict() - - COLOR_TRAIN = (18, 158, 220) - COLOR_VAL = (248, 167, 52) - COLOR_BEST_VAL = (151, 204, 89) - - self.series["batch"] = QtCharts.QScatterSeries() - self.series["batch"].setName("Batch Training Loss") - self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48)) - self.series["batch"].setMarkerSize(8.0) - self.series["batch"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["batch"]) - - self.series["epoch_loss"] = QtCharts.QLineSeries() - self.series["epoch_loss"].setName("Epoch Training Loss") - self.series["epoch_loss"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - pen = self.series["epoch_loss"].pen() - pen.setWidth(4) - self.series["epoch_loss"].setPen(pen) - self.chart.addSeries(self.series["epoch_loss"]) - - self.series["epoch_loss_scatter"] = QtCharts.QScatterSeries() - self.series["epoch_loss_scatter"].setColor(QtGui.QColor(*COLOR_TRAIN, 255)) - self.series["epoch_loss_scatter"].setMarkerSize(12.0) - self.series["epoch_loss_scatter"].setBorderColor( - QtGui.QColor(255, 255, 255, 25) - ) - self.chart.addSeries(self.series["epoch_loss_scatter"]) - - self.series["val_loss"] = QtCharts.QLineSeries() - self.series["val_loss"].setName("Epoch Validation Loss") - self.series["val_loss"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - pen = self.series["val_loss"].pen() - pen.setWidth(4) - self.series["val_loss"].setPen(pen) - self.chart.addSeries(self.series["val_loss"]) - - self.series["val_loss_scatter"] = QtCharts.QScatterSeries() - self.series["val_loss_scatter"].setColor(QtGui.QColor(*COLOR_VAL, 255)) - self.series["val_loss_scatter"].setMarkerSize(12.0) - self.series["val_loss_scatter"].setBorderColor(QtGui.QColor(255, 255, 255, 25)) - self.chart.addSeries(self.series["val_loss_scatter"]) - - self.series["val_loss_best"] = QtCharts.QScatterSeries() - self.series["val_loss_best"].setName("Best Validation Loss") - self.series["val_loss_best"].setColor(QtGui.QColor(*COLOR_BEST_VAL, 255)) - self.series["val_loss_best"].setMarkerSize(12.0) - self.series["val_loss_best"].setBorderColor(QtGui.QColor(32, 32, 32, 25)) - self.chart.addSeries(self.series["val_loss_best"]) - - axisX = QtCharts.QValueAxis() - axisX.setLabelFormat("%d") - axisX.setTitleText("Batches") - self.chart.addAxis(axisX, QtCore.Qt.AlignBottom) - - # Create the different Y axes that can be used. - self.axisY = dict() - - self.axisY["log"] = QtCharts.QLogValueAxis() - self.axisY["log"].setBase(10) - - self.axisY["linear"] = QtCharts.QValueAxis() - - # Apply settings that apply to all Y axes. - for axisY in self.axisY.values(): - axisY.setLabelFormat("%f") - axisY.setLabelsVisible(True) - axisY.setMinorTickCount(1) - axisY.setTitleText("Loss") - - # Use the default Y axis. - axisY = self.axisY["log"] if self.log_scale else self.axisY["linear"] - - # Add axes to chart and series. - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisX) - series.attachAxis(axisY) - - # Setup legend. - self.chart.legend().setVisible(True) - self.chart.legend().setAlignment(QtCore.Qt.AlignTop) - self.chart.legend().setMarkerShape(QtCharts.QLegend.MarkerShapeCircle) - - # Hide scatters for epoch and val loss from legend. - for s in ("epoch_loss_scatter", "val_loss_scatter"): - self.chart.legend().markers(self.series[s])[0].setVisible(False) - - self.chartView = QtCharts.QChartView(self.chart) - self.chartView.setRenderHint(QtGui.QPainter.Antialiasing) - layout = QtWidgets.QVBoxLayout() - layout.addWidget(self.chartView) - - if self.show_controller: - control_layout = QtWidgets.QHBoxLayout() - - field = QtWidgets.QCheckBox("Log Scale") - field.setChecked(self.log_scale) - field.stateChanged.connect(self.toggle_log_scale) - control_layout.addWidget(field) - - field = QtWidgets.QCheckBox("Ignore Outliers") - field.setChecked(self.ignore_outliers) - field.stateChanged.connect(self.toggle_ignore_outliers) - control_layout.addWidget(field) - - control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) - - # Add field for how many batches to show in chart. - field = QtWidgets.QComboBox() - self.batch_options = "200,1000,5000,All".split(",") - for opt in self.batch_options: - field.addItem(opt) - cur_opt_str = ( - "All" if self.batches_to_show < 0 else str(self.batches_to_show) - ) - if cur_opt_str in self.batch_options: - field.setCurrentText(cur_opt_str) - - # Set connection action for when user selects another option. - field.currentIndexChanged.connect( - lambda x: self.set_batches_to_show(self.batch_options[x]) - ) - - # Store field as property and add to layout. - self.batches_to_show_field = field - control_layout.addWidget(self.batches_to_show_field) - - control_layout.addStretch(1) - - self.stop_button = QtWidgets.QPushButton("Stop Early") - self.stop_button.clicked.connect(self.stop) - control_layout.addWidget(self.stop_button) - self.cancel_button = QtWidgets.QPushButton("Cancel Training") - self.cancel_button.clicked.connect(self.cancel) - control_layout.addWidget(self.cancel_button) - - widget = QtWidgets.QWidget() - widget.setLayout(control_layout) - layout.addWidget(widget) - - wid = QtWidgets.QWidget() - wid.setLayout(layout) - self.setCentralWidget(wid) - - self.config = config - self.X = [] - self.Y = [] - self.best_val_x = None - self.best_val_y = None - - self.t0 = None - self.mean_epoch_time_min = None - self.mean_epoch_time_sec = None - self.eta_ten_epochs_min = None - - self.current_job_output_type = what - self.epoch = 0 - self.epoch_size = 1 - self.epochs_in_plateau = 0 - self.last_epoch_val_loss = None - self.penultimate_epoch_val_loss = None - self.epoch_in_plateau_flag = False - self.last_batch_number = 0 - self.is_running = False - - def toggle_ignore_outliers(self): - """Toggles whether to ignore outliers in chart scaling.""" - self.ignore_outliers = not self.ignore_outliers - - def toggle_log_scale(self): - """Toggle whether to use log-scaled y-axis.""" - self.log_scale = not self.log_scale - self.update_y_axis() - - def set_batches_to_show(self, batches: str): - """Set the number of batches to show on the x-axis. - - Args: - batches: Number of batches as a string. If numeric, this will be converted - to an integer. If non-numeric string (e.g., "All"), then all batches - will be shown. - """ - if batches.isdigit(): - self.batches_to_show = int(batches) - else: - self.batches_to_show = -1 - - def update_y_axis(self): - """Update the y-axis when scale changes.""" - to = "log" if self.log_scale else "linear" - - # Remove other axes. - for name, axisY in self.axisY.items(): - if name != to: - if axisY in self.chart.axes(): - self.chart.removeAxis(axisY) - for series in self.chart.series(): - if axisY in series.attachedAxes(): - series.detachAxis(axisY) - - # Add axis. - axisY = self.axisY[to] - self.chart.addAxis(axisY, QtCore.Qt.AlignLeft) - for series in self.chart.series(): - series.attachAxis(axisY) - - def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): - """Connect to ZMQ ports that listen to commands and updates. - - Args: - zmq_context: The `zmq.Context` object to use for connections. A new one is - created if not specified and will be closed when the monitor exits. If - an existing one is provided, it will NOT be closed. - """ - # Keep track of whether we're using an existing context (which we won't close - # when done) or are creating our own (which we should close). - self.ctx_given = zmq_context is not None - self.ctx = zmq.Context() if zmq_context is None else zmq_context - - # Progress monitoring, SUBSCRIBER - self.sub = self.ctx.socket(zmq.SUB) - self.sub.subscribe("") - self.sub.bind("tcp://127.0.0.1:9001") - - # Controller, PUBLISHER - self.zmq_ctrl = None - if self.show_controller: - self.zmq_ctrl = self.ctx.socket(zmq.PUB) - self.zmq_ctrl.bind("tcp://127.0.0.1:9000") - - # Set timer to poll for messages. - self.timer = QtCore.QTimer() - self.timer.timeout.connect(self.check_messages) - self.timer.start(self.message_poll_time_ms) - - def cancel(self): - """Set the cancel flag.""" - self.canceled = True - if self.cancel_button is not None: - self.cancel_button.setText("Canceling...") - self.cancel_button.setEnabled(False) - - def stop(self): - """Send command to stop training.""" - if self.zmq_ctrl is not None: - # Send command to stop training. - logger.info("Sending command to stop training.") - self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) - - # Disable the button to prevent double messages. - if self.stop_button is not None: - self.stop_button.setText("Stopping...") - self.stop_button.setEnabled(False) - - def add_datapoint(self, x: int, y: float, which: str): - """Add a data point to graph. - - Args: - x: The batch number (out of all epochs, not just current), or epoch. - y: The loss value. - which: Type of data point we're adding. Possible values are: - * "batch" (loss for the batch) - * "epoch_loss" (loss for the entire epoch) - * "val_loss" (validation loss for the epoch) - """ - if which == "batch": - self.X.append(x) - self.Y.append(y) - - # Redraw batch at intervals (faster than plotting every batch). - draw_batch = False - if self.last_redraw_batch is None: - draw_batch = True - else: - dt = perf_counter() - self.last_redraw_batch - draw_batch = (dt * 1000) >= self.redraw_batch_time_ms - - if draw_batch: - self.last_redraw_batch = perf_counter() - if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: - xs, ys = self.X, self.Y - else: - xs, ys = ( - self.X[-self.batches_to_show :], - self.Y[-self.batches_to_show :], - ) - - points = [QtCore.QPointF(x, y) for x, y in zip(xs, ys) if y > 0] - self.series["batch"].replace(points) - - # Set X scale to show all points - dx = 0.5 - self.chart.axisX().setRange(min(xs) - dx, max(xs) + dx) - - if self.ignore_outliers: - dy = np.ptp(ys) * 0.02 - # Set Y scale to exclude outliers - q1, q3 = np.quantile(ys, (0.25, 0.75)) - iqr = q3 - q1 # interquartile range - low = q1 - iqr * 1.5 - high = q3 + iqr * 1.5 - - low = max(low, min(ys) - dy) # keep within range of data - high = min(high, max(ys) + dy) - else: - # Set Y scale to show all points - dy = np.ptp(ys) * 0.02 - low = min(ys) - dy - high = max(ys) + dy - - if self.log_scale: - low = max(low, 1e-8) # for log scale, low cannot be 0 - - self.chart.axisY().setRange(low, high) - - else: - if which == "epoch_loss": - self.series["epoch_loss"].append(x, y) - self.series["epoch_loss_scatter"].append(x, y) - elif which == "val_loss": - self.series["val_loss"].append(x, y) - self.series["val_loss_scatter"].append(x, y) - if self.best_val_y is None or y < self.best_val_y: - self.best_val_x = x - self.best_val_y = y - self.series["val_loss_best"].replace([QtCore.QPointF(x, y)]) - - def set_start_time(self, t0: float): - """Mark the start flag and time of the run. - - Args: - t0: Start time in seconds. - """ - self.t0 = t0 - self.is_running = True - - def set_end(self): - """Mark the end of the run.""" - self.is_running = False - - def update_runtime(self): - """Update the title text with the current running time.""" - if self.is_timer_running: - dt = perf_counter() - self.t0 - dt_min, dt_sec = divmod(dt, 60) - title = f"Training Epoch {self.epoch + 1} / " - title += f"Runtime: {int(dt_min):02}:{int(dt_sec):02}" - if self.last_epoch_val_loss is not None: - if self.penultimate_epoch_val_loss is not None: - title += ( - f"
Mean Time per Epoch: " - f"{int(self.mean_epoch_time_min):02}:{int(self.mean_epoch_time_sec):02} / " - f"ETA Next 10 Epochs: {int(self.eta_ten_epochs_min)} min" - ) - if self.epoch_in_plateau_flag: - title += ( - f"
Epochs in Plateau: " - f"{self.epochs_in_plateau} / " - f"{self.config.optimization.early_stopping.plateau_patience}" - ) - title += ( - f"
Last Epoch Validation Loss: " - f"{self.last_epoch_val_loss:.3e}" - ) - if self.best_val_x is not None: - best_epoch = (self.best_val_x // self.epoch_size) + 1 - title += ( - f"
Best Epoch Validation Loss: " - f"{self.best_val_y:.3e} (epoch {best_epoch})" - ) - self.set_message(title) - - @property - def is_timer_running(self) -> bool: - """Return True if the timer has started.""" - return self.t0 is not None and self.is_running - - def set_message(self, text: str): - """Set the chart title text.""" - self.chart.setTitle(text) - - def check_messages( - self, timeout: int = 10, times_to_check: int = 10, do_update: bool = True - ): - """Poll for ZMQ messages and adds any received data to graph. - - The message is a dictionary encoded as JSON: - * event - options include - * train_begin - * train_end - * epoch_begin - * epoch_end - * batch_end - * what - this should match the type of model we're training and - ensures that we ignore old messages when we start monitoring - a new training session (when we're training multiple types - of models in a sequence, as for the top-down pipeline). - * logs - dictionary with data relevant for plotting, can include - * loss - * val_loss - - Args: - timeout: Message polling timeout in milliseconds. This is how often we will - check for new command messages. - times_to_check: How many times to check for new messages in the queue before - going back to polling with a timeout. Helps to clear backlogs of - messages if necessary. - do_update: If True (the default), update the GUI text. - """ - if self.sub and self.sub.poll(timeout, zmq.POLLIN): - msg = jsonpickle.decode(self.sub.recv_string()) - - if msg["event"] == "train_begin": - self.set_start_time(perf_counter()) - self.current_job_output_type = msg["what"] - - # Make sure message matches current training job. - if msg.get("what", "") == self.current_job_output_type: - - if not self.is_timer_running: - # We must have missed the train_begin message, so start timer now. - self.set_start_time(perf_counter()) - - if msg["event"] == "train_end": - self.set_end() - elif msg["event"] == "epoch_begin": - self.epoch = msg["epoch"] - elif msg["event"] == "epoch_end": - self.epoch_size = max(self.epoch_size, self.last_batch_number + 1) - self.add_datapoint( - (self.epoch + 1) * self.epoch_size, - msg["logs"]["loss"], - "epoch_loss", - ) - if "val_loss" in msg["logs"].keys(): - # update variables and add points to plot - self.penultimate_epoch_val_loss = self.last_epoch_val_loss - self.last_epoch_val_loss = msg["logs"]["val_loss"] - self.add_datapoint( - (self.epoch + 1) * self.epoch_size, - msg["logs"]["val_loss"], - "val_loss", - ) - # calculate timing and flags at new epoch - if self.penultimate_epoch_val_loss is not None: - mean_epoch_time = (perf_counter() - self.t0) / ( - self.epoch + 1 - ) - self.mean_epoch_time_min, self.mean_epoch_time_sec = divmod( - mean_epoch_time, 60 - ) - self.eta_ten_epochs_min = (mean_epoch_time * 10) // 60 - - val_loss_delta = ( - self.penultimate_epoch_val_loss - - self.last_epoch_val_loss - ) - self.epoch_in_plateau_flag = ( - val_loss_delta - < self.config.optimization.early_stopping.plateau_min_delta - ) or (self.best_val_y < self.last_epoch_val_loss) - self.epochs_in_plateau = ( - self.epochs_in_plateau + 1 - if self.epoch_in_plateau_flag - else 0 - ) - self.on_epoch.emit() - elif msg["event"] == "batch_end": - self.last_batch_number = msg["batch"] - self.add_datapoint( - (self.epoch * self.epoch_size) + msg["batch"], - msg["logs"]["loss"], - "batch", - ) - - # Check for messages again (up to times_to_check times). - if times_to_check > 0: - self.check_messages( - timeout=timeout, times_to_check=times_to_check - 1, do_update=False - ) - - if do_update: - self.update_runtime() diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 9e4245b88..c3692637c 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -508,7 +508,7 @@ def setup_visualization( callbacks = [] try: - matplotlib.use("Qt5Agg") + matplotlib.use("QtAgg") except ImportError: print( "Unable to use Qt backend for matplotlib. " From 1370782877edc10346023fecbaf502e5bf3ce006 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Mon, 26 Aug 2024 10:03:53 -0700 Subject: [PATCH 24/27] Refactor `LossViewer` to use underscores for internal method names (#1919) Refactor LossViewer to use underscores for internal method names --- sleap/gui/widgets/monitor.py | 428 +++++++++++++++++------------------ tests/gui/test_monitor.py | 2 +- 2 files changed, 215 insertions(+), 215 deletions(-) diff --git a/sleap/gui/widgets/monitor.py b/sleap/gui/widgets/monitor.py index 5b0ce1ae8..fff8a0327 100644 --- a/sleap/gui/widgets/monitor.py +++ b/sleap/gui/widgets/monitor.py @@ -619,33 +619,47 @@ def __init__( self.canvas = None self.reset() - self.setup_zmq(zmq_context) + self._setup_zmq(zmq_context) def __del__(self): - self.unbind() + self._unbind() - def close(self): - """Disconnect from ZMQ ports and close the window.""" - self.unbind() - super().close() + @property + def is_timer_running(self) -> bool: + """Return True if the timer has started.""" + return self.t0 is not None and self.is_running - def unbind(self): - """Disconnect from all ZMQ sockets.""" - if self.sub is not None: - self.sub.unbind(self.sub.LAST_ENDPOINT) - self.sub.close() - self.sub = None + @property + def log_scale(self): + """Returns True if the plot has a log scale for y-axis.""" - if self.zmq_ctrl is not None: - url = self.zmq_ctrl.LAST_ENDPOINT - self.zmq_ctrl.unbind(url) - self.zmq_ctrl.close() - self.zmq_ctrl = None + return self._log_scale - # If we started out own zmq context, terminate it. - if not self.ctx_given and self.ctx is not None: - self.ctx.term() - self.ctx = None + @log_scale.setter + def log_scale(self, val): + """Sets the scale of the y axis to log if True else linear.""" + + if isinstance(val, bool): + self._log_scale = val + + # Set the log scale on the canvas + self.canvas.log_scale = self._log_scale + + @property + def ignore_outliers(self): + """Returns True if the plot ignores outliers.""" + + return self._ignore_outliers + + @ignore_outliers.setter + def ignore_outliers(self, val): + """Sets whether to ignore outliers in the plot.""" + + if isinstance(val, bool): + self._ignore_outliers = val + + # Set the ignore_outliers on the canvas + self.canvas.ignore_outliers = self._ignore_outliers def reset( self, @@ -680,12 +694,12 @@ def reset( field = QtWidgets.QCheckBox("Log Scale") field.setChecked(self.log_scale) - field.stateChanged.connect(self.toggle_log_scale) + field.stateChanged.connect(self._toggle_log_scale) control_layout.addWidget(field) field = QtWidgets.QCheckBox("Ignore Outliers") field.setChecked(self.ignore_outliers) - field.stateChanged.connect(self.toggle_ignore_outliers) + field.stateChanged.connect(self._toggle_ignore_outliers) control_layout.addWidget(field) control_layout.addWidget(QtWidgets.QLabel("Batches to Show:")) @@ -703,7 +717,7 @@ def reset( # Set connection action for when user selects another option. field.currentIndexChanged.connect( - lambda x: self.set_batches_to_show(self.batch_options[x]) + lambda x: self._set_batches_to_show(self.batch_options[x]) ) # Store field as property and add to layout. @@ -713,10 +727,10 @@ def reset( control_layout.addStretch(1) self.stop_button = QtWidgets.QPushButton("Stop Early") - self.stop_button.clicked.connect(self.stop) + self.stop_button.clicked.connect(self._stop) control_layout.addWidget(self.stop_button) self.cancel_button = QtWidgets.QPushButton("Cancel Training") - self.cancel_button.clicked.connect(self.cancel) + self.cancel_button.clicked.connect(self._cancel) control_layout.addWidget(self.cancel_button) widget = QtWidgets.QWidget() @@ -748,62 +762,16 @@ def reset( self.last_batch_number = 0 self.is_running = False - @property - def log_scale(self): - """Returns True if the plot has a log scale for y-axis.""" - - return self._log_scale - - @log_scale.setter - def log_scale(self, val): - """Sets the scale of the y axis to log if True else linear.""" - - if isinstance(val, bool): - self._log_scale = val - - # Set the log scale on the canvas - self.canvas.log_scale = self._log_scale - - @property - def ignore_outliers(self): - """Returns True if the plot ignores outliers.""" - - return self._ignore_outliers - - @ignore_outliers.setter - def ignore_outliers(self, val): - """Sets whether to ignore outliers in the plot.""" - - if isinstance(val, bool): - self._ignore_outliers = val - - # Set the ignore_outliers on the canvas - self.canvas.ignore_outliers = self._ignore_outliers - - def toggle_ignore_outliers(self): - """Toggles whether to ignore outliers in chart scaling.""" - - self.ignore_outliers = not self.ignore_outliers - - def toggle_log_scale(self): - """Toggle whether to use log-scaled y-axis.""" - - self.log_scale = not self.log_scale - - def set_batches_to_show(self, batches: str): - """Set the number of batches to show on the x-axis. + def set_message(self, text: str): + """Set the chart title text.""" + self.canvas.set_title(text) - Args: - batches: Number of batches as a string. If numeric, this will be converted - to an integer. If non-numeric string (e.g., "All"), then all batches - will be shown. - """ - if batches.isdigit(): - self.batches_to_show = int(batches) - else: - self.batches_to_show = -1 + def close(self): + """Disconnect from ZMQ ports and close the window.""" + self._unbind() + super().close() - def setup_zmq(self, zmq_context: Optional[zmq.Context] = None): + def _setup_zmq(self, zmq_context: Optional[zmq.Context] = None): """Connect to ZMQ ports that listen to commands and updates. Args: @@ -865,124 +833,23 @@ def find_free_port(port: int, zmq_context: zmq.Context): # Set timer to poll for messages. self.timer = QtCore.QTimer() - self.timer.timeout.connect(self.check_messages) + self.timer.timeout.connect(self._check_messages) self.timer.start(self.message_poll_time_ms) - def cancel(self): - """Set the cancel flag.""" - self.canceled = True - if self.cancel_button is not None: - self.cancel_button.setText("Canceling...") - self.cancel_button.setEnabled(False) - - def stop(self): - """Send command to stop training.""" - if self.zmq_ctrl is not None: - # Send command to stop training. - logger.info("Sending command to stop training.") - self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) - - # Disable the button to prevent double messages. - if self.stop_button is not None: - self.stop_button.setText("Stopping...") - self.stop_button.setEnabled(False) - - def add_datapoint(self, x: int, y: float, which: str): - """Add a data point to graph. + def _set_batches_to_show(self, batches: str): + """Set the number of batches to show on the x-axis. Args: - x: The batch number (out of all epochs, not just current), or epoch. - y: The loss value. - which: Type of data point we're adding. Possible values are: - * "batch" (loss for the batch) - * "epoch_loss" (loss for the entire epoch) - * "val_loss" (validation loss for the epoch) + batches: Number of batches as a string. If numeric, this will be converted + to an integer. If non-numeric string (e.g., "All"), then all batches + will be shown. """ - if which == "batch": - self.X.append(x) - self.Y.append(y) - - # Redraw batch at intervals (faster than plotting every batch). - draw_batch = False - if self.last_redraw_batch is None: - draw_batch = True - else: - dt = perf_counter() - self.last_redraw_batch - draw_batch = (dt * 1000) >= self.redraw_batch_time_ms - - if draw_batch: - self.last_redraw_batch = perf_counter() - if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: - xs, ys = self.X, self.Y - else: - xs, ys = ( - self.X[-self.batches_to_show :], - self.Y[-self.batches_to_show :], - ) - - # Set data, resize and redraw the plot - self._set_data_on_scatter(xs, ys, which) - self._resize_axes(xs, ys) - + if batches.isdigit(): + self.batches_to_show = int(batches) else: + self.batches_to_show = -1 - if which == "val_loss": - if self.best_val_y is None or y < self.best_val_y: - self.best_val_x = x - self.best_val_y = y - self._set_data_on_scatter([x], [y], "val_loss_best") - - # Add data and redraw the plot - self._add_data_to_plot(x, y, which) - self._redraw_plot() - - def _set_data_on_scatter(self, xs, ys, which): - """Add data to a scatter plot. - - Not to be used with line plots. - - Args: - xs: The x-coordinates of the data points. - ys: The y-coordinates of the data points. - which: The type of data point. Possible values are: - * "batch" - * "val_loss_best" - """ - - self.canvas.set_data_on_scatter(xs, ys, which) - - def _add_data_to_plot(self, x, y, which): - """Add data to a line plot. - - Not to be used with scatter plots. - - Args: - x: The x-coordinate of the data point. - y: The y-coordinate of the data point. - which: The type of data point. Possible values are: - * "epoch_loss" - * "val_loss" - """ - - self.canvas.add_data_to_plot(x, y, which) - - def _redraw_plot(self): - """Redraw the plot.""" - - self.canvas.redraw_plot() - - def _resize_axes(self, x, y): - """Resize axes to fit data. - - This is only called when plotting batches. - - Args: - x: The x-coordinates of the data points. - y: The y-coordinates of the data points. - """ - self.canvas.resize_axes(x, y) - - def set_start_time(self, t0: float): + def _set_start_time(self, t0: float): """Mark the start flag and time of the run. Args: @@ -991,11 +858,7 @@ def set_start_time(self, t0: float): self.t0 = t0 self.is_running = True - def set_end(self): - """Mark the end of the run.""" - self.is_running = False - - def update_runtime(self): + def _update_runtime(self): """Update the title text with the current running time.""" if self.is_timer_running: @@ -1019,16 +882,7 @@ def update_runtime(self): epoch_size=self.epoch_size, ) - @property - def is_timer_running(self) -> bool: - """Return True if the timer has started.""" - return self.t0 is not None and self.is_running - - def set_message(self, text: str): - """Set the chart title text.""" - self.canvas.set_title(text) - - def check_messages( + def _check_messages( self, timeout: int = 10, times_to_check: int = 10, do_update: bool = True ): """Poll for ZMQ messages and adds any received data to graph. @@ -1060,7 +914,7 @@ def check_messages( msg = jsonpickle.decode(self.sub.recv_string()) if msg["event"] == "train_begin": - self.set_start_time(perf_counter()) + self._set_start_time(perf_counter()) self.current_job_output_type = msg["what"] # Make sure message matches current training job. @@ -1068,15 +922,15 @@ def check_messages( if not self.is_timer_running: # We must have missed the train_begin message, so start timer now. - self.set_start_time(perf_counter()) + self._set_start_time(perf_counter()) if msg["event"] == "train_end": - self.set_end() + self._set_end() elif msg["event"] == "epoch_begin": self.epoch = msg["epoch"] elif msg["event"] == "epoch_end": self.epoch_size = max(self.epoch_size, self.last_batch_number + 1) - self.add_datapoint( + self._add_datapoint( (self.epoch + 1) * self.epoch_size, msg["logs"]["loss"], "epoch_loss", @@ -1085,7 +939,7 @@ def check_messages( # update variables and add points to plot self.penultimate_epoch_val_loss = self.last_epoch_val_loss self.last_epoch_val_loss = msg["logs"]["val_loss"] - self.add_datapoint( + self._add_datapoint( (self.epoch + 1) * self.epoch_size, msg["logs"]["val_loss"], "val_loss", @@ -1116,7 +970,7 @@ def check_messages( self.on_epoch.emit() elif msg["event"] == "batch_end": self.last_batch_number = msg["batch"] - self.add_datapoint( + self._add_datapoint( (self.epoch * self.epoch_size) + msg["batch"], msg["logs"]["loss"], "batch", @@ -1124,9 +978,155 @@ def check_messages( # Check for messages again (up to times_to_check times). if times_to_check > 0: - self.check_messages( + self._check_messages( timeout=timeout, times_to_check=times_to_check - 1, do_update=False ) if do_update: - self.update_runtime() + self._update_runtime() + + def _add_datapoint(self, x: int, y: float, which: str): + """Add a data point to graph. + + Args: + x: The batch number (out of all epochs, not just current), or epoch. + y: The loss value. + which: Type of data point we're adding. Possible values are: + * "batch" (loss for the batch) + * "epoch_loss" (loss for the entire epoch) + * "val_loss" (validation loss for the epoch) + """ + if which == "batch": + self.X.append(x) + self.Y.append(y) + + # Redraw batch at intervals (faster than plotting every batch). + draw_batch = False + if self.last_redraw_batch is None: + draw_batch = True + else: + dt = perf_counter() - self.last_redraw_batch + draw_batch = (dt * 1000) >= self.redraw_batch_time_ms + + if draw_batch: + self.last_redraw_batch = perf_counter() + if self.batches_to_show < 0 or len(self.X) < self.batches_to_show: + xs, ys = self.X, self.Y + else: + xs, ys = ( + self.X[-self.batches_to_show :], + self.Y[-self.batches_to_show :], + ) + + # Set data, resize and redraw the plot + self._set_data_on_scatter(xs, ys, which) + self._resize_axes(xs, ys) + + else: + + if which == "val_loss": + if self.best_val_y is None or y < self.best_val_y: + self.best_val_x = x + self.best_val_y = y + self._set_data_on_scatter([x], [y], "val_loss_best") + + # Add data and redraw the plot + self._add_data_to_plot(x, y, which) + self._redraw_plot() + + def _set_data_on_scatter(self, xs, ys, which): + """Add data to a scatter plot. + + Not to be used with line plots. + + Args: + xs: The x-coordinates of the data points. + ys: The y-coordinates of the data points. + which: The type of data point. Possible values are: + * "batch" + * "val_loss_best" + """ + + self.canvas.set_data_on_scatter(xs, ys, which) + + def _add_data_to_plot(self, x, y, which): + """Add data to a line plot. + + Not to be used with scatter plots. + + Args: + x: The x-coordinate of the data point. + y: The y-coordinate of the data point. + which: The type of data point. Possible values are: + * "epoch_loss" + * "val_loss" + """ + + self.canvas.add_data_to_plot(x, y, which) + + def _redraw_plot(self): + """Redraw the plot.""" + + self.canvas.redraw_plot() + + def _resize_axes(self, x, y): + """Resize axes to fit data. + + This is only called when plotting batches. + + Args: + x: The x-coordinates of the data points. + y: The y-coordinates of the data points. + """ + self.canvas.resize_axes(x, y) + + def _toggle_ignore_outliers(self): + """Toggles whether to ignore outliers in chart scaling.""" + + self.ignore_outliers = not self.ignore_outliers + + def _toggle_log_scale(self): + """Toggle whether to use log-scaled y-axis.""" + + self.log_scale = not self.log_scale + + def _stop(self): + """Send command to stop training.""" + if self.zmq_ctrl is not None: + # Send command to stop training. + logger.info("Sending command to stop training.") + self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) + + # Disable the button to prevent double messages. + if self.stop_button is not None: + self.stop_button.setText("Stopping...") + self.stop_button.setEnabled(False) + + def _cancel(self): + """Set the cancel flag.""" + self.canceled = True + if self.cancel_button is not None: + self.cancel_button.setText("Canceling...") + self.cancel_button.setEnabled(False) + + def _unbind(self): + """Disconnect from all ZMQ sockets.""" + if self.sub is not None: + self.sub.unbind(self.sub.LAST_ENDPOINT) + self.sub.close() + self.sub = None + + if self.zmq_ctrl is not None: + url = self.zmq_ctrl.LAST_ENDPOINT + self.zmq_ctrl.unbind(url) + self.zmq_ctrl.close() + self.zmq_ctrl = None + + # If we started out own zmq context, terminate it. + if not self.ctx_given and self.ctx is not None: + self.ctx.term() + self.ctx = None + + def _set_end(self): + """Mark the end of the run.""" + self.is_running = False diff --git a/tests/gui/test_monitor.py b/tests/gui/test_monitor.py index 7ea81d6dc..e0abea692 100644 --- a/tests/gui/test_monitor.py +++ b/tests/gui/test_monitor.py @@ -30,7 +30,7 @@ def test_monitor_release(qtbot, min_centroid_model_path): # Enter "bes_val_x" conditional win.best_val_x = 0 win.best_val_y = win.last_epoch_val_loss - win.update_runtime() + win._update_runtime() win.close() From fa1c1b7838fadbd00ef99b1af733b67858d56344 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Fri, 30 Aug 2024 19:38:39 -0700 Subject: [PATCH 25/27] Manually handle `Instance.from_predicted` structuring when not `None` (#1930) --- sleap/instance.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index 67e96f330..08a5c6ae6 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1232,37 +1232,38 @@ def structure_points(x, type): def structure_instances_list(x, type): inst_list = [] for inst_data in x: - if "score" in inst_data.keys(): - inst = converter.structure(inst_data, PredictedInstance) - else: - if ( - "from_predicted" in inst_data - and inst_data["from_predicted"] is not None - ): - inst_data["from_predicted"] = converter.structure( - inst_data["from_predicted"], PredictedInstance - ) - inst = converter.structure(inst_data, Instance) + inst = structure_instance(inst_data, type) inst_list.append(inst) return inst_list + def structure_instance(inst_data, type): + """Structure hook for Instance and PredictedInstance objects.""" + from_predicted = None + + if "score" in inst_data.keys(): + inst = converter.structure(inst_data, PredictedInstance) + else: + if ( + "from_predicted" in inst_data + and inst_data["from_predicted"] is not None + ): + from_predicted = converter.structure( + inst_data["from_predicted"], PredictedInstance + ) + # Remove the from_predicted key. We'll add it back afterwards. + inst_data["from_predicted"] = None + + # Structure the instance data, then add the from_predicted attribute. + inst = converter.structure(inst_data, Instance) + inst.from_predicted = from_predicted + return inst + converter.register_structure_hook( Union[List[Instance], List[PredictedInstance]], structure_instances_list ) converter.register_structure_hook(InstancesList, structure_instances_list) - # Structure forward reference for PredictedInstance for the Instance.from_predicted - # attribute. - converter.register_structure_hook_func( - lambda t: t.__class__ is ForwardRef, - lambda v, t: converter.structure(v, t.__forward_value__), - ) - # converter.register_structure_hook( - # ForwardRef("PredictedInstance"), - # lambda x, _: converter.structure(x, PredictedInstance), - # ) - # We can register structure hooks for point arrays that do nothing # because Instance can have a dict of points passed to it in place of # a PointArray From 35463a1ddf7649ab813d36f46680dad5eaf3edfc Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Fri, 30 Aug 2024 19:39:11 -0700 Subject: [PATCH 26/27] Use `tf.math.mod` instead of `%` (#1931) --- sleap/nn/peak_finding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/peak_finding.py b/sleap/nn/peak_finding.py index 84dca00ae..e1fb43a6e 100644 --- a/sleap/nn/peak_finding.py +++ b/sleap/nn/peak_finding.py @@ -221,7 +221,7 @@ def find_global_peaks_rough( channels = tf.cast(tf.shape(cms)[-1], tf.int64) total_peaks = tf.cast(tf.shape(argmax_cols)[0], tf.int64) sample_subs = tf.range(total_peaks, dtype=tf.int64) // channels - channel_subs = tf.range(total_peaks, dtype=tf.int64) % channels + channel_subs = tf.math.mod(tf.range(total_peaks, dtype=tf.int64), channels) # Gather subscripts. peak_subs = tf.stack([sample_subs, argmax_rows, argmax_cols, channel_subs], axis=1) From 83d6bc04ce3f0c93a8af9a7536db39c60affc1d8 Mon Sep 17 00:00:00 2001 From: MweinbergUmass <143860933+MweinbergUmass@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:50:28 -0700 Subject: [PATCH 27/27] Option for Max Stride to be 128 (#1941) Co-authored-by: Max Weinberg --- sleap/config/training_editor_form.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sleap/config/training_editor_form.yaml b/sleap/config/training_editor_form.yaml index eabfc3940..7d7972892 100644 --- a/sleap/config/training_editor_form.yaml +++ b/sleap/config/training_editor_form.yaml @@ -44,7 +44,7 @@ model: label: Max Stride name: model.backbone.hourglass.max_stride type: list - options: 1,2,4,8,16,32,64 + options: 1,2,4,8,16,32,64,128 # - default: 4 # help: Determines the number of upsampling blocks in the network. # label: Output Stride @@ -81,7 +81,7 @@ model: label: Max Stride name: model.backbone.leap.max_stride type: list - options: 2,4,8,16,32,64 + options: 2,4,8,16,32,64,128 # - default: 1 # help: Determines the number of upsampling blocks in the network. # label: Output Stride @@ -190,7 +190,7 @@ model: label: Max Stride name: model.backbone.resnet.max_stride type: list - options: 2,4,8,16,32,64 + options: 2,4,8,16,32,64,128 # - default: 4 # help: Stride of the final output. If the upsampling branch is not defined, the # output stride is controlled via dilated convolutions or reduced pooling in the @@ -250,7 +250,7 @@ model: label: Max Stride name: model.backbone.unet.max_stride type: list - options: 2,4,8,16,32,64 + options: 2,4,8,16,32,64,128 # - default: 1 # help: Determines the number of upsampling blocks in the network. # label: Output Stride