From e1b8c62bad13fd071bbdb8f0cad3b911a9d99adb Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Mon, 8 Feb 2021 11:38:28 -0500 Subject: [PATCH] Miscellaneous QOL (#467) Pre-1.1.0 update features (changelist in #467) --- .conda/bld.bat | 1 + .conda/build.sh | 1 + docs/conf.py | 6 +- requirements.txt | 3 +- setup.py | 10 +- sleap/__init__.py | 5 +- sleap/config/shortcuts.yaml | 22 +- sleap/config/training_editor_form.yaml | 18 +- sleap/gui/app.py | 203 +++- sleap/gui/commands.py | 239 +++- sleap/gui/dataviews.py | 21 +- sleap/gui/dialogs/shortcuts.py | 47 +- sleap/gui/learning/dialog.py | 179 ++- sleap/gui/learning/runners.py | 208 +++- sleap/gui/learning/scopedkeydict.py | 12 + sleap/gui/overlays/instance.py | 4 +- sleap/gui/shortcuts.py | 7 +- sleap/instance.py | 350 +++--- sleap/io/dataset.py | 335 +++++- sleap/io/format/hdf5.py | 2 + sleap/io/video.py | 58 +- sleap/nn/config/__init__.py | 2 +- sleap/nn/config/data.py | 28 +- sleap/nn/config/outputs.py | 7 + sleap/nn/config/training_job.py | 46 +- sleap/nn/data/pipelines.py | 12 +- sleap/nn/data/providers.py | 71 +- sleap/nn/data/training.py | 52 + sleap/nn/evals.py | 10 +- sleap/nn/inference.py | 1446 ++++++++++++------------ sleap/nn/monitor.py | 16 +- sleap/nn/training.py | 220 +++- sleap/prefs.py | 15 +- sleap/skeleton.py | 110 +- sleap/util.py | 16 +- sleap/version.py | 19 + tests/gui/test_commands.py | 21 +- tests/io/test_video.py | 9 +- tests/nn/data/test_data_training.py | 123 +- tests/nn/test_inference.py | 14 +- tests/test_skeleton.py | 10 - 41 files changed, 2602 insertions(+), 1376 deletions(-) diff --git a/.conda/bld.bat b/.conda/bld.bat index 5aea27a19..040a42587 100644 --- a/.conda/bld.bat +++ b/.conda/bld.bat @@ -38,6 +38,7 @@ pip install jsmin pip install seaborn pip install pykalman==0.9.5 pip install segmentation-models==1.0.1 +pip install rich==9.10.0 rem # Use and update environment.yml call to install pip dependencies. This is slick. rem # While environment.yml contains the non pip dependencies, the only thing left diff --git a/.conda/build.sh b/.conda/build.sh index f43797a54..984701e19 100644 --- a/.conda/build.sh +++ b/.conda/build.sh @@ -36,6 +36,7 @@ pip install jsmin pip install seaborn pip install pykalman==0.9.5 pip install segmentation-models==1.0.1 +pip install rich==9.10.0 pip install setuptools-scm diff --git a/docs/conf.py b/docs/conf.py index 6eb8197cf..5f2580a18 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -89,15 +89,13 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # -html_theme_options = { - 'font_size': '15px' -} +html_theme_options = {"font_size": "15px"} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/requirements.txt b/requirements.txt index be7a19828..67364300f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,5 @@ qimage2ndarray==1.8 jsmin seaborn pykalman==0.9.5 -segmentation-models==1.0.1 \ No newline at end of file +segmentation-models==1.0.1 +rich==9.10.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 67893fb4f..2dbaadf51 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ # Get the sleap version with open(path.join(here, "sleap/version.py")) as f: version_file = f.read() - sleap_version = re.search("\d.+(?=['\"])", version_file).group(0) + sleap_version = re.search('__version__ = "([0-9\\.a]+)"', version_file).group(1) def get_requirements(require_name=None): @@ -31,11 +31,13 @@ def get_requirements(require_name=None): version=sleap_version, setup_requires=["setuptools_scm"], install_requires=get_requirements(), - extras_require={"dev": get_requirements("dev"),}, - description="SLEAP (Social LEAP Estimates Animal Pose) is a deep learning framework for estimating animal pose.", + extras_require={ + "dev": get_requirements("dev"), + }, + description="SLEAP (Social LEAP Estimates Animal Poses) is a deep learning framework for animal pose tracking.", long_description=long_description, long_description_content_type="text/x-rst", - author="Talmo Pereira, David Turner, Nat Tabris", + author="Talmo Pereira, Arie Matsliah, David Turner, Nat Tabris", author_email="talmo@princeton.edu", project_urls={ "Documentation": "https://sleap.ai/#sleap", diff --git a/sleap/__init__.py b/sleap/__init__.py index 60de74ae5..3b6aa14a5 100644 --- a/sleap/__init__.py +++ b/sleap/__init__.py @@ -5,8 +5,9 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO) # Import submodules we want available at top-level +from sleap.version import __version__, versions from sleap.io.dataset import Labels, load_file -from sleap.io.video import Video +from sleap.io.video import Video, load_video from sleap.instance import LabeledFrame, Instance, PredictedInstance, Track from sleap.skeleton import Skeleton import sleap.nn @@ -15,4 +16,4 @@ from sleap.nn.inference import load_model from sleap.nn.system import use_cpu_only, disable_preallocation from sleap.nn.system import summary as system_summary -from sleap.version import __version__ +from sleap.nn.config import TrainingJobConfig, load_config diff --git a/sleap/config/shortcuts.yaml b/sleap/config/shortcuts.yaml index 8123231ce..e4a86436a 100644 --- a/sleap/config/shortcuts.yaml +++ b/sleap/config/shortcuts.yaml @@ -1,7 +1,7 @@ add instance: Ctrl+I add videos: Ctrl+A clear selection: Esc -close: QKeySequence.Close +close: Ctrl+Q color predicted: delete area: Ctrl+K delete clip: @@ -11,19 +11,19 @@ export clip: fit: Ctrl+= goto frame: Ctrl+J goto marked: Ctrl+Shift+M -goto next suggestion: QKeySequence.FindNext +goto next suggestion: Space goto next track spawn: Ctrl+E -goto next user: Ctrl+> -goto next labeled: Ctrl+. -goto prev suggestion: QKeySequence.FindPrevious -goto prev labeled: +goto next user: Ctrl+U +goto next labeled: Alt+Right +goto prev suggestion: Shift+Space +goto prev labeled: Alt+Left learning: Ctrl+L mark frame: Ctrl+M new: Ctrl+N next video: QKeySequence.Forward open: Ctrl+O prev video: QKeySequence.Back -save as: QKeySequence.SaveAs +save as: Ctrl+Shift+S save: Ctrl+S select next: '`' select to frame: Ctrl+Shift+J @@ -33,7 +33,7 @@ show trails: transpose: Ctrl+T frame next: Right frame prev: Left -frame next medium step: Down -frame prev medium step: Up -frame next large step: Space -frame prev large step: / +frame next medium step: Ctrl+Right +frame prev medium step: Ctrl+Left +frame next large step: Ctrl+Alt+Right +frame prev large step: Ctrl+Alt+Left diff --git a/sleap/config/training_editor_form.yaml b/sleap/config/training_editor_form.yaml index 44cafe008..d362e8cee 100644 --- a/sleap/config/training_editor_form.yaml +++ b/sleap/config/training_editor_form.yaml @@ -466,6 +466,13 @@ augmentation: label: Scale Max name: optimization.augmentation_config.scale_max type: double +- label: Random flip + default: none + type: list + name: optimization.augmentation_config.random_flip + options: none,horizontal,vertical + default: none + help: 'Randomly reflect images and instances. IMPORTANT: Left/right symmetric nodes must be indicated in the skeleton or this will lead to incorrect results!' - default: false help: If True, uniformly distributed noise will be added to the image. This is effectively adding a different random value to each pixel to simulate shot noise. See `imgaug.augmenters.arithmetic.AddElementwise`. @@ -533,17 +540,6 @@ augmentation: label: Brightness Max Val name: optimization.augmentation_config.brightness_max_val type: double -- name: optimization.augmentation_config.random_flip - label: Random flip - help: 'Randomly reflect images and instances. IMPORTANT: Left/right symmetric nodes must be indicated in the skeleton or this will lead to incorrect results!' - type: bool - default: false -- name: optimization.augmentation_config.flip_horizontal - label: Flip left/right - help: Flip images horizontally when randomly reflecting. If unchecked, flipping will - reflect images up/down. - type: bool - default: true optimization: - default: 8 diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 9152825e6..859543e3a 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -106,14 +106,14 @@ class MainWindow(QMainWindow): whether to show node labels, etc. """ - def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): + def __init__( + self, labels_path: Optional[str] = None, reset: bool = False, *args, **kwargs + ): """Initialize the app. Args: labels_path: Path to saved :class:`Labels` dataset. - - Returns: - None. + reset: If `True`, reset preferences to default (including window state). """ super(MainWindow, self).__init__(*args, **kwargs) @@ -137,14 +137,24 @@ def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): self.state["filename"] = None self.state["show labels"] = True self.state["show edges"] = True - self.state["edge style"] = "Line" + self.state["edge style"] = prefs["edge style"] self.state["fit"] = False self.state["color predicted"] = prefs["color predicted"] + self.state["marker size"] = prefs["marker size"] + self.state["propagate track labels"] = prefs["propagate track labels"] + self.state.connect("marker size", self.plotFrame) self.release_checker = ReleaseChecker() self._initialize_gui() + if reset: + print("Reseting GUI state and preferences...") + prefs.reset_to_default() + elif len(prefs["window state"]) > 0: + print("Restoring GUI state...") + self.restoreState(prefs["window state"]) + if labels_path: self.loadProjectFile(labels_path) else: @@ -174,7 +184,17 @@ def event(self, e: QEvent) -> bool: return super().event(e) def closeEvent(self, event): - """Closes application window, prompting for saving as needed.""" + """Close application window, prompting for saving as needed.""" + # Save window state. + prefs["window state"] = self.saveState() + prefs["marker size"] = self.state["marker size"] + prefs["edge style"] = self.state["edge style"] + prefs["propagate track labels"] = self.state["propagate track labels"] + prefs["color predicted"] = self.state["color predicted"] + + # Save preferences. + prefs.save() + if not self.state["has_changes"]: # No unsaved changes, so accept event (close) event.accept() @@ -273,8 +293,9 @@ def connect_check(key): # add checkable menu item connected to state variable def add_menu_check_item(menu, key: str, name: str): - add_menu_item(menu, key, name, lambda: self.state.toggle(key)) + menu_item = add_menu_item(menu, key, name, lambda: self.state.toggle(key)) connect_check(key) + return menu_item # check and uncheck submenu items def _menu_check_single(menu, item_text): @@ -359,8 +380,6 @@ def add_submenu_choices(menu, title, options, key): fileMenu.addSeparator() add_menu_item(fileMenu, "save", "Save", self.commands.saveProject) add_menu_item(fileMenu, "save as", "Save As...", self.commands.saveProjectAs) - - fileMenu.addSeparator() add_menu_item( fileMenu, "export analysis", @@ -368,6 +387,11 @@ def add_submenu_choices(menu, title, options, key): self.commands.exportAnalysisFile, ) + fileMenu.addSeparator() + add_menu_item( + fileMenu, "reset prefs", "Reset preferences to defaults...", self.resetPrefs + ) + fileMenu.addSeparator() add_menu_item(fileMenu, "close", "Quit", self.close) @@ -489,10 +513,17 @@ def prev_vid(): key="edge style", ) + add_submenu_choices( + menu=viewMenu, + title="Node Marker Size", + options=(1, 4, 6, 8, 12), + key="marker size", + ) + add_submenu_choices( menu=viewMenu, title="Trail Length", - options=(0, 10, 20, 50, 100, 200, 500), + options=(0, 10, 50, 100, 250), key="trail_length", ) @@ -627,17 +658,28 @@ def new_instance_menu_action(): self.commands.deleteFrameLimitPredictions, ) - labelMenu.addSeparator() + # labelMenu.addSeparator() - self.track_menu = labelMenu.addMenu("Set Instance Track") + ### Tracks Menu ### + + tracksMenu = self.menuBar().addMenu("Tracks") + self.track_menu = tracksMenu.addMenu("Set Instance Track") + self.delete_tracks_menu = tracksMenu.addMenu("Delete Track") + self.delete_tracks_menu.setEnabled(False) + add_menu_check_item( + tracksMenu, "propagate track labels", "Propagate Track Labels" + ).setToolTip( + "If enabled, setting a track will also apply to subsequent instances of " + "the same track." + ) add_menu_item( - labelMenu, + tracksMenu, "transpose", "Transpose Instance Tracks", self.commands.transposeInstance, ) add_menu_item( - labelMenu, + tracksMenu, "delete track", "Delete Instance and Track", self.commands.deleteSelectedInstanceTrack, @@ -678,10 +720,21 @@ def new_instance_menu_action(): ) predictionMenu.addSeparator() + + training_package_menu = predictionMenu.addMenu("Export Training Package...") add_menu_item( - predictionMenu, + training_package_menu, + "export user labels package", + "Labeled frames", + self.commands.exportUserLabelsPackage, + ).setToolTip( + "Export user-labeled frames with image data into a single SLP file.\n\n" + "Use this for archiving a dataset with labeled frames only." + ) + add_menu_item( + training_package_menu, "export training package", - "Export Training Package...", + "Labeled + suggested frames (recommended)", self.commands.exportTrainingPackage, ).setToolTip( "Export user-labeled frames and suggested frames with image data into a " @@ -690,18 +743,9 @@ def new_instance_menu_action(): "unlabeled frames." ) add_menu_item( - predictionMenu, - "export user labels package", - "Export Package with User Labels Only...", - self.commands.exportUserLabelsPackage, - ).setToolTip( - "Export user-labeled frames with image data into a single SLP file.\n\n" - "Use this for archiving a dataset with labeled frames only." - ) - add_menu_item( - predictionMenu, + training_package_menu, "export full package", - "Export Package with All Labels...", + "Labeled + predicted + suggested frames", self.commands.exportFullPackage, ).setToolTip( "Export all frames (including predictions) and suggested frames with image " @@ -730,7 +774,7 @@ def new_instance_menu_action(): helpMenu.addSeparator() - helpMenu.addAction("Check for updates...", self.commands.checkForUpdates) + helpMenu.addAction("Latest versions:", self.commands.checkForUpdates) self.state["stable_version_menu"] = helpMenu.addAction( " Stable: N/A", self.commands.openStableVersion ) @@ -754,14 +798,16 @@ def wrapped_function(*args): return wrapped_function def _create_dock_windows(self): - """Create dock windows and connects them to gui.""" + """Create dock windows and connect them to GUI.""" def _make_dock(name, widgets=[], tab_with=None): dock = QDockWidget(name) + dock.setObjectName(name + "Dock") dock.setAllowedAreas(Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea) dock_widget = QWidget() + dock_widget.setObjectName(name + "Widget") layout = QVBoxLayout() for widget in widgets: @@ -770,10 +816,6 @@ def _make_dock(name, widgets=[], tab_with=None): dock_widget.setLayout(layout) dock.setWidget(dock_widget) - key = f"hide {name.lower()} dock" - if key in prefs and prefs[key]: - dock.hide() - self.addDockWidget(Qt.RightDockWidgetArea, dock) self.viewMenu.addAction(dock.toggleViewAction()) @@ -921,6 +963,33 @@ def new_edge(): hb = QHBoxLayout() + _add_button( + hb, + "Add current frame", + self.process_events_then(self.commands.addCurrentFrameAsSuggestion), + "add current frame as suggestion", + ) + + _add_button( + hb, + "Remove", + self.process_events_then(self.commands.removeSuggestion), + "remove suggestion", + ) + + _add_button( + hb, + "Clear all", + self.process_events_then(self.commands.clearSuggestions), + "clear suggestions", + ) + + hbw = QWidget() + hbw.setLayout(hb) + suggestions_layout.addWidget(hbw) + + hb = QHBoxLayout() + _add_button( hb, "Prev", @@ -1032,6 +1101,7 @@ def _update_gui_state(self): # Update menus self.track_menu.setEnabled(has_selected_instance) + self.delete_tracks_menu.setEnabled(has_tracks) self._menu_actions["clear selection"].setEnabled(has_selected_instance) self._menu_actions["delete instance"].setEnabled(has_selected_instance) @@ -1125,13 +1195,14 @@ def _has_topic(topic_list): suggestion_status_text = "" suggestion_list = self.labels.get_suggestions() if suggestion_list: - suggestion_label_counts = [ - self.labels.instance_count(item.video, item.frame_idx) - for item in suggestion_list - ] - labeled_count = len(suggestion_list) - suggestion_label_counts.count(0) + labeled_count = 0 + for suggestion in suggestion_list: + lf = self.labels.get((suggestion.video, suggestion.frame_idx)) + if lf is not None and lf.has_user_instances: + labeled_count += 1 + prc = (labeled_count / len(suggestion_list)) * 100 suggestion_status_text = ( - f"{labeled_count}/{len(suggestion_list)} labeled" + f"{labeled_count}/{len(suggestion_list)} labeled ({prc:.1f}%)" ) self.suggested_count_label.setText(suggestion_status_text) @@ -1177,13 +1248,17 @@ def updateStatusMessage(self, message: Optional[str] = None): spacer = " " if message is None: - message = f"Frame: {frame_idx+1:,}/{len(current_video):,}" + message = "" + if len(self.labels.videos) > 1: + message += f"Video {self.labels.videos.index(current_video)+1}/" + message += f"{len(self.labels.videos)}" + message += spacer + + message += f"Frame: {frame_idx+1:,}/{len(current_video):,}" if self.player.seekbar.hasSelection(): start, end = self.state["frame_range"] - message += f" (selection: {start+1:,}-{end:,})" - - if len(self.labels.videos) > 1: - message += f" of video {self.labels.videos.index(current_video)+1}" + message += spacer + message += f"Selection: {start+1:,}-{end:,} ({end-start+1:,} frames)" message += f"{spacer}Labeled Frames: " if current_video is not None: @@ -1212,6 +1287,15 @@ def updateStatusMessage(self, message: Optional[str] = None): self.statusBar().showMessage(message) + def resetPrefs(self): + """Reset preferences to defaults.""" + prefs.reset_to_default() + msg = QMessageBox() + msg.setText( + "Note: Some preferences may not take effect until application is restarted." + ) + msg.exec_() + def loadProjectFile(self, filename: Optional[str] = None): """ Loads given labels file into GUI. @@ -1281,15 +1365,19 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None): def _update_track_menu(self): """Updates track menu options.""" self.track_menu.clear() - for track in self.labels.tracks: + self.delete_tracks_menu.clear() + for track_ind, track in enumerate(self.labels.tracks): key_command = "" - if self.labels.tracks.index(track) < 9: + if track_ind < 9: key_command = Qt.CTRL + Qt.Key_0 + self.labels.tracks.index(track) + 1 self.track_menu.addAction( f"{track.name}", lambda x=track: self.commands.setInstanceTrack(x), key_command, ) + self.delete_tracks_menu.addAction( + f"{track.name}", lambda x=track: self.commands.deleteTrack(x) + ) self.track_menu.addAction( "New Track", self.commands.addTrack, Qt.CTRL + Qt.Key_0 ) @@ -1562,6 +1650,16 @@ def main(): const=True, default=False, ) + parser.add_argument( + "--reset", + help=( + "Reset GUI state and preferences. Use this flag if the GUI appears " + "incorrectly or fails to open." + ), + action="store_const", + const=True, + default=False, + ) args = parser.parse_args() @@ -1577,11 +1675,18 @@ def main(): app = QApplication([]) app.setApplicationName(f"SLEAP Label v{sleap.version.__version__}") - window = MainWindow(labels_path=args.labels_path) + window = MainWindow(labels_path=args.labels_path, reset=args.reset) window.showMaximized() - # if not args.labels_path: - # window.commands.openProject(first_open=True) + # Disable GPU in GUI process. This does not affect subprocesses. + sleap.use_cpu_only() + + # Print versions. + print() + print("Software versions:") + sleap.versions() + print() + print("Happy SLEAPing! :)") if args.profiling: import cProfile diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 6230d93eb..503b9786a 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -29,6 +29,9 @@ class which inherits from `AppCommand` (or a more specialized class such as import attr import operator import os +import re +import sys +import subprocess from abc import ABC from enum import Enum @@ -40,7 +43,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from PySide2 import QtCore, QtWidgets, QtGui -from PySide2.QtWidgets import QMessageBox +from PySide2.QtWidgets import QMessageBox, QProgressDialog from sleap.gui.dialogs.delete import DeleteDialog from sleap.skeleton import Skeleton @@ -326,6 +329,18 @@ def prevSuggestedFrame(self): """Goes to previous suggested frame.""" self.execute(GoPrevSuggestedFrame) + def addCurrentFrameAsSuggestion(self): + """Add current frame as a suggestion.""" + self.execute(AddSuggestion) + + def removeSuggestion(self): + """Remove the selected frame from suggestions.""" + self.execute(RemoveSuggestion) + + def clearSuggestions(self): + """Clear all suggestions.""" + self.execute(ClearSuggestions) + def nextTrackFrame(self): """Goes to next frame on which a track starts.""" self.execute(GoNextTrackFrame) @@ -486,6 +501,10 @@ def setInstanceTrack(self, new_track: "Track"): """Sets track for selected instance.""" self.execute(SetSelectedInstanceTrack, new_track=new_track) + def deleteTrack(self, track: "Track"): + """Delete a track and remove from all instances.""" + self.execute(DeleteTrack, track=track) + def setTrackName(self, track: "Track", name: str): """Sets name for track.""" self.execute(SetTrackName, track=track, name=name) @@ -681,18 +700,6 @@ def ask(context: "CommandContext", params: dict) -> bool: if len(filename) == 0: return False - # QtWidgets.QMessageBox( - # text="Please locate the directory with image files for this dataset." - # ).exec_() - # - # img_dir = FileDialog.openDir( - # None, - # directory=os.path.dirname(filename), - # caption="Open Image Directory" - # ) - # if len(img_dir) == 0: - # return False - params["filename"] = filename params["img_dir"] = os.path.dirname(filename) @@ -817,6 +824,22 @@ def ask(context: "CommandContext", params: dict) -> bool: return True +def get_new_version_filename(filename: str) -> str: + """Increment version number in filenames that end in `.v###.slp`.""" + p = PurePath(filename) + + match = re.match(".*\\.v(\\d+)\\.slp", filename) + if match is not None: + old_ver = match.group(1) + new_ver = str(int(old_ver) + 1).zfill(len(old_ver)) + filename = filename.replace(f".v{old_ver}.slp", f".v{new_ver}.slp") + filename = str(PurePath(filename)) + else: + filename = str(p.with_name(f"{p.stem} copy{p.suffix}")) + + return filename + + class SaveProjectAs(AppCommand): @staticmethod def _try_save(context, labels: Labels, filename: str): @@ -849,15 +872,12 @@ def do_action(cls, context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - default_name = context.state["filename"] or "untitled" - p = PurePath(default_name) - default_name = str(p.with_name(f"{p.stem} copy{p.suffix}")) - - filters = [ - "SLEAP HDF5 dataset (*.slp)", - "SLEAP JSON dataset (*.json)", - "Compressed JSON (*.zip)", - ] + default_name = context.state["filename"] + if default_name: + default_name = get_new_version_filename(default_name) + else: + default_name = "labels.v000.slp" + filters = ["SLEAP labels dataset (*.slp)"] filename, selected_filter = FileDialog.save( context.app, caption="Save As...", @@ -881,7 +901,7 @@ def do_action(cls, context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - default_name = context.state["filename"] or "untitled" + default_name = context.state["filename"] or "labels" p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem}.analysis.h5")) @@ -1009,25 +1029,65 @@ def ask(context: CommandContext, params: dict) -> bool: return True +def export_dataset_gui( + labels: Labels, filename: str, all_labeled: bool = False, suggested: bool = False +) -> str: + """Export dataset with image data and display progress GUI dialog. + + Args: + labels: `sleap.Labels` dataset to export. + filename: Output filename. Should end in `.pkg.slp`. + all_labeled: If `True`, export all labeled frames, including frames with no user + instances. + suggested: If `True`, include image data for suggested frames. + """ + win = QProgressDialog("Exporting dataset with frame images...", "Cancel", 0, 1) + + def update_progress(n, n_total): + if win.wasCanceled(): + return False + win.setMaximum(n_total) + win.setValue(n) + win.setLabelText( + "Exporting dataset with frame images...
" + f"{n}/{n_total} ({(n/n_total)*100:.1f}%)" + ) + QtWidgets.QApplication.instance().processEvents() + return True + + Labels.save_file( + labels, + filename, + default_suffix="slp", + save_frame_data=True, + all_labeled=all_labeled, + suggested=suggested, + progress_callback=update_progress, + ) + + if win.wasCanceled(): + # Delete output if saving was canceled. + os.remove(filename) + return "canceled" + + win.hide() + + return filename + + class ExportDatasetWithImages(AppCommand): all_labeled = False suggested = False @classmethod def do_action(cls, context: CommandContext, params: dict): - win = MessageDialog("Exporting dataset with frame images...", context.app) - - Labels.save_file( - context.state["labels"], - params["filename"], - default_suffix="slp", - save_frame_data=True, + export_dataset_gui( + labels=context.state["labels"], + filename=params["filename"], all_labeled=cls.all_labeled, suggested=cls.suggested, ) - win.hide() - @staticmethod def ask(context: CommandContext, params: dict) -> bool: filters = [ @@ -1809,11 +1869,12 @@ def do_action(context: CommandContext, params: dict): if selected_instance is None: return - old_track = selected_instance.track - # When setting track for an instance that doesn't already have a track set, # just set for selected instance. - if old_track is None: + if ( + selected_instance.track is None + or not context.state["propagate track labels"] + ): # Move anything already in the new track out of it new_track_instances = context.labels.find_track_occupancy( video=context.state["video"], @@ -1833,6 +1894,7 @@ def do_action(context: CommandContext, params: dict): # When the instance does already have a track, then we want to update # the track for a range of frames. else: + old_track = selected_instance.track # Determine range that should be affected if context.state["has_frame_range"]: @@ -1854,6 +1916,15 @@ def do_action(context: CommandContext, params: dict): context.state["instance"] = selected_instance +class DeleteTrack(EditCommand): + topics = [UpdateTopic.tracks] + + @staticmethod + def do_action(context: CommandContext, params: dict): + track = params["track"] + context.labels.remove_track(track) + + class SetTrackName(EditCommand): topics = [UpdateTopic.tracks, UpdateTopic.frame] @@ -1870,7 +1941,11 @@ class GenerateSuggestions(EditCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - win = MessageDialog("Generating list of suggested frames...", context.app) + # TODO: Progress bar + win = MessageDialog( + "Generating list of suggested frames... " "This may take a few minutes.", + context.app, + ) new_suggestions = VideoFrameSuggestions.suggest( labels=context.labels, params=params @@ -1881,6 +1956,56 @@ def do_action(cls, context: CommandContext, params: dict): win.hide() +class AddSuggestion(EditCommand): + topics = [UpdateTopic.suggestions] + + @classmethod + def do_action(cls, context: CommandContext, params: dict): + context.labels.add_suggestion( + context.state["video"], context.state["frame_idx"] + ) + context.app.suggestionsTable.selectRow(len(context.labels) - 1) + + +class RemoveSuggestion(EditCommand): + topics = [UpdateTopic.suggestions] + + @classmethod + def do_action(cls, context: CommandContext, params: dict): + selected_frame = context.app.suggestionsTable.getSelectedRowItem() + if selected_frame is not None: + context.labels.remove_suggestion( + selected_frame.video, selected_frame.frame_idx + ) + + +class ClearSuggestions(EditCommand): + topics = [UpdateTopic.suggestions] + + @staticmethod + def ask(context: CommandContext, params: dict) -> bool: + if len(context.labels.suggestions) == 0: + return False + + # Warn that suggestions will be cleared + + response = QMessageBox.warning( + context.app, + "Clearing all suggestions", + "Are you sure you want to remove all suggestions from the project?", + QMessageBox.Yes, + QMessageBox.No, + ) + if response == QMessageBox.No: + return False + + return True + + @classmethod + def do_action(cls, context: CommandContext, params: dict): + context.labels.clear_suggestions() + + class MergeProject(EditCommand): topics = [UpdateTopic.all] @@ -1942,7 +2067,8 @@ def do_action(cls, context: CommandContext, params: dict): if context.state["labeled_frame"] is None: return - # FIXME: filter by skeleton type + if len(context.state["skeleton"]) == 0: + return from_predicted = copy_instance from_prev_frame = False @@ -2135,6 +2261,7 @@ def add_best_nodes(cls, context, instance, visible): @classmethod def add_random_nodes(cls, context, instance, visible): + # TODO: Move this to Instance so we can do this on-demand # the rect that's currently visible in the window view in_view_rect = context.app.player.getVisibleRect() @@ -2257,10 +2384,19 @@ def do_action(cls, context: CommandContext, params: dict): context.labels.add_instance(context.state["labeled_frame"], new_instance) +def open_website(url: str): + """Open website in default browser. + + Args: + url: URL to open. + """ + QtGui.QDesktopServices.openUrl(QtCore.QUrl(url)) + + class OpenWebsite(AppCommand): @staticmethod def do_action(context: CommandContext, params: dict): - QtGui.QDesktopServices.openUrl(QtCore.QUrl(params["url"])) + open_website(params["url"]) class CheckForUpdates(AppCommand): @@ -2294,3 +2430,30 @@ def do_action(context: CommandContext, params: dict): rls = context.app.release_checker.latest_prerelease if rls is not None: context.openWebsite(rls.url) + + +def copy_to_clipboard(text: str): + """Copy a string to the system clipboard. + + Args: + text: String to copy to clipboard. + """ + clipboard = QtWidgets.QApplication.clipboard() + clipboard.clear(mode=clipboard.Clipboard) + clipboard.setText(text, mode=clipboard.Clipboard) + + +def open_file(filename: str): + """Opens file in native system file browser or registered application. + + Args: + filename: Path to file or folder. + + Notes: + Source: https://stackoverflow.com/a/16204023 + """ + if sys.platform == "win32": + os.startfile(filename) + else: + opener = "open" if sys.platform == "darwin" else "xdg-open" + subprocess.call([opener, filename]) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 2834468c1..4106575d3 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -303,8 +303,7 @@ def selectionChanged(self, new, old): self.state[self.name_prefix + self.row_name] = item def activateSelected(self, *args): - """ - Activates item currently selected in table. + """Activate item currently selected in table. "Activate" means that the relevant :py:class:`GuiState` state variable is set to the currently selected item. @@ -313,8 +312,7 @@ def activateSelected(self, *args): self.state[self.row_name] = self.getSelectedRowItem() def selectRowItem(self, item: Any): - """ - Selects row corresponding to item. + """Select row corresponding to item. If the table model converts items to dictionaries (using `item_to_data` method), then `item` argument should be the original item, not the @@ -330,9 +328,12 @@ def selectRowItem(self, item: Any): if self.row_name: self.state[self.name_prefix + self.row_name] = item + def selectRow(self, idx: int): + """Select row corresponding to index.""" + self.selectRowItem(self.model().original_items[idx]) + def getSelectedRowItem(self) -> Any: - """ - Returns item corresponding to currently selected row. + """Return item corresponding to currently selected row. Note that if the table model converts items to dictionaries (using `item_to_data` method), then returned item will be the original item, @@ -461,7 +462,10 @@ def item_to_data(self, obj, item): item_dict["SuggestionFrame"] = item - video_string = f"{labels.videos.index(item.video)+1}: {os.path.basename(item.video.filename)}" + video_string = ( + f"{labels.videos.index(item.video)+1}: " + f"{os.path.basename(item.video.filename)}" + ) item_dict["group"] = str(item.group + 1) if item.group is not None else "" item_dict["group_int"] = item.group if item.group is not None else -1 @@ -469,7 +473,8 @@ def item_to_data(self, obj, item): item_dict["frame"] = int(item.frame_idx) + 1 # start at frame 1 rather than 0 # show how many labeled instances are in this frame - val = labels.instance_count(item.video, item.frame_idx) + lf = labels.get((item.video, item.frame_idx)) + val = 0 if lf is None else len(lf.user_instances) val = str(val) if val > 0 else "" item_dict["labeled"] = val diff --git a/sleap/gui/dialogs/shortcuts.py b/sleap/gui/dialogs/shortcuts.py index f4e912321..89f700019 100644 --- a/sleap/gui/dialogs/shortcuts.py +++ b/sleap/gui/dialogs/shortcuts.py @@ -28,11 +28,26 @@ def accept(self): for action, widget in self.key_widgets.items(): self.shortcuts[action] = widget.keySequence().toString() self.shortcuts.save() + self.info_msg() + super(ShortcutDialog, self).accept() + + def info_msg(self): + """Display information about changes.""" + msg = QtWidgets.QMessageBox() + msg.setText( + "Application must be restarted before changes to keyboard shortcuts take " + "effect." + ) + msg.exec_() + def reset(self): + """Reset to defaults.""" + self.shortcuts.reset_to_default() + self.info_msg() super(ShortcutDialog, self).accept() def load_shortcuts(self): - """Loads shortcuts object.""" + """Load shortcuts object.""" self.shortcuts = Shortcuts() def make_form(self): @@ -41,26 +56,28 @@ def make_form(self): layout = QtWidgets.QVBoxLayout() layout.addWidget(self.make_shortcuts_widget()) - layout.addWidget( - QtWidgets.QLabel( - "Any changes to keyboard shortcuts will not take effect " - "until you quit and re-open the application." - ) - ) layout.addWidget(self.make_buttons_widget()) self.setLayout(layout) def make_buttons_widget(self) -> QtWidgets.QDialogButtonBox: - """Makes the form buttons.""" - buttons = QtWidgets.QDialogButtonBox( - QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel - ) - buttons.accepted.connect(self.accept) - buttons.rejected.connect(self.reject) + """Make the form buttons.""" + buttons = QtWidgets.QDialogButtonBox() + save = QtWidgets.QPushButton("Save") + save.clicked.connect(self.accept) + buttons.addButton(save, QtWidgets.QDialogButtonBox.AcceptRole) + + cancel = QtWidgets.QPushButton("Cancel") + cancel.clicked.connect(self.reject) + buttons.addButton(cancel, QtWidgets.QDialogButtonBox.RejectRole) + + reset = QtWidgets.QPushButton("Reset to defaults") + reset.clicked.connect(self.reset) + buttons.addButton(reset, QtWidgets.QDialogButtonBox.ActionRole) + return buttons def make_shortcuts_widget(self) -> QtWidgets.QWidget: - """Makes the widget will fields for all shortcuts.""" + """Make the widget will fields for all shortcuts.""" shortcuts = self.shortcuts widget = QtWidgets.QWidget() @@ -75,7 +92,7 @@ def make_shortcuts_widget(self) -> QtWidgets.QWidget: return widget def make_column_widget(self, shortcuts: List) -> QtWidgets.QWidget: - """Makes a single column of shortcut fields. + """Make a single column of shortcut fields. Args: shortcuts: The list of shortcuts to include in this column. diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index c5c3361d9..12f4b74e9 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -3,22 +3,22 @@ """ import cattr import os +import shutil +import atexit +import tempfile +from pathlib import Path -import networkx as nx - +import sleap from sleap import Labels, Video from sleap.gui.dialogs.filedialog import FileDialog from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield -from typing import Dict, List, Optional, Text +from typing import Dict, List, Optional, Text, Optional from PySide2 import QtWidgets, QtCore -# Debug option to skip the training run -SKIP_TRAINING = False - # List of fields which should show list of skeleton nodes NODE_LIST_FIELDS = [ "data.instance_cropping.center_on_part", @@ -84,13 +84,20 @@ def __init__( # Layout for buttons buttons = QtWidgets.QDialogButtonBox() - self.cancel_button = buttons.addButton(QtWidgets.QDialogButtonBox.Cancel) self.save_button = buttons.addButton( - "Save configuration files...", QtWidgets.QDialogButtonBox.ApplyRole + "Save configuration files...", QtWidgets.QDialogButtonBox.ActionRole ) - self.run_button = buttons.addButton( - "Run", QtWidgets.QDialogButtonBox.AcceptRole + self.export_button = buttons.addButton( + "Export training job package...", QtWidgets.QDialogButtonBox.ActionRole + ) + self.cancel_button = buttons.addButton(QtWidgets.QDialogButtonBox.Cancel) + self.run_button = buttons.addButton("Run", QtWidgets.QDialogButtonBox.ApplyRole) + + self.save_button.setToolTip("Save scripts and configuration to run pipeline.") + self.export_button.setToolTip( + "Export data, configuration, and scripts for remote training and inference." ) + self.run_button.setToolTip("Run pipeline locally (GPU recommended).") buttons_layout = QtWidgets.QHBoxLayout() buttons_layout.addWidget(buttons, alignment=QtCore.Qt.AlignTop) @@ -132,9 +139,10 @@ def __init__( self.connect_signals() # Connect actions for buttons - buttons.accepted.connect(self.run) - buttons.rejected.connect(self.reject) - buttons.clicked.connect(self.on_button_click) + self.save_button.clicked.connect(self.save) + self.export_button.clicked.connect(self.export_package) + self.cancel_button.clicked.connect(self.reject) + self.run_button.clicked.connect(self.run) # Connect button for previewing the training data if "_view_datagen" in self.pipeline_form_widget.buttons: @@ -450,8 +458,6 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen frame_count = self.count_total_frames_for_selection_option(frame_selection) if predict_frames_choice.startswith("user"): - # For inference on user labeled frames, we'll have a single - # inference item. items_for_inference = runners.ItemsForInference( items=[ runners.DatasetItemForInference( @@ -461,8 +467,6 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen total_frame_count=frame_count, ) elif predict_frames_choice.startswith("suggested"): - # For inference on all suggested frames, we'll have a single - # inference item. items_for_inference = runners.ItemsForInference( items=[ runners.DatasetItemForInference( @@ -472,7 +476,6 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen total_frame_count=frame_count, ) else: - # Otherwise, make an inference item for each video with list of frames. items_for_inference = runners.ItemsForInference.from_video_frames_dict( frame_selection, total_frame_count=frame_count ) @@ -491,24 +494,37 @@ def _validate_pipeline(self): ] if untrained: can_run = False - message = f"Cannot run inference with untrained models ({', '.join(untrained)})." + message = ( + "Cannot run inference with untrained models " + f"({', '.join(untrained)})." + ) # Make sure skeleton will be valid for bottom-up inference. if self.mode == "training" and self.current_pipeline == "bottom-up": skeleton = self.labels.skeletons[0] if not skeleton.is_arborescence: - message += "Cannot run bottom-up pipeline when skeleton is not an arborescence." + message += ( + "Cannot run bottom-up pipeline when skeleton is not an " + "arborescence." + ) root_names = [n.name for n in skeleton.root_nodes] over_max_in_degree = [n.name for n in skeleton.in_degree_over_one] cycles = skeleton.cycles if len(root_names) > 1: - message += f" There are multiple root nodes: {', '.join(root_names)} (there should be exactly one node which is not a target)." + message += ( + f" There are multiple root nodes: {', '.join(root_names)} " + "(there should be exactly one node which is not a target)." + ) if over_max_in_degree: - message += f" There are nodes which are target in multiple edges: {', '.join(over_max_in_degree)} (maximum in-degree should be 1)." + message += ( + " There are nodes which are target in multiple edges: " + f"{', '.join(over_max_in_degree)} (maximum in-degree should be " + "1)." + ) if cycles: cycle_strings = [] @@ -577,27 +593,128 @@ def run(self): win.setWindowTitle("Inference Results") win.exec_() - def save(self): - models_dir = os.path.join(os.path.dirname(self.labels_filename), "/models") - output_dir = FileDialog.openDir( - None, directory=models_dir, caption="Select directory to save scripts" - ) + def save( + self, output_dir: Optional[str] = None, labels_filename: Optional[str] = None + ): + """Save scripts and configs to run pipeline.""" + if output_dir is None: + models_dir = os.path.join(os.path.dirname(self.labels_filename), "/models") + output_dir = FileDialog.openDir( + None, directory=models_dir, caption="Select directory to save scripts" + ) - if not output_dir: - return + if not output_dir: + return pipeline_form_data = self.pipeline_form_widget.get_form_data() items_for_inference = self.get_items_for_inference(pipeline_form_data) config_info_list = self.get_every_head_config_data(pipeline_form_data) + if labels_filename is None: + labels_filename = self.labels_filename + runners.write_pipeline_files( output_dir=output_dir, - labels_filename=self.labels_filename, + labels_filename=labels_filename, config_info_list=config_info_list, inference_params=pipeline_form_data, items_for_inference=items_for_inference, ) + def export_package(self, output_path: Optional[str] = None, gui: bool = True): + """Export training job package.""" + # TODO: Warn if self.mode != "training"? + if output_path is None: + # Prompt for output path. + output_path, _ = FileDialog.save( + caption="Export Training Job Package...", + dir=f"{self.labels_filename}.training_job.zip", + filter="Training Job Package (*.zip)", + ) + if len(output_path) == 0: + return + + # Create temp dir before packaging. + tmp_dir = tempfile.TemporaryDirectory() + + # Remove the temp dir when program exits in case something goes wrong. + # atexit.register(shutil.rmtree, tmp_dir.name, ignore_errors=True) + + # Check if we need to include suggestions. + include_suggestions = False + items_for_inference = self.get_items_for_inference( + self.pipeline_form_widget.get_form_data() + ) + for item in items_for_inference.items: + if ( + isinstance(item, runners.DatasetItemForInference) + and item.frame_filter == "suggested" + ): + include_suggestions = True + + # Save dataset with images. + labels_pkg_filename = str( + Path(self.labels_filename).with_suffix(".pkg.slp").name + ) + if gui: + ret = sleap.gui.commands.export_dataset_gui( + self.labels, + tmp_dir.name + "/" + labels_pkg_filename, + all_labeled=False, + suggested=include_suggestions, + ) + if ret == "canceled": + # Quit if user canceled during export. + tmp_dir.cleanup() + return + else: + self.labels.save( + tmp_dir.name + "/" + labels_pkg_filename, + with_images=True, + embed_all_labeled=False, + embed_suggested=include_suggestions, + ) + + # Save config and scripts. + self.save(tmp_dir.name, labels_filename=labels_pkg_filename) + + # Package everything. + shutil.make_archive( + base_name=str(Path(output_path).with_suffix("")), + format="zip", + root_dir=tmp_dir.name, + ) + + msg = f"Saved training job package to: {output_path}" + print(msg) + + # Close training editor. + self.accept() + + if gui: + msgBox = QtWidgets.QMessageBox(text=f"Created training job package:") + msgBox.setDetailedText(output_path) + msgBox.setWindowTitle("Training Job Package") + okButton = msgBox.addButton(QtWidgets.QMessageBox.Ok) + openFolderButton = msgBox.addButton( + "Open containing folder", QtWidgets.QMessageBox.ActionRole + ) + colabButton = msgBox.addButton( + "Go to Colab", QtWidgets.QMessageBox.ActionRole + ) + msgBox.exec_() + + if msgBox.clickedButton() == openFolderButton: + sleap.gui.commands.open_file(str(Path(output_path).resolve().parent)) + elif msgBox.clickedButton() == colabButton: + # TODO: Update this to more workflow-tailored notebook. + sleap.gui.commands.copy_to_clipboard(output_path) + sleap.gui.commands.open_website( + "https://colab.research.google.com/github/murthylab/sleap-notebooks/blob/master/Training_and_inference_using_Google_Drive.ipynb" + ) + + tmp_dir.cleanup() + class TrainingPipelineWidget(QtWidgets.QWidget): """ @@ -808,7 +925,7 @@ def __init__( @classmethod def from_trained_config(cls, cfg_info: configs.ConfigFileInfo): - widget = cls(require_trained=True) + widget = cls(require_trained=True, head=cfg_info.head_name) widget.acceptSelectedConfigInfo(cfg_info) widget.setWindowTitle(cfg_info.path_dir) return widget diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 48a17c37b..1839c1522 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -4,9 +4,12 @@ import abc import attr import os -import subprocess as sub +import psutil +import json +import subprocess import tempfile import time +import shutil from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Text, Tuple @@ -17,7 +20,17 @@ from sleap.nn import training from sleap.nn.config import TrainingJobConfig -SKIP_TRAINING = False + +def kill_process(pid: int): + """Force kill a running process and any child processes. + + Args: + proc: A process ID. + """ + proc_ = psutil.Process(pid) + for subproc_ in proc_.children(recursive=True): + subproc_.kill() + proc_.kill() @attr.s(auto_attribs=True) @@ -151,7 +164,10 @@ class InferenceTask: results: List[LabeledFrame] = attr.ib(default=attr.Factory(list)) def make_predict_cli_call( - self, item_for_inference: ItemForInference, output_path: Optional[str] = None + self, + item_for_inference: ItemForInference, + output_path: Optional[str] = None, + gui: bool = True, ) -> List[Text]: """Makes list of CLI arguments needed for running inference.""" cli_args = ["sleap-track"] @@ -222,6 +238,11 @@ def make_predict_cli_call( cli_args.extend(("-o", output_path)) + if gui: + cli_args.extend(("--verbosity", "json")) + + cli_args.extend(("--no-empty-frames",)) + return cli_args, output_path def predict_subprocess( @@ -229,23 +250,43 @@ def predict_subprocess( item_for_inference: ItemForInference, append_results: bool = False, waiting_callback: Optional[Callable] = None, + gui: bool = True, ) -> Tuple[Text, bool]: """Runs inference in a subprocess.""" - cli_args, output_path = self.make_predict_cli_call(item_for_inference) + cli_args, output_path = self.make_predict_cli_call(item_for_inference, gui=gui) print("Command line call:") - print(" \\\n".join(cli_args)) + print(" ".join(cli_args)) print() - with sub.Popen(cli_args) as proc: + # Run inference CLI capturing output. + with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc: + + # Poll until finished. while proc.poll() is None: - if waiting_callback is not None: - if waiting_callback() == -1: - # -1 signals user cancellation - return "", False + # Read line. + line = proc.stdout.readline() + line = line.decode().rstrip() + + if line.startswith("{"): + # Parse line. + line_data = json.loads(line) + else: + # Pass through non-json output. + print(line) + line_data = {} - time.sleep(0.1) + if waiting_callback is not None: + # Pass line data to callback. + ret = waiting_callback(**line_data) + + if ret == "cancel": + # Stop if callback returned cancel signal. + kill_process(proc.pid) + print(f"Killed PID: {proc.pid}") + return "", "canceled" + time.sleep(0.05) print(f"Process return code: {proc.returncode}") success = proc.returncode == 0 @@ -255,7 +296,9 @@ def predict_subprocess( new_inference_labels = Labels.load_file(output_path, match_to=self.labels) self.results.extend(new_inference_labels.labeled_frames) - return output_path, success + # Return "success" or return code if failed. + ret = "success" if success else proc.returncode + return output_path, ret def merge_results(self): """Merges result frames into labels dataset.""" @@ -498,7 +541,8 @@ def run_gui_training( os.path.dirname(labels_filename), "models" ) training.setup_new_run_folder( - job.outputs, base_run_name=f"{model_type}.{len(labels)}" + job.outputs, + base_run_name=f"{model_type}.n={len(labels.user_labeled_frames)}", ) if gui: @@ -518,9 +562,11 @@ def run_gui_training( def waiting(): if gui: QtWidgets.QApplication.instance().processEvents() + if win.canceled: + return "cancel" # Run training - trained_job_path, success = train_subprocess( + trained_job_path, ret = train_subprocess( job_config=job, labels_filename=labels_filename, video_paths=video_path_list, @@ -528,10 +574,17 @@ def waiting(): save_viz=save_viz, ) - if success: + if ret == "success": # get the path to the resulting TrainingJob file trained_job_paths[model_type] = trained_job_path print(f"Finished training {str(model_type)}.") + elif ret == "canceled": + if gui: + win.close() + print("Deleting canceled run data:", trained_job_path) + shutil.rmtree(trained_job_path, ignore_errors=True) + trained_job_paths[model_type] = None + break else: if gui: win.close() @@ -568,48 +621,93 @@ def run_gui_inference( """ if gui: - # show message while running inference progress = QtWidgets.QProgressDialog( - f"Running inference on {len(items_for_inference)} videos...", + "Initializing...", "Cancel", 0, - len(items_for_inference), + 1, ) progress.show() QtWidgets.QApplication.instance().processEvents() # Make callback to process events while running inference - def waiting(done_count): + def waiting( + n_processed: Optional[int] = None, + n_total: Optional[int] = None, + elapsed: Optional[float] = None, + rate: Optional[float] = None, + eta: Optional[float] = None, + current_item: Optional[int] = None, + total_items: Optional[int] = None, + **kwargs, + ) -> str: if gui: QtWidgets.QApplication.instance().processEvents() - progress.setValue(done_count) + if n_total is not None: + progress.setMaximum(n_total) + if n_processed is not None: + progress.setValue(n_processed) + + msg = "Predicting..." + # if current_item is not None and total_items is not None: + # msg += f" [{current_item + 1}/{total_items}]" + + if n_processed is not None and n_total is not None: + + prc = (n_processed / n_total) * 100 + msg = f"Predicted: {n_processed+1}/{n_total} ({prc:.1f}%)" + + # Show time elapsed? + if rate is not None and eta is not None: + eta_mins, eta_secs = divmod(eta, 60) + if eta_mins > 60: + eta_hours, eta_mins = divmod(eta, 60) + eta_str = f"{int(eta_hours)}:{int(eta_mins):02}:{int(eta_secs):02}" + else: + eta_str = f"{int(eta_mins)}:{int(eta_secs):02}" + msg += f"
ETA: {eta_str} FPS: {rate:.1f}" + + msg = msg.replace(" ", " ") + + progress.setLabelText(msg) + QtWidgets.QApplication.instance().processEvents() + if progress.wasCanceled(): - return -1 + return "cancel" for i, item_for_inference in enumerate(items_for_inference.items): - # Run inference for desired frames in this video - predictions_path, success = inference_task.predict_subprocess( - item_for_inference, append_results=True, waiting_callback=lambda: waiting(i) + + def waiting_item(**kwargs): + kwargs["current_item"] = i + kwargs["total_items"] = len(items_for_inference.items) + return waiting(**kwargs) + + # Run inference for desired frames in this video. + predictions_path, ret = inference_task.predict_subprocess( + item_for_inference, + append_results=True, + waiting_callback=waiting_item, + gui=gui, ) - if not success: + if gui: + progress.close() + + if ret == "success": + inference_task.merge_results() + return len(inference_task.results) + elif ret == "canceled": + return -1 + else: if gui: - progress.close() QtWidgets.QMessageBox( - text="An error occcured during inference. Your command line " - "terminal may have more information about the error." + text=( + "An error occcured during inference. Your command line " + "terminal may have more information about the error." + ) ).exec_() return -1 - inference_task.merge_results() - - # close message window - if gui: - progress.close() - - # return total_new_lf_count - return len(inference_task.results) - def train_subprocess( job_config: TrainingJobConfig, @@ -619,8 +717,6 @@ def train_subprocess( save_viz: bool = False, ): """Runs training inside subprocess.""" - - # run_name = job_config.outputs.run_name run_path = job_config.outputs.run_path success = False @@ -655,20 +751,26 @@ def train_subprocess( print(cli_args) - if not SKIP_TRAINING: - # Run training in a subprocess - with sub.Popen(cli_args) as proc: - - # Wait till training is done, calling a callback if given. - while proc.poll() is None: - if waiting_callback is not None: - if waiting_callback() == -1: - # -1 signals user cancellation - return "", False - time.sleep(0.1) - - success = proc.returncode == 0 + # Run training in a subprocess. + proc = subprocess.Popen(cli_args) + + # Wait till training is done, calling a callback if given. + while proc.poll() is None: + if waiting_callback is not None: + ret = waiting_callback() + if ret == "cancel": + print("Canceling training...") + kill_process(proc.pid) + print(f"Killed PID: {proc.pid}") + return run_path, "canceled" + time.sleep(0.1) + + # Check return code. + if proc.returncode == 0: + ret = "success" + else: + ret = proc.returncode print("Run Path:", run_path) - return run_path, success + return run_path, ret diff --git a/sleap/gui/learning/scopedkeydict.py b/sleap/gui/learning/scopedkeydict.py index 7d298d5e3..d10867ddc 100644 --- a/sleap/gui/learning/scopedkeydict.py +++ b/sleap/gui/learning/scopedkeydict.py @@ -119,6 +119,18 @@ def apply_cfg_transforms_to_key_val_dict(key_val_dict: dict): key_val_dict[f"model.backbone.{backbone_name}.output_stride"] = output_stride key_val_dict[f"model.backbone.{backbone_name}.max_stride"] = max_stride + # Convert random flip dropdown selection to config. + random_flip = key_val_dict.get( + "optimization.augmentation_config.random_flip", "none" + ) + if random_flip == "none": + key_val_dict["optimization.augmentation_config.random_flip"] = False + else: + key_val_dict["optimization.augmentation_config.random_flip"] = True + key_val_dict["optimization.augmentation_config.flip_horizontal"] = ( + random_flip == "horizontal" + ) + def find_backbone_name_from_key_val_dict(key_val_dict: dict): """Find the backbone model name from the config dictionary.""" diff --git a/sleap/gui/overlays/instance.py b/sleap/gui/overlays/instance.py index 998e15b6f..7ec8ae894 100644 --- a/sleap/gui/overlays/instance.py +++ b/sleap/gui/overlays/instance.py @@ -43,7 +43,9 @@ def add_to_scene(self, video, frame_idx): has_user = any((True for inst in instances if not hasattr(inst, "score"))) for instance in instances: - self.player.addInstance(instance=instance) + self.player.addInstance( + instance=instance, markerRadius=self.state.get("marker size", 4) + ) self.player.showLabels(self.state.get("show labels", default=True)) self.player.showEdges(self.state.get("show edges", default=True)) diff --git a/sleap/gui/shortcuts.py b/sleap/gui/shortcuts.py index 280081d9f..ec76e45a3 100644 --- a/sleap/gui/shortcuts.py +++ b/sleap/gui/shortcuts.py @@ -3,9 +3,7 @@ """ from typing import Dict, Union - from PySide2.QtGui import QKeySequence - from sleap import util @@ -100,6 +98,11 @@ def save(self): util.save_config_yaml("shortcuts.yaml", data) + def reset_to_default(self): + """Reset shortcuts to default and save.""" + self._shortcuts = util.get_config_yaml("shortcuts.yaml", get_defaults=True) + self.save() + def __getitem__(self, idx: Union[slice, int, str]) -> Union[str, Dict[str, str]]: """ Returns shortcut value, accessed by range, index, or key. diff --git a/sleap/instance.py b/sleap/instance.py index 34b1feb3e..214fc8f25 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -319,7 +319,7 @@ class Track: name: A name given to this track for identifying purposes. """ - spawned_on: int = attr.ib(converter=int) + spawned_on: int = attr.ib(default=0, converter=int) name: str = attr.ib(default="", converter=str) def matches(self, other: "Track"): @@ -341,10 +341,9 @@ def matches(self, other: "Track"): # that are created in post init so they are not serialized. -@attr.s(eq=False, order=False, slots=True) +@attr.s(eq=False, order=False, slots=True, repr=False, str=False) class Instance: - """ - The class :class:`Instance` represents a labelled instance of a skeleton. + """This class represents a labeled instance. Args: skeleton: The skeleton that this instance is associated with. @@ -375,8 +374,7 @@ class Instance: def _validate_from_predicted_( self, attribute, from_predicted: Optional["PredictedInstance"] ): - """ - Validation method called by attrs. + """Validation method called by attrs. Checks that from_predicted is None or :class:`PredictedInstance` @@ -391,13 +389,13 @@ def _validate_from_predicted_( """ if from_predicted is not None and type(from_predicted) != PredictedInstance: raise TypeError( - f"Instance.from_predicted type must be PredictedInstance (not {type(from_predicted)})" + f"Instance.from_predicted type must be PredictedInstance (not " + "{type(from_predicted)})" ) @_points.validator def _validate_all_points(self, attribute, points: Union[dict, PointArray]): - """ - Validation method called by attrs. + """Validation method called by attrs. Checks that all the _points defined for the skeleton are found in the skeleton. @@ -425,12 +423,12 @@ def _validate_all_points(self, attribute, points: Union[dict, PointArray]): elif isinstance(points, PointArray): if len(points) != len(self.skeleton.nodes): raise ValueError( - "PointArray does not have the same number of rows as skeleton nodes." + "PointArray does not have the same number of rows as skeleton " + "nodes." ) def __attrs_post_init__(self): - """ - Method called by attrs after __init__() + """Method called by attrs after __init__(). Initializes points if none were specified when creating object, caches list of nodes so what we can still find points in array @@ -441,12 +439,8 @@ def __attrs_post_init__(self): Raises: ValueError: If object has no `Skeleton`. - - Returns: - None """ - - if not self.skeleton: + if self.skeleton is None: raise ValueError("No skeleton set for Instance") # If the user did not pass a points list initialize a point array for future @@ -472,8 +466,7 @@ def __attrs_post_init__(self): def _points_dict_to_array( points: Dict[Union[str, Node], Point], parray: PointArray, skeleton: Skeleton ): - """ - Sets values in given :class:`PointsArray` from dictionary. + """Set values in given :class:`PointsArray` from dictionary. Args: points: The dictionary of points. Keys can be either node @@ -485,11 +478,7 @@ def _points_dict_to_array( Raises: ValueError: If dictionary keys are not either all strings or all :class:`Node`s. - - Returns: - None """ - # Check if the dict contains all strings is_string_dict = set(map(type, points)) == {str} @@ -522,8 +511,7 @@ def _points_dict_to_array( pass def _node_to_index(self, node: Union[str, Node]) -> int: - """ - Helper method to get the index of a node from its name. + """Helper method to get the index of a node from its name. Args: node: Node name or :class:`Node` object. @@ -537,8 +525,7 @@ def __getitem__( self, node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray], ) -> Union[List[Point], Point, np.ndarray]: - """ - Get the Points associated with particular skeleton node(s). + """Get the Points associated with particular skeleton node(s). Args: node: A single node or list of nodes within the skeleton @@ -575,8 +562,7 @@ def __getitem__( return self._points[node] def __contains__(self, node: Union[str, Node, int]) -> bool: - """ - Whether this instance has a point with the specified node. + """Whether this instance has a point with the specified node. Args: node: Node name or :class:`Node` object. @@ -602,8 +588,7 @@ def __setitem__( node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray], value: Union[List[Point], Point, np.ndarray], ): - """ - Set the point(s) for given node(s). + """Set the point(s) for given node(s). Args: node: Either node (by name or `Node`) or list of nodes. @@ -613,9 +598,6 @@ def __setitem__( IndexError: If lengths of lists don't match, or if exactly one of the inputs is a list. KeyError: If skeleton does not have (one of) the node(s). - - Returns: - None """ # Make sure node and value, if either are lists, are of compatible size if isinstance(node, (list, np.ndarray)): @@ -647,8 +629,7 @@ def __setitem__( self._points[node_idx] = value def __delitem__(self, node: Union[str, Node]): - """ - Delete node key and points associated with that node. + """Delete node key and points associated with that node. Args: node: Node name or :class:`Node` object. @@ -668,9 +649,24 @@ def __delitem__(self, node: Union[str, Node]): f"The underlying skeleton ({self.skeleton}) has no node '{node}'" ) + def __repr__(self) -> str: + """Return string representation of this object.""" + pts = [] + for node, pt in self.nodes_points: + pts.append(f"{node.name}: ({pt.x:.1f}, {pt.y:.1f})") + pts = ", ".join(pts) + + return ( + "Instance(" + f"video={self.video}, " + f"frame_idx={self.frame_idx}, " + f"points=[{pts}], " + f"track={self.track}" + ")" + ) + def matches(self, other: "Instance") -> bool: - """ - Whether two instances match by value. + """Whether two instances match by value. Checks the types, points, track, and frame index. @@ -703,9 +699,7 @@ def matches(self, other: "Instance") -> bool: @property def nodes(self) -> Tuple[Node, ...]: - """ - The tuple of nodes that have been labelled for this instance. - """ + """Return nodes that have been labelled for this instance.""" self._fix_array() return tuple( self._nodes[i] @@ -715,32 +709,23 @@ def nodes(self) -> Tuple[Node, ...]: @property def nodes_points(self) -> List[Tuple[Node, Point]]: - """ - The list of (node, point) tuples for all labelled points. - """ + """Return a list of (node, point) tuples for all labeled points.""" names_to_points = dict(zip(self.nodes, self.points)) return names_to_points.items() @property def points(self) -> Tuple[Point, ...]: - """ - The tuple of labelled points, in order they were labelled. - """ + """Return a tuple of labelled points, in the order they were labelled.""" self._fix_array() return tuple(point for point in self._points if not point.isnan()) def _fix_array(self): - """ - Fixes PointArray after nodes have been added or removed. + """Fix PointArray after nodes have been added or removed. This updates the PointArray as required by comparing the cached list of nodes to the nodes in the `Skeleton` object (which may have changed). - - Returns: - None """ - # Check if cached skeleton nodes are different than current nodes if self._nodes != self.skeleton.nodes: # Create new PointArray (or PredictedPointArray) @@ -759,8 +744,7 @@ def _fix_array(self): def get_points_array( self, copy: bool = True, invisible_as_nan: bool = False, full: bool = False ) -> Union[np.ndarray, np.recarray]: - """ - Return the instance's points in array form. + """Return the instance's points in array form. Args: copy: If True, the return a copy of the points array as an ndarray. @@ -803,14 +787,13 @@ def get_points_array( @property def points_array(self) -> np.ndarray: - """ - Nx2 array of x and y for visible points. + """Return array of x and y coordinates for visible points. - Row in array corresponds to order of points in skeleton. - Invisible points will have nans. + Row in array corresponds to order of points in skeleton. Invisible points will + be denoted by NaNs. Returns: - ndarray of visible point coordinates. + A numpy array of of shape `(n_nodes, 2)` point coordinates. """ return self.get_points_array(invisible_as_nan=True) @@ -820,13 +803,19 @@ def numpy(self) -> np.ndarray: Alias for `points_array`. Returns: - Array of shape (n_nodes, 2) of dtype float32 containing the coordinates of - the instance's nodes. Missing/not visible nodes will be replaced with NaNs. + Array of shape `(n_nodes, 2)` of dtype `float32` containing the coordinates + of the instance's nodes. Missing/not visible nodes will be replaced with + `NaN`. """ return self.points_array def transform_points(self, transformation_matrix): - """Applies transformation matrix to points.""" + """Apply affine transformation matrix to points in the instance. + + Args: + transformation_matrix: Affine transformation matrix as a numpy array of + shape `(3, 3)`. + """ points = self.get_points_array(copy=True, full=False, invisible_as_nan=False) if transformation_matrix.shape[1] == 3: @@ -843,36 +832,50 @@ def transform_points(self, transformation_matrix): @property def centroid(self) -> np.ndarray: - """Returns instance centroid as (x,y) numpy row vector.""" + """Return instance centroid as an array of `(x, y)` coordinates + + Notes: + This computes the centroid as the median of the visible points. + """ points = self.points_array centroid = np.nanmedian(points, axis=0) return centroid @property def bounding_box(self) -> np.ndarray: - """Returns the instance's containing bounding box in [y1, x1, y2, x2] format.""" + """Return bounding box containing all points in `[y1, x1, y2, x2]` format.""" points = self.points_array bbox = np.concatenate( [np.nanmin(points, axis=0)[::-1], np.nanmax(points, axis=0)[::-1]] ) return bbox + @property + def midpoint(self) -> np.ndarray: + """Return the center of the bounding box of the instance points.""" + y1, x1, y2, x2 = self.bounding_box + return np.array([(x2 - x1) / 2, (y2 - y1) / 2]) + @property def n_visible_points(self) -> int: - """Returns the count of points that are visible in this instance.""" + """Return the number of visible points in this instance.""" return sum(~np.isnan(self.points_array[:, 0])) - @property - def frame_idx(self) -> Optional[int]: - """ - Get the index of the frame that this instance was found on. + def __len__(self) -> int: + """Return the number of visible points in this instance.""" + return self.n_visible_points - This is a convenience method for Instance.frame.frame_idx. + @property + def video(self) -> Optional[Video]: + """Return the video of the labeled frame this instance is associated with.""" + if self.frame is None: + return None + else: + return self.frame.video - Returns: - The frame number this instance was found on, or None if the - instance is not associated with frame. - """ + @property + def frame_idx(self) -> Optional[int]: + """Return the index of the labeled frame this instance is associated with.""" if self.frame is None: return None else: @@ -880,22 +883,20 @@ def frame_idx(self) -> Optional[int]: @classmethod def from_pointsarray( - cls, - points: np.ndarray, - skeleton: Skeleton, - track: Optional[Track] = None, + cls, points: np.ndarray, skeleton: Skeleton, track: Optional[Track] = None ) -> "Instance": - """Create an instance from pointsarray. + """Create an instance from an array of points. Args: - points: A numpy array of shape (n_nodes, 2) and dtype float32 that contains - the points in (x, y) coordinates of each node. Missing nodes should be - represented as NaNs. - skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the - predicted instance. + points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that + contains the points in (x, y) coordinates of each node. Missing nodes + should be represented as `NaN`. + skeleton: A `sleap.Skeleton` instance with `n_nodes` nodes to associate with + the instance. + track: Optional `sleap.Track` object to associate with the instance. Returns: - A new Instance. + A new `Instance` object. """ predicted_points = dict() for point, node_name in zip(points, skeleton.node_names): @@ -906,8 +907,30 @@ def from_pointsarray( return cls(points=predicted_points, skeleton=skeleton, track=track) + @classmethod + def from_numpy( + cls, points: np.ndarray, skeleton: Skeleton, track: Optional[Track] = None + ) -> "Instance": + """Create an instance from a numpy array. -@attr.s(eq=False, order=False, slots=True) + Args: + points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that + contains the points in (x, y) coordinates of each node. Missing nodes + should be represented as `NaN`. + skeleton: A `sleap.Skeleton` instance with `n_nodes` nodes to associate with + the instance. + track: Optional `sleap.Track` object to associate with the instance. + + Returns: + A new `Instance` object. + + Notes: + This is an alias for `Instance.from_pointsarray()`. + """ + return cls.from_pointsarray(points, skeleton, track=track) + + +@attr.s(eq=False, order=False, slots=True, repr=False, str=False) class PredictedInstance(Instance): """ A predicted instance is an output of the inference procedure. @@ -929,35 +952,55 @@ def __attrs_post_init__(self): if self.from_predicted is not None: raise ValueError("PredictedInstance should not have from_predicted.") + def __repr__(self) -> str: + """Return string representation of this object.""" + pts = [] + for node, pt in self.nodes_points: + pts.append(f"{node.name}: ({pt.x:.1f}, {pt.y:.1f}, {pt.score:.2f})") + pts = ", ".join(pts) + + return ( + "PredictedInstance(" + f"video={self.video}, " + f"frame_idx={self.frame_idx}, " + f"points=[{pts}], " + f"score={self.score:.2f}, " + f"track={self.track}, " + f"tracking_score={self.tracking_score:.2f}" + ")" + ) + @property def points_and_scores_array(self) -> np.ndarray: - """ - (N, 3) array of (x, y, score) for predicted points. + """Return the instance points and scores as an array. - Row in arrow corresponds to order of points in skeleton. - Invisible points will have NaNs. + This will be a `(n_nodes, 3)` array of `(x, y, score)` for each predicted point. - Returns: - ndarray of visible point coordinates and scores. + Rows in the array correspond to the order of points in skeleton. Invisible + points will be represented as NaNs. """ pts = self.get_points_array(full=True, copy=True, invisible_as_nan=True) return pts[:, (0, 1, 4)] # (x, y, score) + @property + def scores(self) -> np.ndarray: + """Return point scores for each predicted node.""" + return self.points_and_scores_array[:, 2] + @classmethod def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": - """ - Create a :class:`PredictedInstance` from an :class:`Instance`. + """Create a `PredictedInstance` from an `Instance`. - The fields are copied in a shallow manner with the exception of - points. For each point in the instance a :class:`PredictedPoint` - is created with score set to default value. + The fields are copied in a shallow manner with the exception of points. For each + point in the instance a `PredictedPoint` is created with score set to default + value. Args: - instance: The Instance object to shallow copy data from. + instance: The `Instance` object to shallow copy data from. score: The score for this instance. Returns: - A PredictedInstance for the given Instance. + A `PredictedInstance` for the given `Instance`. """ kw_args = attr.asdict( instance, @@ -980,18 +1023,19 @@ def from_arrays( """Create a predicted instance from data arrays. Args: - points: A numpy array of shape (n_nodes, 2) and dtype float32 that contains - the points in (x, y) coordinates of each node. Missing nodes should be - represented as NaNs. - point_confidences: A numpy array of shape (n_nodes,) and dtype float32 that - contains the confidence/score of the points. + points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that + contains the points in `(x, y)` coordinates of each node. Missing nodes + should be represented as `NaN`. + point_confidences: A numpy array of shape `(n_nodes,)` and dtype `float32` + that contains the confidence/score of the points. instance_score: Scalar float representing the overall instance score, e.g., the PAF grouping score. skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. + track: Optional `sleap.Track` to associate with the instance. Returns: - A new PredictedInstance. + A new `PredictedInstance`. """ predicted_points = dict() for point, confidence, node_name in zip( @@ -1013,18 +1057,15 @@ def from_arrays( def make_instance_cattr() -> cattr.Converter: - """ - Create a cattr converter for Lists of Instances/PredictedInstances. + """Create a cattr converter for Lists of Instances/PredictedInstances. - This is required because cattrs doesn't automatically detect the - class when the attributes of one class are a subset of another. + This is required because cattrs doesn't automatically detect the class when the + attributes of one class are a subset of another. Returns: - A cattr converter with hooks registered for structuring and - unstructuring :class:`Instance` objects and - :class:`PredictedInstance`s. + A cattr converter with hooks registered for structuring and unstructuring + `Instance` and `PredictedInstance` objects. """ - converter = cattr.Converter() #### UNSTRUCTURE HOOKS @@ -1105,12 +1146,12 @@ def structure_point_array(x, t): @attr.s(auto_attribs=True, eq=False, repr=False, str=False) class LabeledFrame: - """ - Holds labeled data for a single frame of a video. + """Holds labeled data for a single frame of a video. Args: video: The :class:`Video` associated with this frame. frame_idx: The index of frame in video. + instances: List of instances associated with the frame. """ video: Video = attr.ib() @@ -1120,8 +1161,7 @@ class LabeledFrame: ) def __attrs_post_init__(self): - """ - Called by attrs. + """Called by attrs. Updates :attribute:`Instance.frame` for each instance associated with this :class:`LabeledFrame`. @@ -1132,19 +1172,19 @@ def __attrs_post_init__(self): instance.frame = self def __len__(self) -> int: - """Returns number of instances associated with frame.""" + """Return number of instances associated with frame.""" return len(self.instances) def __getitem__(self, index) -> Instance: - """Returns instance (retrieved by index).""" + """Return instance (retrieved by index).""" return self.instances.__getitem__(index) def index(self, value: Instance) -> int: - """Returns index of given :class:`Instance`.""" + """Return index of given :class:`Instance`.""" return self.instances.index(value) def __delitem__(self, index): - """Removes instance (by index) from frame.""" + """Remove instance (by index) from frame.""" value = self.instances.__getitem__(index) self.instances.__delitem__(index) @@ -1162,8 +1202,7 @@ def __repr__(self) -> str: ) def insert(self, index: int, value: Instance): - """ - Adds instance to frame. + """Add instance to frame. Args: index: The index in list of frame instances where we should @@ -1179,8 +1218,7 @@ def insert(self, index: int, value: Instance): value.frame = self def __setitem__(self, index, value: Instance): - """ - Sets nth instance in frame to the given instance. + """Set nth instance in frame to the given instance. Args: index: The index of instance to replace with new instance. @@ -1197,8 +1235,7 @@ def __setitem__(self, index, value: Instance): def find( self, track: Optional[Union[Track, int]] = -1, user: bool = False ) -> List[Instance]: - """ - Retrieves instances (if any) matching specifications. + """Retrieve instances (if any) matching specifications. Args: track: The :class:`Track` to match. Note that None will only @@ -1218,13 +1255,12 @@ def find( @property def instances(self) -> List[Instance]: - """Returns list of all instances associated with this frame.""" + """Return list of all instances associated with this frame.""" return self._instances @instances.setter def instances(self, instances: List[Instance]): - """ - Sets the list of instances associated with this frame. + """Set the list of instances associated with this frame. Updates the `frame` attribute on each instance to the :class:`LabeledFrame` which will contain the instance. @@ -1246,14 +1282,12 @@ def instances(self, instances: List[Instance]): @property def user_instances(self) -> List[Instance]: - """Returns list of user instances associated with this frame.""" - return [ - inst for inst in self._instances if not isinstance(inst, PredictedInstance) - ] + """Return list of user instances associated with this frame.""" + return [inst for inst in self._instances if type(inst) == Instance] @property def training_instances(self) -> List[Instance]: - """Returns list of user instances with points for training.""" + """Return list of user instances with points for training.""" return [ inst for inst in self._instances @@ -1262,23 +1296,22 @@ def training_instances(self) -> List[Instance]: @property def predicted_instances(self) -> List[PredictedInstance]: - """Returns list of predicted instances associated with frame.""" - return [inst for inst in self._instances if isinstance(inst, PredictedInstance)] + """Return list of predicted instances associated with frame.""" + return [inst for inst in self._instances if type(inst) == PredictedInstance] @property def has_user_instances(self) -> bool: - """Whether the frame contains any user instances.""" + """Return whether the frame contains any user instances.""" return len(self.user_instances) > 0 @property def has_predicted_instances(self) -> bool: - """Whether the frame contains any predicted instances.""" + """Return whether the frame contains any predicted instances.""" return len(self.predicted_instances) > 0 @property def unused_predictions(self) -> List[Instance]: - """ - Returns list of "unused" :class:`PredictedInstance` objects in frame. + """Return a list of "unused" :class:`PredictedInstance` objects in frame. This is all the :class:`PredictedInstance` objects which do not have a corresponding :class:`Instance` in the same track in frame. @@ -1316,8 +1349,7 @@ def unused_predictions(self) -> List[Instance]: @property def instances_to_show(self) -> List[Instance]: - """ - Return a list of instances to show in GUI for this frame. + """Return a list of instances to show in GUI for this frame. This list will not include any predicted instances for which there's a corresponding regular instance. @@ -1342,7 +1374,7 @@ def instances_to_show(self) -> List[Instance]: def merge_frames( labeled_frames: List["LabeledFrame"], video: "Video", remove_redundant=True ) -> List["LabeledFrame"]: - """Merged LabeledFrames for same video and frame index. + """Return merged LabeledFrames for same video and frame index. Args: labeled_frames: List of :class:`LabeledFrame` objects to merge. @@ -1390,8 +1422,7 @@ def merge_frames( def complex_merge_between( cls, base_labels: "Labels", new_frames: List["LabeledFrame"] ) -> Tuple[Dict[Video, Dict[int, List[Instance]]], List[Instance], List[Instance]]: - """ - Merge data from new frames into a :class:`Labels` object. + """Merge data from new frames into a :class:`Labels` object. Everything that can be merged cleanly is merged, any conflicts are returned. @@ -1444,8 +1475,7 @@ def complex_merge_between( def complex_frame_merge( cls, base_frame: "LabeledFrame", new_frame: "LabeledFrame" ) -> Tuple[List[Instance], List[Instance], List[Instance]]: - """ - Merge two frames, return conflicts if any. + """Merge two frames, return conflicts if any. A conflict occurs when * each frame has Instances which don't perfectly match those @@ -1542,34 +1572,40 @@ def image(self) -> np.ndarray: """Return the image for this frame of shape (height, width, channels).""" return self.video.get_frame(self.frame_idx) - def plot(self, image: bool = True): + def numpy(self) -> np.ndarray: + """Return the instances as an array of shape (instances, nodes, 2).""" + return np.stack([inst.numpy() for inst in self.instances], axis=0) + + def plot(self, image: bool = True, scale: float = 1.0): """Plot the frame with all instances. Args: image: If False, only the instances will be plotted without loading the original image. + scale: Relative scaling for the figure. Notes: - See sleap.nn.viz.plot_img and sleap.nn.viz.plot_instances for more plotting - options. + See `sleap.nn.viz.plot_img` and `sleap.nn.viz.plot_instances` for more + plotting options. """ if image: - sleap.nn.viz.plot_img(self.image) + sleap.nn.viz.plot_img(self.image, scale=scale) sleap.nn.viz.plot_instances(self.instances) - def plot_predicted(self, image: bool = True): + def plot_predicted(self, image: bool = True, scale: float = 1.0): """Plot the frame with all predicted instances. Args: image: If False, only the instances will be plotted without loading the original image. + scale: Relative scaling for the figure. Notes: - See sleap.nn.viz.plot_img and sleap.nn.viz.plot_instances for more plotting - options. + See `sleap.nn.viz.plot_img` and `sleap.nn.viz.plot_instances` for more + plotting options. """ if image: - sleap.nn.viz.plot_img(self.image) + sleap.nn.viz.plot_img(self.image, scale=scale) sleap.nn.viz.plot_instances( self.predicted_instances, color_by_track=(len(self.predicted_instances) > 0) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 30a848638..a6ee370a5 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -51,6 +51,7 @@ Iterable, Any, Set, + Callable, ) import attr @@ -203,7 +204,7 @@ def get_video_track_occupancy(self, video: Video) -> Dict[Track, RangeList]: return self._track_occupancy[video] def remove_frame(self, frame: LabeledFrame): - """Remvoe frame and update cache as needed.""" + """Remove frame and update cache as needed.""" self._lf_by_video[frame.video].remove(frame) # We'll assume that there's only a single LabeledFrame for this video and # frame_idx, and remove the frame_idx from the cache. @@ -412,7 +413,7 @@ class Labels(MutableSequence): skeletons: List[Skeleton] = attr.ib(default=attr.Factory(list)) nodes: List[Node] = attr.ib(default=attr.Factory(list)) tracks: List[Track] = attr.ib(default=attr.Factory(list)) - suggestions: List["SuggestionFrame"] = attr.ib(default=attr.Factory(list)) + suggestions: List[SuggestionFrame] = attr.ib(default=attr.Factory(list)) negative_anchors: Dict[Video, list] = attr.ib(default=attr.Factory(dict)) provenance: Dict[Text, Union[str, int, float, bool]] = attr.ib( default=attr.Factory(dict) @@ -644,6 +645,8 @@ def __getitem__(self, key, *args) -> Union[LabeledFrame, List[LabeledFrame]]: scalar key was provided. """ if len(args) > 0: + if type(key) != tuple: + key = (key,) key = key + tuple(args) if isinstance(key, int): @@ -685,6 +688,46 @@ def __getitem__(self, key, *args) -> Union[LabeledFrame, List[LabeledFrame]]: else: raise KeyError("Invalid label indexing arguments.") + def get(self, *args) -> Union[LabeledFrame, List[LabeledFrame]]: + """Get an item from the labels or return `None` if not found. + + This is a safe version of `labels[...]` that will not raise an exception if the + item is not found. + """ + try: + return self.__getitem__(*args) + except KeyError: + return None + + def extract(self, inds) -> "Labels": + """Extract labeled frames from indices and return a new `Labels` object. + + Args: + inds: Any valid indexing keys, e.g., a range, slice, list of label indices, + numpy array, `Video`, etc. See `__getitem__` for full list. + + Returns: + A new `Labels` object with the specified labeled frames. + + This will preserve the other data structures even if they are not found in + the extracted labels, including: + - `Labels.videos` + - `Labels.skeletons` + - `Labels.tracks` + - `Labels.suggestions` + - `Labels.provenance` + """ + lfs = self.__getitem__(inds) + new_labels = type(self)( + labeled_frames=lfs, + videos=self.videos, + skeletons=self.skeletons, + tracks=self.tracks, + suggestions=self.suggestions, + provenance=self.provenance, + ) + return new_labels + def __setitem__(self, index, value: LabeledFrame): """Set labeled frame at given index.""" # TODO: Maybe we should remove this method altogether? @@ -733,6 +776,13 @@ def remove_frames(self, lfs: List[LabeledFrame]): self.labeled_frames = [lf for lf in self.labeled_frames if lf not in to_remove] self.update_cache() + def remove_empty_frames(self): + """Remove frames with no instances.""" + self.labeled_frames = [ + lf for lf in self.labeled_frames if len(lf.instances) > 0 + ] + self.update_cache() + def find( self, video: Video, @@ -830,10 +880,19 @@ def find_last( return label @property - def user_labeled_frames(self): + def user_labeled_frames(self) -> List[LabeledFrame]: """Return all labeled frames with user (non-predicted) instances.""" return [lf for lf in self.labeled_frames if lf.has_user_instances] + @property + def user_labeled_frame_inds(self) -> List[int]: + """Return a list of indices of frames with user labeled instances.""" + return [i for i, lf in enumerate(self.labeled_frames) if lf.has_user_instances] + + def with_user_labels_only(self) -> "Labels": + """Return a new `Labels` object with only user labels.""" + return self.extract(self.user_labeled_frame_inds) + def get_labeled_frame_count(self, video: Optional[Video] = None, filter: Text = ""): return self._cache.get_frame_count(video, filter) @@ -855,46 +914,37 @@ def all_instances(self) -> List[Instance]: @property def user_instances(self) -> List[Instance]: """Return list of all user (non-predicted) instances.""" - return [inst for inst in self.all_instances if isinstance(inst, Instance)] + return [inst for inst in self.all_instances if type(inst) == Instance] @property def predicted_instances(self) -> List[PredictedInstance]: """Return list of all predicted instances.""" - return [ - inst for inst in self.all_instances if isinstance(inst, PredictedInstance) - ] + return [inst for inst in self.all_instances if type(inst) == PredictedInstance] def describe(self): """Print basic statistics about the labels dataset.""" - print(f"Videos: {len(self.videos)}") - n_user_inst = len(self.user_instances) - n_predicted_inst = len(self.predicted_instances) - print( - f"Instances: {n_user_inst:,} (user-labeled), " - f"{n_predicted_inst:,} (predicted), " - f"{n_user_inst + n_predicted_inst:,} (total)" - ) - n_user_only = 0 - n_pred_only = 0 - n_both = 0 + print(f"Skeleton: {self.skeleton}") + print(f"Videos: {[v.filename for v in self.videos]}") + n_user = 0 + n_pred = 0 + n_user_inst = 0 + n_pred_inst = 0 for lf in self.labeled_frames: - has_user = lf.has_user_instances - has_pred = lf.has_predicted_instances - if has_user and not has_pred: - n_user_only += 1 - elif not has_user and has_pred: - n_pred_only += 1 - elif has_user and has_pred: - n_both += 1 - n_total = len(self.labeled_frames) - print( - f"Frames: {n_user_only:,} (user-labeled), " - f"{n_pred_only:,} (predicted), " - f"{n_both:,} (both), " - f"{n_total:,} (total)" - ) - - def instances(self, video: Video = None, skeleton: Skeleton = None): + if lf.has_user_instances: + n_user += 1 + n_user_inst += len(lf.user_instances) + if lf.has_predicted_instances: + n_pred += 1 + n_pred_inst += len(lf.predicted_instances) + print(f"Frames (user/predicted): {n_user:,}/{n_pred:,}") + print(f"Instances (user/predicted): {n_user_inst:,}/{n_pred_inst:,}") + print("Tracks:", self.tracks) + print(f"Suggestions: {len(self.suggestions):,}") + print("Provenance:", self.provenance) + + def instances( + self, video: Optional[Video] = None, skeleton: Optional[Skeleton] = None + ): """Iterate over instances in the labels, optionally with filters. Args: @@ -1086,16 +1136,54 @@ def does_track_match(inst, tr, labeled_frame): ] return track_frame_inst - def get_video_suggestions(self, video: Video) -> List[int]: + def add_suggestion(self, video: Video, frame_idx: int): + """Add a suggested frame to the labels. + + Args: + video: `sleap.Video` instance of the suggestion. + frame_idx: Index of the frame of the suggestion. + """ + for suggestion in self.suggestions: + if suggestion.video == video and suggestion.frame_idx == frame_idx: + return + self.suggestions.append(SuggestionFrame(video=video, frame_idx=frame_idx)) + + def remove_suggestion(self, video: Video, frame_idx: int): + """Remove a suggestion from the list by video and frame index. + + Args: + video: `sleap.Video` instance of the suggestion. + frame_idx: Index of the frame of the suggestion. + """ + for suggestion in self.suggestions: + if suggestion.video == video and suggestion.frame_idx == frame_idx: + self.suggestions.remove(suggestion) + return + + def get_video_suggestions( + self, video: Video, user_labeled: bool = True + ) -> List[int]: """Return a list of suggested frame indices. Args: video: Video to get suggestions for. + user_labeled: If `True` (the default), return frame indices for suggestions + that already have user labels. If `False`, only suggestions with no user + labeled instances will be returned. Returns: - Indices of the labeled frames for for the specified video. + Indices of the suggested frames for for the specified video. """ - return [item.frame_idx for item in self.suggestions if item.video == video] + frame_indices = [] + for suggestion in self.suggestions: + if suggestion.video == video: + fidx = suggestion.frame_idx + if not user_labeled: + lf = self.get((video, fidx)) + if lf is not None and lf.has_user_instances: + continue + frame_indices.append(fidx) + return frame_indices def get_suggestions(self) -> List[SuggestionFrame]: """Return all suggestions as a list of SuggestionFrame items.""" @@ -1170,6 +1258,47 @@ def delete_suggestions(self, video): """Delete suggestions for specified video.""" self.suggestions = [item for item in self.suggestions if item.video != video] + def clear_suggestions(self): + """Delete all suggestions.""" + self.suggestions = [] + + @property + def unlabeled_suggestions(self) -> List[SuggestionFrame]: + """Return suggestions without user labels.""" + unlabeled_suggestions = [] + for suggestion in self.suggestions: + lf = self.get(suggestion.video, suggestion.frame_idx) + if lf is None or not lf.has_user_instances: + unlabeled_suggestions.append(suggestion) + return unlabeled_suggestions + + def get_unlabeled_suggestion_inds(self) -> List[int]: + """Find labeled frames for unlabeled suggestions and return their indices. + + This is useful for generating a list of example indices for inference on + unlabeled suggestions. + + Returns: + List of indices of the labeled frames that correspond to the suggestions + that do not have user instances. + + If a labeled frame corresponding to a suggestion does not exist, an empty + one will be created. + + See also: `Labels.remove_empty_frames` + """ + inds = [] + for suggestion in self.unlabeled_suggestions: + lf = self.get((suggestion.video, suggestion.frame_idx)) + if lf is None: + self.append( + LabeledFrame(video=suggestion.video, frame_idx=suggestion.frame_idx) + ) + inds.append(len(self.labeled_frames) - 1) + else: + inds.append(self.index(lf)) + return inds + def add_video(self, video: Video): """Add a video to the labels if it is not already in it. @@ -1643,6 +1772,29 @@ def save( suggested=embed_suggested, ) + def export(self, filename: str): + """Export labels to analysis HDF5 format. + + This expects the labels to contain data for a single video (e.g., predictions). + + Args: + filename: Path to output HDF5 file. + + Notes: + This will write the contents of the labels out as a HDF5 file without + complete metadata. + + The resulting file will have datasets: + - `/node_names`: List of skeleton node names. + - `/track_names`: List of track names. + - `/tracks`: All coordinates of the instances in the labels. + - `/track_occupancy`: Mask denoting which instances are present in each + frame. + """ + from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor + + SleapAnalysisAdaptor.write(filename, self) + @classmethod def load_json(cls, filename: str, *args, **kwargs) -> "Labels": from .format import read @@ -1679,6 +1831,18 @@ def load_deeplabcut(cls, filename: str) -> "Labels": return read(filename, for_object="labels", as_format="deeplabcut") + @classmethod + def load_deeplabcut_folder(cls, filename: str) -> "Labels": + csv_files = glob(f"{filename}/*/*.csv") + merged_labels = None + for csv_file in csv_files: + labels = cls.load_file(csv_file, as_format="deeplabcut") + if merged_labels is None: + merged_labels = labels + else: + merged_labels.extend_from(labels, unify=True) + return merged_labels + @classmethod def load_coco( cls, filename: str, img_dir: str, use_missing_gui: bool = False @@ -1750,6 +1914,7 @@ def save_frame_data_hdf5( user_labeled: bool = True, all_labeled: bool = False, suggested: bool = False, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> List[HDF5Video]: """Write images for labeled frames from all videos to hdf5 file. @@ -1765,12 +1930,21 @@ def save_frame_data_hdf5( Defaults to `False`. suggested: Include suggested frames even if they do not have instances. Useful for inference after training. Defaults to `False`. + progress_callback: If provided, this function will be called to report the + progress of the frame data saving. This function should be a callable + of the form: `fn(n, n_total)` where `n` is the number of frames saved so + far and `n_total` is the total number of frames that will be saved. This + is called after each video is processed. If the function has a return + value and it returns `False`, saving will be canceled and the output + deleted. Returns: A list of :class:`HDF5Video` objects with the stored frames. """ - new_vids = [] - for v_idx, video in enumerate(self.videos): + # Build list of frames to save. + vids = [] + frame_idxs = [] + for video in self.videos: lfs_v = self.find(video) frame_nums = [ lf.frame_idx @@ -1784,13 +1958,29 @@ def save_frame_data_hdf5( if suggestion.video == video ] frame_nums = sorted(list(set(frame_nums))) + vids.append(video) + frame_idxs.append(frame_nums) + + n_total = sum([len(x) for x in frame_idxs]) + n = 0 + # Save images for each video. + new_vids = [] + for v_idx, (video, frame_nums) in enumerate(zip(vids, frame_idxs)): vid = video.to_hdf5( path=output_path, dataset=f"video{v_idx}", format=format, frame_numbers=frame_nums, ) + n += len(frame_nums) + if progress_callback is not None: + # Notify update callback. + ret = progress_callback(n, n_total) + if ret == False: + vid.close() + return [] + vid.close() new_vids.append(vid) @@ -1837,6 +2027,57 @@ def to_pipeline( pipeline += pipelines.Prefetcher() return pipeline + def numpy( + self, video: Optional[Video] = None, all_frames: bool = True + ) -> np.ndarray: + """Construct a numpy array from tracked instance points. + + Args: + video: Video to convert to numpy arrays. If `None` (the default), uses the + first video. + all_frames: If `True` (the default), allocate array of the same number of + frames as the video. If `False`, only return data between the first and + last frame with data. + + Returns: + An array of tracks of shape `(n_frames, n_tracks, n_nodes, 2)`. + + Missing data will be replaced with `np.nan`. + + Notes: + This method assumes that instances have tracks assigned and is intended to + function primarily for single-video prediction results. + """ + if video is None: + video = self.videos[0] + lfs = self.find(video=video) + + if all_frames: + first_frame, last_frame = None, None + for lf in lfs: + if first_frame is None: + first_frame = lf.frame_idx + if last_frame is None: + last_frame = lf.frame_idx + first_frame = min(first_frame, lf.frame_idx) + last_frame = max(last_frame, lf.frame_idx) + else: + first_frame, last_frame = 0, video.shape[0] - 1 + + n_frames = last_frame - first_frame + 1 + n_tracks = len(self.tracks) + n_nodes = len(self.skeleton.nodes) + + tracks = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, dtype="float32") + for lf in lfs: + i = lf.frame_idx - first_frame + for inst in lf: + if inst.track is not None: + j = self.tracks.index(inst.track) + tracks[i, j] = inst.numpy() + + return tracks + @classmethod def make_gui_video_callback(cls, search_paths: Optional[List] = None) -> Callable: return cls.make_video_callback(search_paths=search_paths, use_gui=True) @@ -1964,9 +2205,16 @@ def load_file( filename: Text, detect_videos: bool = True, search_paths: Optional[Union[List[Text], Text]] = None, + match_to: Optional[Labels] = None, ) -> Labels: """Load a SLEAP labels file. + SLEAP labels files (`.slp`) contain all the metadata for a labeling project or the + predicted labels from a video. This includes the skeleton, videos, labeled frames, + user-labeled and predicted instances, suggestions and tracks. + + See `sleap.io.dataset.Labels` for more detailed information. + Args: filename: Path to a SLEAP labels (.slp) file. detect_videos: If True, will attempt to detect missing videos by searching for @@ -1976,6 +2224,9 @@ def load_file( be the direct path to the video file or its containing folder. If not specified, defaults to searching for the videos in the same folder as the labels. + match_to: If a `sleap.Labels` object is provided, attempt to match and reuse + video and skeleton objects when loading. This is useful when comparing the + contents across sets of labels. Returns: The loaded `Labels` instance. @@ -1990,6 +2241,6 @@ def load_file( if detect_videos: if search_paths is None: search_paths = os.path.dirname(filename) - return Labels.load_file(filename, search_paths) + return Labels.load_file(filename, search_paths, match_to=match_to) else: - return Labels.load_file(filename) + return Labels.load_file(filename, match_to=match_to) diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py index a6f6933d1..7ba60cd79 100644 --- a/sleap/io/format/hdf5.py +++ b/sleap/io/format/hdf5.py @@ -225,6 +225,7 @@ def write( frame_data_format: str = "png", all_labeled: bool = False, suggested: bool = False, + progress_callback: Optional[Callable[[int, int], None]] = None, ): labels = source_object @@ -245,6 +246,7 @@ def write( user_labeled=True, all_labeled=all_labeled, suggested=suggested, + progress_callback=progress_callback, ) # Replace path to video file with "." (which indicates that the diff --git a/sleap/io/video.py b/sleap/io/video.py index 0203f91b7..62288223b 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -982,7 +982,13 @@ def shape(self) -> Tuple[int, int, int, int]: def __str__(self) -> str: """Informal string representation (for print or format).""" - return type(self).__name__ + " ([%d x %d x %d x %d])" % self.shape + return ( + "Video(" + f"filename={self.filename}, " + f"shape={self.shape}, " + f"backend={type(self.backend).__name__}" + ")" + ) def __len__(self) -> int: """Return the length of the video as the number of frames.""" @@ -1172,7 +1178,7 @@ def from_filename(cls, filename: str, *args, **kwargs) -> "Video": backend_class = HDF5Video elif filename.endswith(("npy")): backend_class = NumpyVideo - elif filename.lower().endswith(("mp4", "avi")): + elif filename.lower().endswith(("mp4", "avi", "mov")): backend_class = MediaVideo elif os.path.isdir(filename) or "metadata.yaml" in filename: backend_class = ImgStoreVideo @@ -1512,3 +1518,51 @@ def fixup_path( if raise_warning: logger.warning(f"Cannot find a video file: {path}") return path + + +def load_video( + filename: str, + grayscale: Optional[bool] = None, + dataset=Optional[None], + channels_first: bool = False, +) -> Video: + """Open a video from disk. + + Args: + filename: Path to a video file. The video reader backend will be determined by + the file extension. Support extensions include: `.mp4`, `.avi`, `.h5`, + `.hdf5` and `.slp` (for embedded images in a labels file). If the path to a + folder is provided, images within that folder will be treated as video + frames. + grayscale: Read frames as a single channel grayscale images. If `None` (the + default), this will be auto-detected. + dataset: Name of the dataset that contains the video if loading a video stored + in an HDF5 file. This has no effect for non-HDF5 inputs. + channels_first: If `False` (the default), assume the data in the HDF5 dataset + are formatted in `(frames, height, width, channels)` order. If `False`, + assume the data are in `(frames, channels, width, height)` format. This has + no effect for non-HDF5 inputs. + + Returns: + A `sleap.Video` instance with the appropriate backend for its format. + + This enables numpy-like access to video data. + + Example: + >>> video = sleap.load_video("centered_pair_small.mp4") + >>> video.shape + (1100, 384, 384, 1) + >>> imgs = video[0:3] + >>> imgs.shape + (3, 384, 384, 1) + + See also: + sleap.io.video.Video + """ + kwargs = {} + if grayscale is not None: + kwargs["grayscale"] = grayscale + if dataset is not None: + kwargs["dataset"] = dataset + kwargs["input_format"] = "channels_first" if channels_first else "channels_last" + return Video.from_filename(filename, **kwargs) diff --git a/sleap/nn/config/__init__.py b/sleap/nn/config/__init__.py index dbcc4c3a7..0ff2d565d 100644 --- a/sleap/nn/config/__init__.py +++ b/sleap/nn/config/__init__.py @@ -34,4 +34,4 @@ ZMQConfig, OutputsConfig, ) -from sleap.nn.config.training_job import TrainingJobConfig +from sleap.nn.config.training_job import TrainingJobConfig, load_config diff --git a/sleap/nn/config/data.py b/sleap/nn/config/data.py index bc091a547..37cf181e8 100644 --- a/sleap/nn/config/data.py +++ b/sleap/nn/config/data.py @@ -26,6 +26,15 @@ class LabelsConfig: can be computed from these data during model optimization. This is also useful to explicitly keep track of the test set that should be used when multiple splits are created for training. + split_by_inds: If `True`, splits used for training will be determined by the + lists below by indexing into the labels in `training_labels`. If this is + `False`, the indices below will not be used even if specified. This is + useful for specifying the fixed split sets from examples within a single + labels file. If splits are generated automatically (using + `validation_fraction`), the selected indices are stored below for reference. + training_inds: List of indices of the training split labels. + validation_inds: List of indices of the validation split labels. + test_inds: List of indices of the test split labels. search_path_hints: List of paths to use for searching for missing data. This is useful when labels and data are moved across computers, network storage, or operating systems that may have different absolute paths than those stored @@ -40,6 +49,10 @@ class LabelsConfig: validation_labels: Optional[Text] = None validation_fraction: float = 0.1 test_labels: Optional[Text] = None + split_by_inds: bool = False + training_inds: Optional[List[int]] = None + validation_inds: Optional[List[int]] = None + test_inds: Optional[List[int]] = None search_path_hints: List[Text] = attr.ib(factory=list) skeletons: List[sleap.Skeleton] = attr.ib(factory=list) @@ -76,13 +89,14 @@ class PreprocessingConfig: max stride (typically 32). This padding will be ignored when instance cropping inputs since the crop size should already be divisible by the model's max stride. - resize_and_pad_to_target: If True, will resize and pad all images in the dataset to match target dimensions. - This is useful when preprocessing datasets with mixed image dimensions (from different video resolutions). - Aspect ratio is preserved, and padding applied (if needed) to bottom or right of image only. - target_height: Target image height for 'resize_and_pad_to_target'. When not explicitly provided, inferred as the - max image height from the dataset. - target_width: Target image width for 'resize_and_pad_to_target'. When not explicitly provided, inferred as the - max image width from the dataset. + resize_and_pad_to_target: If True, will resize and pad all images in the dataset + to match target dimensions. This is useful when preprocessing datasets with + mixed image dimensions (from different video resolutions). Aspect ratio is + preserved, and padding applied (if needed) to bottom or right of image only. + target_height: Target image height for 'resize_and_pad_to_target'. When not + explicitly provided, inferred as the max image height from the dataset. + target_width: Target image width for 'resize_and_pad_to_target'. When not + explicitly provided, inferred as the max image width from the dataset. """ ensure_rgb: bool = False diff --git a/sleap/nn/config/outputs.py b/sleap/nn/config/outputs.py index 99369e335..ffb0d76e4 100644 --- a/sleap/nn/config/outputs.py +++ b/sleap/nn/config/outputs.py @@ -151,6 +151,11 @@ class OutputsConfig: save_visualizations: If True, will render and save visualizations of the model predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the split is one of "train", "validation", "test". + delete_viz_images: If True, delete the saved visualizations after training + completes. This is useful to reduce the model folder size if you do not need + to keep the visualization images. + zip_outputs: If True, compress the run folder to a zip file. This will be named + "{run_folder}.zip". log_to_csv: If True, loss and metrics will be saved to a simple CSV after each epoch to "{run_folder}/training_log.csv" checkpointing: Configuration options related to model checkpointing. @@ -165,6 +170,8 @@ class OutputsConfig: runs_folder: Text = "models" tags: List[Text] = attr.ib(factory=list) save_visualizations: bool = True + delete_viz_images: bool = True + zip_outputs: bool = False log_to_csv: bool = True checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig) tensorboard: TensorBoardConfig = attr.ib(factory=TensorBoardConfig) diff --git a/sleap/nn/config/training_job.py b/sleap/nn/config/training_job.py index fb7721afc..680fd8f15 100644 --- a/sleap/nn/config/training_job.py +++ b/sleap/nn/config/training_job.py @@ -27,13 +27,14 @@ import os import attr import cattr +import sleap from sleap.nn.config.data import DataConfig from sleap.nn.config.model import ModelConfig from sleap.nn.config.optimization import OptimizationConfig from sleap.nn.config.outputs import OutputsConfig import json from jsmin import jsmin -from typing import Text, Dict, Any +from typing import Text, Dict, Any, Optional @attr.s(auto_attribs=True) @@ -45,13 +46,20 @@ class TrainingJobConfig: model: Configuration options related to the model architecture. optimization: Configuration options related to the training. outputs: Configuration options related to outputs during training. + name: Optional name for this configuration profile. + description: Optional description of the configuration. + sleap_version: Version of SLEAP that generated this configuration. + filename: Path to this config file if it was loaded from disk. """ data: DataConfig = attr.ib(factory=DataConfig) model: ModelConfig = attr.ib(factory=ModelConfig) optimization: OptimizationConfig = attr.ib(factory=OptimizationConfig) outputs: OutputsConfig = attr.ib(factory=OutputsConfig) - # TODO: store fixed config format version + SLEAP version? + name: Optional[Text] = "" + description: Optional[Text] = "" + sleap_version: Optional[Text] = sleap.__version__ + filename: Optional[Text] = "" @classmethod def from_json_dicts(cls, json_data_dicts: Dict[Text, Any]) -> "TrainingJobConfig": @@ -82,16 +90,27 @@ def from_json(cls, json_data: Text) -> "TrainingJobConfig": return cls.from_json_dicts(json_data_dicts) @classmethod - def load_json(cls, filename: Text) -> "TrainingJobConfig": + def load_json( + cls, filename: Text, load_training_config: bool = True + ) -> "TrainingJobConfig": """Load a training job configuration from a file. Arguments: filename: Path to a training job configuration JSON file or a directory containing `"training_job.json"`. + load_training_config: If `True` (the default), prefer `training_job.json` + over `initial_config.json` if it is present in the same folder. Returns: A TrainingJobConfig instance parsed from the file. """ + if load_training_config and filename.endswith("initial_config.json"): + training_config_path = os.path.join( + os.path.dirname(filename), "training_config.json" + ) + if os.path.exists(training_config_path): + filename = training_config_path + # Use stored configuration if a directory was provided. if os.path.isdir(filename): filename = os.path.join(filename, "training_config.json") @@ -100,7 +119,9 @@ def load_json(cls, filename: Text) -> "TrainingJobConfig": with open(filename, "r") as f: json_data = f.read() - return cls.from_json(json_data) + obj = cls.from_json(json_data) + obj.filename = filename + return obj def to_json(self) -> str: """Serialize the configuration into JSON-encoded string format. @@ -117,5 +138,22 @@ def save_json(self, filename: Text): Arguments: filename: Path to save the training job file to. """ + self.filename = filename with open(filename, "w") as f: f.write(self.to_json()) + + +def load_config(filename: Text, load_training_config: bool = True) -> TrainingJobConfig: + """Load a training job configuration for a model run. + + Args: + filename: Path to a JSON file or directory containing `training_job.json`. + load_training_config: If `True` (the default), prefer `training_job.json` over + `initial_config.json` if it is present in the same folder. + + Returns: + The parsed `TrainingJobConfig`. + """ + return TrainingJobConfig.load_json( + filename, load_training_config=load_training_config + ) diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index 4ce8df9e6..76b30faa4 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -392,7 +392,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: if self.optimization_config.augmentation_config.random_flip: pipeline += RandomFlipper.from_skeleton( - self.data_config.skeletons[0], + self.data_config.labels.skeletons[0], horizontal=self.optimization_config.augmentation_config.flip_horizontal, ) pipeline += ImgaugAugmenter.from_config( @@ -541,7 +541,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: pipeline += Shuffler( shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size ) - + if self.optimization_config.augmentation_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.labels.skeletons[0], + horizontal=self.optimization_config.augmentation_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config( self.optimization_config.augmentation_config ) @@ -702,7 +706,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: ) if self.optimization_config.augmentation_config.random_flip: pipeline += RandomFlipper.from_skeleton( - self.data_config.skeletons[0], + self.data_config.labels.skeletons[0], horizontal=self.optimization_config.augmentation_config.flip_horizontal, ) pipeline += ImgaugAugmenter.from_config( @@ -848,7 +852,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: aug_config = self.optimization_config.augmentation_config if aug_config.random_flip: pipeline += RandomFlipper.from_skeleton( - self.data_config.skeletons[0], + self.data_config.labels.skeletons[0], horizontal=aug_config.flip_horizontal, ) pipeline += ImgaugAugmenter.from_config(aug_config) diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index d2e04f2cb..87e6b02e2 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -23,10 +23,13 @@ class LabelsReader: the entire labels dataset will be read. These indices will be applicable to the labeled frames in `labels` attribute, which may have changed in ordering or filtered. + user_instances_only: If `True`, load only user labeled instances. If `False`, + all instances will be loaded. """ labels: sleap.Labels example_indices: Optional[Union[Sequence[int], np.ndarray]] = None + user_instances_only: bool = False @classmethod def from_user_instances(cls, labels: sleap.Labels) -> "LabelsReader": @@ -36,17 +39,40 @@ def from_user_instances(cls, labels: sleap.Labels) -> "LabelsReader": labels: A `sleap.Labels` instance containing user instances. Returns: - A `LabelsReader` instance that can create a dataset for pipelining. Note - that the examples may change in ordering relative to the input `labels`, so - be sure to use the `labels` attribute in the returned instance. + A `LabelsReader` instance that can create a dataset for pipelining. """ - user_labels = sleap.Labels( - [ - sleap.LabeledFrame(lf.video, lf.frame_idx, lf.training_instances) - for lf in labels.user_labeled_frames - ] - ) - return cls(labels=user_labels) + obj = cls.from_user_labeled_frames(labels) + obj.user_instances_only = True + return obj + + @classmethod + def from_user_labeled_frames(cls, labels: sleap.Labels) -> "LabelsReader": + """Create a `LabelsReader` using the user labeled frames in a `Labels` set. + + Args: + labels: A `sleap.Labels` instance containing user labeled frames. + + Returns: + A `LabelsReader` instance that can create a dataset for pipelining. + + Note that this constructor will load ALL instances in frames that have user + instances. To load only user labeled indices, use + `LabelsReader.from_user_instances`. + """ + return cls(labels=labels, example_indices=labels.user_labeled_frame_inds) + + @classmethod + def from_unlabeled_suggestions(cls, labels: sleap.Labels) -> "LabelsReader": + """Create a `LabelsReader` using the unlabeled suggestions in a `Labels` set. + + Args: + labels: A `sleap.Labels` instance containing unlabeled suggestions. + + Returns: + A `LabelsReader` instance that can create a dataset for pipelining. + """ + inds = labels.get_unlabeled_suggestion_inds() + return cls(labels=labels, example_indices=inds) @classmethod def from_filename( @@ -95,12 +121,14 @@ def videos(self) -> List[sleap.Video]: @property def max_height_and_width(self) -> Tuple[int, int]: + """Return `(height, width)` that is the maximum of all videos.""" return max(video.shape[1] for video in self.videos), max( video.shape[2] for video in self.videos ) @property def is_from_multi_size_videos(self) -> bool: + """Return `True` if labels contain videos with different sizes.""" return ( len(set(v.shape[1] for v in self.videos)) > 1 or len(set(v.shape[2] for v in self.videos)) > 1 @@ -146,15 +174,27 @@ def make_dataset( def py_fetch_lf(ind): """Local function that will not be autographed.""" lf = self.labels[int(ind.numpy())] + video_ind = np.array(self.videos.index(lf.video)).astype("int32") frame_ind = np.array(lf.frame_idx).astype("int64") + raw_image = lf.image raw_image_size = np.array(raw_image.shape).astype("int32") - instances = np.stack( - [inst.points_array.astype("float32") for inst in lf.instances], axis=0 - ) + + if self.user_instances_only: + insts = lf.user_instances + else: + insts = lf.instances + insts = [inst for inst in insts if len(inst) > 0] + n_instances = len(insts) + n_nodes = len(insts[0].skeleton) if n_instances > 0 else 0 + + instances = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32") + for i, instance in enumerate(insts): + instances[i] = instance.numpy() + skeleton_inds = np.array( - [self.labels.skeletons.index(inst.skeleton) for inst in lf.instances] + [self.labels.skeletons.index(inst.skeleton) for inst in insts] ).astype("int32") return ( raw_image, @@ -181,7 +221,8 @@ def fetch_lf(ind): [image_dtype, tf.int32, tf.float32, tf.int32, tf.int64, tf.int32], ) - # Ensure shape with constant or variable height/width, based on whether or not the videos have mixed sizes + # Ensure shape with constant or variable height/width, based on whether or + # not the videos have mixed sizes. if self.is_from_multi_size_videos: image = tf.ensure_shape(image, (None, None, image_num_channels)) else: diff --git a/sleap/nn/data/training.py b/sleap/nn/data/training.py index f75ded45c..cb6177375 100644 --- a/sleap/nn/data/training.py +++ b/sleap/nn/data/training.py @@ -7,6 +7,58 @@ from sleap.nn.data.utils import expand_to_rank, ensure_list import attr from typing import List, Text, Optional, Any, Union, Dict, Tuple, Sequence +from sklearn.model_selection import train_test_split + + +def split_labels_train_val( + labels: sleap.Labels, validation_fraction: float +) -> Tuple[sleap.Labels, List[int], sleap.Labels, List[int]]: + """Make a train/validation split from a labels dataset. + + Args: + labels: A `sleap.Labels` dataset with labeled frames. + validation_fraction: Fraction of frames to use for validation. + + Returns: + A tuple of `(labels_train, idx_train, labels_val, idx_val)`. + + `labels_train` and `labels_val` are `sleap.Label` objects containing the + selected frames for each split. Their `videos`, `tracks` and `provenance` + attributes are identical to `labels` even if the split does not contain + instances with a particular video or track. + + `idx_train` and `idx_val` are list indices of the labeled frames within the + input labels that were assigned to each split, i.e.: + + `labels[idx_train] == labels_train[:]` + + If there is only one labeled frame in `labels`, both of the labels will contain + the same frame. + + If `validation_fraction` would result in fewer than one label for either split, + it will be rounded to ensure there is at least one label in each. + """ + if len(labels) == 1: + return labels, [0], labels, [0] + + # Split indices. + n_val = round(len(labels) * validation_fraction) + n_val = max(min(n_val, len(labels) - 1), 1) + + idx_train, idx_val = train_test_split(list(range(len(labels))), test_size=n_val) + + # Create labels and keep original metadata. + labels_train = sleap.Labels(labels[idx_train]) + labels_train.videos = labels.videos + labels_train.tracks = labels.tracks + labels_train.provenance = labels.provenance + + labels_val = sleap.Labels(labels[idx_val]) + labels_val.videos = labels.videos + labels_val.tracks = labels.tracks + labels_val.provenance = labels.provenance + + return labels_train, idx_train, labels_val, idx_val def split_labels( diff --git a/sleap/nn/evals.py b/sleap/nn/evals.py index 4c630ee60..612d3681f 100644 --- a/sleap/nn/evals.py +++ b/sleap/nn/evals.py @@ -37,8 +37,8 @@ from sleap.nn.model import Model from sleap.nn.data.pipelines import LabelsReader from sleap.nn.inference import ( - TopdownPredictor, - BottomupPredictor, + TopDownPredictor, + BottomUpPredictor, SingleInstancePredictor, ) @@ -671,21 +671,21 @@ def evaluate_model( # Setup predictor for evaluation. head_config = cfg.model.heads.which_oneof() if isinstance(head_config, CentroidsHeadConfig): - predictor = TopdownPredictor( + predictor = TopDownPredictor( centroid_config=cfg, centroid_model=model, confmap_config=None, confmap_model=None, ) elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig): - predictor = TopdownPredictor( + predictor = TopDownPredictor( centroid_config=None, centroid_model=None, confmap_config=cfg, confmap_model=model, ) elif isinstance(head_config, MultiInstanceConfig): - predictor = sleap.nn.inference.BottomupPredictor( + predictor = sleap.nn.inference.BottomUpPredictor( bottomup_config=cfg, bottomup_model=model ) elif isinstance(head_config, SingleInstanceConfmapsHeadConfig): diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index f66a39154..87ea34626 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -22,23 +22,32 @@ """ import attr +import argparse import logging import warnings import os -import time +import tempfile +import platform +import shutil +import atexit +import rich.progress +from collections import deque +import json +from time import time +from datetime import datetime +from pathlib import Path + from abc import ABC, abstractmethod -from typing import Text, Optional, List, Dict, Union, Iterator +from typing import Text, Optional, List, Dict, Union, Iterator, Tuple import tensorflow as tf import numpy as np import sleap -from sleap import util from sleap.nn.config import TrainingJobConfig from sleap.nn.model import Model -from sleap.nn.tracking import Tracker, run_tracker +from sleap.nn.tracking import Tracker from sleap.nn.paf_grouping import PAFScorer -from sleap.nn.data.grouping import group_examples_iter from sleap.nn.data.pipelines import ( Provider, Pipeline, @@ -47,138 +56,327 @@ Normalizer, Resizer, Prefetcher, - LambdaFilter, KerasModelPredictor, - LocalPeakFinder, - PredictedInstanceCropper, - InstanceCentroidFinder, - InstanceCropper, - GlobalPeakFinder, - MockGlobalPeakFinder, - KeyFilter, - KeyRenamer, - KeyDeviceMover, - PredictedCenterInstanceNormalizer, - PointsRescaler, ) logger = logging.getLogger(__name__) -def safely_generate(ds: tf.data.Dataset, progress: bool = True): - """Yields examples from dataset, catching and logging exceptions.""" - # Unsafe generating: - # for example in ds: - # yield example - - ds_iter = iter(ds) - - i = 0 - wall_t0 = time.time() - done = False - while not done: - try: - next_val = next(ds_iter) - yield next_val - except StopIteration: - done = True - except Exception as e: - logger.info(f"ERROR in sample index {i}") - logger.info(e) - logger.info("") - finally: - if not done: - i += 1 - - # Show the current progress (frames, time, fps) - if progress: - if (i and i % 1000 == 0) or done: - elapsed_time = time.time() - wall_t0 - logger.info( - f"Finished {i} examples in {elapsed_time:.2f} seconds " - "(inference + postprocessing)" - ) - if elapsed_time: - logger.info(f"examples/s = {i/elapsed_time}") +def get_keras_model_path(path: Text) -> str: + """Utility method for finding the path to a saved Keras model. + Args: + path: Path to a model run folder or job file. -def get_keras_model_path(path: Text) -> Text: + Returns: + Path to `best_model.h5` in the run folder. + """ + # TODO: Move this to TrainingJobConfig or Model? if path.endswith(".json"): path = os.path.dirname(path) return os.path.join(path, "best_model.h5") +class RateColumn(rich.progress.ProgressColumn): + """Renders the progress rate.""" + + def render(self, task: "Task") -> rich.progress.Text: + """Show progress rate.""" + speed = task.speed + if speed is None: + return rich.progress.Text("?", style="progress.data.speed") + return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") + + @attr.s(auto_attribs=True) class Predictor(ABC): """Base interface class for predictors.""" + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="rich", + kw_only=True, + ) + report_rate: float = attr.ib(default=2.0, kw_only=True) + model_paths: List[str] = attr.ib(factory=list, kw_only=True) + + @property + def report_period(self) -> float: + """Time between progress reports in seconds.""" + return 1.0 / self.report_rate + + @classmethod + def from_model_paths( + cls, + model_paths: Union[str, List[str]], + peak_threshold: float = 0.2, + integral_refinement: bool = True, + integral_patch_size: int = 5, + batch_size: int = 4, + ) -> "Predictor": + """Create the appropriate `Predictor` subclass from a list of model paths. + + Args: + model_paths: A single or list of trained model paths. + peak_threshold: Minimum confidence map value to consider a peak as valid. + integral_refinement: If `True`, peaks will be refined with integral + regression. If `False`, `"local"`, peaks will be refined with quarter + pixel local gradient offset. This has no effect if the model has an + offset regression head. + integral_patch_size: Size of patches to crop around each rough peak for + integral refinement as an integer scalar. + batch_size: The default batch size to use when loading data for inference. + Higher values increase inference speed at the cost of higher memory + usage. + + Returns: + A subclass of `Predictor`. + + See also: `SingleInstancePredictor`, `TopDownPredictor`, `BottomUpPredictor` + """ + # Read configs and find model types. + if isinstance(model_paths, str): + model_paths = [model_paths] + model_configs = [sleap.load_config(model_path) for model_path in model_paths] + model_paths = [cfg.filename for cfg in model_configs] + model_types = [ + cfg.model.heads.which_oneof_attrib_name() for cfg in model_configs + ] + + if "single_instance" in model_types: + predictor = SingleInstancePredictor.from_trained_models( + model_path=model_paths[model_types.index("single_instance")], + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + batch_size=batch_size, + ) + + elif "centroid" in model_types or "centered_instance" in model_types: + centroid_model_path = None + if "centroid" in model_types: + centroid_model_path = model_paths[model_types.index("centroid")] + + confmap_model_path = None + if "centered_instance" in model_types: + confmap_model_path = model_paths[model_types.index("centered_instance")] + + predictor = TopDownPredictor.from_trained_models( + centroid_model_path=centroid_model_path, + confmap_model_path=confmap_model_path, + batch_size=batch_size, + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + ) + + elif "multi_instance" in model_types: + predictor = BottomUpPredictor.from_trained_models( + model_path=model_paths[model_types.index("multi_instance")], + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + batch_size=batch_size, + ) + + else: + raise ValueError( + "Could not create predictor from model paths:" + "\n".join(model_paths) + ) + predictor.model_paths = model_paths + return predictor + @classmethod @abstractmethod def from_trained_models(cls, *args, **kwargs): pass - @abstractmethod - def make_pipeline(self): - pass + def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: + """Make a data loading pipeline. + + Args: + data_provider: If not `None`, the pipeline will be created with an instance + of a `sleap.pipelines.Provider`. + + Returns: + The created `sleap.pipelines.Pipeline` with batching and prefetching. + + Notes: + This method also updates the class attribute for the pipeline and will be + called automatically when predicting on data from a new source. + """ + pipeline = Pipeline() + if data_provider is not None: + pipeline.providers = [data_provider] + + pipeline += sleap.nn.data.pipelines.Batcher( + batch_size=self.batch_size, drop_remainder=False, unrag=False + ) + + pipeline += Prefetcher() + + self.pipeline = pipeline + + return pipeline @abstractmethod - def predict(self, data_provider: Provider): + def _initialize_inference_model(self): pass + def _predict_generator( + self, data_provider: Provider + ) -> Iterator[Dict[str, np.ndarray]]: + """Create a generator that yields batches of inference results. -@attr.s(auto_attribs=True) -class MockPredictor(Predictor): - labels: sleap.Labels + This method handles creating or updating the input `sleap.pipelines.Pipeline` + for loading the data, as well as looping over the batches and running inference. - @classmethod - def from_trained_models(cls, labels_path: Text): - labels = sleap.Labels.load_file(labels_path) - return cls(labels=labels) + Args: + data_provider: The `sleap.pipelines.Provider` that contains data that should + be used for inference. - def make_pipeline(self): - pass + Returns: + A generator yielding batches predicted results as dictionaries of numpy + arrays. + """ + # Initialize data pipeline and inference model if needed. + if self.pipeline is None: + self.make_pipeline() + if self.inference_model is None: + self._initialize_inference_model() - def predict(self, data_provider: Provider): + # Update the data provider source. + self.pipeline.providers = [data_provider] - prediction_video = None - - # Try to match specified video by its full path - prediction_video_path = os.path.abspath(data_provider.video.filename) - for video in self.labels.videos: - if os.path.abspath(video.filename) == prediction_video_path: - prediction_video = video - break - - if prediction_video is None: - # Try to match on filename (without path) - prediction_video_path = os.path.basename(data_provider.video.filename) - for video in self.labels.videos: - if os.path.basename(video.filename) == prediction_video_path: - prediction_video = video - break + def process_batch(ex): + # Run inference on current batch. + preds = self.inference_model.predict_on_batch(ex) - if prediction_video is None: - # Default to first video in labels file - prediction_video = self.labels.videos[0] + # Add model outputs to the input data example. + ex.update(preds) - # Get specified frames from labels file (or use None for all frames) - frame_idx_list = ( - list(data_provider.example_indices) - if data_provider.example_indices - else None - ) + # Convert to numpy arrays if not already. + if isinstance(ex["video_ind"], tf.Tensor): + ex["video_ind"] = ex["video_ind"].numpy().flatten() + if isinstance(ex["frame_ind"], tf.Tensor): + ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - frames = self.labels.find(video=prediction_video, frame_idx=frame_idx_list) + return ex + + # Loop over data batches with optional progress reporting. + if self.verbosity == "rich": + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Predicting...", total=len(data_provider)) + last_report = time() + for ex in self.pipeline.make_dataset(): + ex = process_batch(ex) + progress.update(task, advance=len(ex["frame_ind"])) + + # Handle refreshing manually to support notebooks. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + progress.refresh() + + # Return results. + yield ex + + elif self.verbosity == "json": + n_processed = 0 + n_total = len(data_provider) + n_recent = deque(maxlen=30) + elapsed_recent = deque(maxlen=30) + last_report = time() + t0_all = time() + t0_batch = time() + for ex in self.pipeline.make_dataset(): + # Process batch of examples. + ex = process_batch(ex) + + # Track timing and progress. + elapsed_batch = time() - t0_batch + t0_batch = time() + n_batch = len(ex["frame_ind"]) + n_processed += n_batch + elapsed_all = time() - t0_all + + # Compute recent rate. + n_recent.append(n_batch) + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + # Report. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() - # Run tracker as specified - if self.tracker: - frames = run_tracker(tracker=self.tracker, frames=frames) - self.tracker.final_pass(frames) + # Return results. + yield ex + else: + for ex in self.pipeline.make_dataset(): + yield process_batch(ex) + + def predict( + self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True + ) -> Union[List[Dict[str, np.ndarray]], sleap.Labels]: + """Run inference on a data source. + + Args: + data: A `sleap.pipelines.Provider`, `sleap.Labels` or `sleap.Video` to + run inference over. + make_labels: If `True` (the default), returns a `sleap.Labels` instance with + `sleap.PredictedInstance`s. If `False`, just return a list of + dictionaries containing the raw arrays returned by the inference model. + + Returns: + A `sleap.Labels` with `sleap.PredictedInstance`s if `make_labels` is `True`, + otherwise a list of dictionaries containing batches of numpy arrays with the + raw results. + """ + # Create provider if necessary. + if isinstance(data, np.ndarray): + data = sleap.Video(backend=sleap.io.video.NumpyVideo(data)) + if isinstance(data, sleap.Labels): + data = LabelsReader(data) + elif isinstance(data, sleap.Video): + data = VideoReader(data) + + # Initialize inference loop generator. + generator = self._predict_generator(data) - # Return frames (there are no "raw" predictions we could return) - return frames + if make_labels: + # Create SLEAP data structures while consuming results. + return sleap.Labels( + self._make_labeled_frames_from_generator(generator, data) + ) + else: + # Just return the raw results. + return list(generator) +# TODO: Rewrite this class. @attr.s(auto_attribs=True) class VisualPredictor(Predictor): """Predictor class for generating the visual output of model.""" @@ -248,6 +446,42 @@ def make_pipeline(self): self.pipeline = pipeline + def safely_generate(self, ds: tf.data.Dataset, progress: bool = True): + """Yields examples from dataset, catching and logging exceptions.""" + # Unsafe generating: + # for example in ds: + # yield example + + ds_iter = iter(ds) + + i = 0 + wall_t0 = time() + done = False + while not done: + try: + next_val = next(ds_iter) + yield next_val + except StopIteration: + done = True + except Exception as e: + logger.info(f"ERROR in sample index {i}") + logger.info(e) + logger.info("") + finally: + if not done: + i += 1 + + # Show the current progress (frames, time, fps) + if progress: + if (i and i % 1000 == 0) or done: + elapsed_time = time() - wall_t0 + logger.info( + f"Finished {i} examples in {elapsed_time:.2f} seconds " + "(inference + postprocessing)" + ) + if elapsed_time: + logger.info(f"examples/s = {i/elapsed_time}") + def predict_generator(self, data_provider: Provider): if self.pipeline is None: # Pass in data provider when mocking one of the models. @@ -256,7 +490,7 @@ def predict_generator(self, data_provider: Provider): self.pipeline.providers = [data_provider] # Yield each example from dataset, catching and logging exceptions - return safely_generate(self.pipeline.make_dataset()) + return self.safely_generate(self.pipeline.make_dataset()) def predict(self, data_provider: Provider): generator = self.predict_generator(data_provider) @@ -911,75 +1145,6 @@ def from_trained_models( obj._initialize_inference_model() return obj - def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: - """Make a data loading pipeline. - - Args: - data_provider: If not `None`, the pipeline will be created with an instance - of a `sleap.pipelines.Provider`. - - Returns: - The created `sleap.pipelines.Pipeline` with batching and prefetching. - - Notes: - This method also updates the class attribute for the pipeline and will be - called automatically when predicting on data from a new source. - """ - pipeline = Pipeline() - if data_provider is not None: - pipeline.providers = [data_provider] - - pipeline += sleap.nn.data.pipelines.Batcher( - batch_size=self.batch_size, drop_remainder=False, unrag=False - ) - - pipeline += Prefetcher() - - self.pipeline = pipeline - - return pipeline - - def _predict_generator( - self, data_provider: Provider - ) -> Iterator[Dict[str, np.ndarray]]: - """Create a generator that yields batches of inference results. - - This method handles creating or updating the input `sleap.pipelines.Pipeline` - for loading the data, as well as looping over the batches and running inference. - - Args: - data_provider: The `sleap.pipelines.Provider` that contains data that should - be used for inference. - - Returns: - A generator yielding batches predicted results as dictionaries of numpy - arrays. - """ - # Initialize data pipeline and inference model if needed. - if self.pipeline is None: - self.make_pipeline() - if self.inference_model is None: - self._initialize_inference_model() - - # Update the data provider source. - self.pipeline.providers = [data_provider] - - # Loop over data batches. - for ex in self.pipeline.make_dataset(): - # Run inference on current batch. - preds = self.inference_model.predict(ex) - - ex["peaks"] = preds["peaks"] - ex["peak_vals"] = preds["peak_vals"] - - # Convert to numpy arrays if not already. - if isinstance(ex["video_ind"], tf.Tensor): - ex["video_ind"] = ex["video_ind"].numpy().flatten() - if isinstance(ex["frame_ind"], tf.Tensor): - ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - - yield ex - def _make_labeled_frames_from_generator( self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider ) -> List[sleap.LabeledFrame]: @@ -1030,43 +1195,6 @@ def _make_labeled_frames_from_generator( return predicted_frames - def predict( - self, - data: Union[Provider, sleap.Labels, sleap.Video], - make_labels: bool = True, - ) -> Union[List[Dict[str, np.ndarray]], sleap.Labels]: - """Run inference on a data source. - - Args: - data: A `sleap.pipelines.Provider`, `sleap.Labels` or `sleap.Video` to - run inference over. - make_labels: If `True` (the default), returns a `sleap.Labels` instance with - `sleap.PredictedInstance`s. If `False`, just return a list of - dictionaries containing the raw arrays returned by the inference model. - - Returns: - A `sleap.Labels` with `sleap.PredictedInstance`s if `make_labels` is `True`, - otherwise a list of dictionaries containing batches of numpy arrays with the - raw results. - """ - # Create provider if necessary. - if isinstance(data, sleap.Labels): - data = LabelsReader(data) - elif isinstance(data, sleap.Video): - data = VideoReader(data) - - # Initialize inference loop generator. - generator = self._predict_generator(data) - - if make_labels: - # Create SLEAP data structures while consuming results. - return sleap.Labels( - self._make_labeled_frames_from_generator(generator, data) - ) - else: - # Just return the raw results. - return list(generator) - class CentroidCrop(InferenceLayer): """Inference layer for applying centroid crop-based models. @@ -1586,7 +1714,7 @@ def call( @attr.s(auto_attribs=True) -class TopdownPredictor(Predictor): +class TopDownPredictor(Predictor): """Top-down multi-instance predictor. This high-level class handles initialization, preprocessing and tracking using a @@ -1692,7 +1820,7 @@ def from_trained_models( peak_threshold: float = 0.2, integral_refinement: bool = True, integral_patch_size: int = 5, - ) -> "TopdownPredictor": + ) -> "TopDownPredictor": """Create predictor from saved models. Args: @@ -1715,7 +1843,7 @@ def from_trained_models( integral refinement as an integer scalar. Returns: - An instance of `TopdownPredictor` with the loaded models. + An instance of `TopDownPredictor` with the loaded models. One of the two models can be left as `None` to perform inference with ground truth data. This will only work with `LabelsReader` as the provider. @@ -1798,69 +1926,13 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: return pipeline - def _predict_generator( - self, data_provider: Provider - ) -> Iterator[Dict[str, np.ndarray]]: - """Create a generator that yields batches of inference results. + def _make_labeled_frames_from_generator( + self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider + ) -> List[sleap.LabeledFrame]: + """Create labeled frames from a generator that yields inference results. - This method handles creating or updating the input `sleap.pipelines.Pipeline` - for loading the data, as well as looping over the batches and running inference. - - Args: - data_provider: The `sleap.pipelines.Provider` that contains data that should - be used for inference. - - Returns: - A generator yielding batches predicted results as dictionaries of numpy - arrays. - """ - # Initialize data pipeline and inference model if needed. - if self.pipeline is None: - if self.centroid_config is not None and self.confmap_config is not None: - self.make_pipeline() - else: - # Pass in data provider when mocking one of the models. - self.make_pipeline(data_provider=data_provider) - if self.inference_model is None: - self._initialize_inference_model() - - # Update the data provider source. - self.pipeline.providers = [data_provider] - - # Loop over data batches. - for ex in self.pipeline.make_dataset(): - # Run inference on current batch. - preds = self.inference_model.predict(ex) - - # Crop possibly variable length results. - ex["instance_peaks"] = [ - x[:n] for x, n in zip(preds["instance_peaks"], preds["n_valid"]) - ] - ex["instance_peak_vals"] = [ - x[:n] for x, n in zip(preds["instance_peak_vals"], preds["n_valid"]) - ] - ex["centroids"] = [ - x[:n] for x, n in zip(preds["centroids"], preds["n_valid"]) - ] - ex["centroid_vals"] = [ - x[:n] for x, n in zip(preds["centroid_vals"], preds["n_valid"]) - ] - - # Convert to numpy arrays if not already. - if isinstance(ex["video_ind"], tf.Tensor): - ex["video_ind"] = ex["video_ind"].numpy().flatten() - if isinstance(ex["frame_ind"], tf.Tensor): - ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - - yield ex - - def _make_labeled_frames_from_generator( - self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider - ) -> List[sleap.LabeledFrame]: - """Create labeled frames from a generator that yields inference results. - - This method converts pure arrays into SLEAP-specific data structures and runs - them through the tracker if it is specified. + This method converts pure arrays into SLEAP-specific data structures and runs + them through the tracker if it is specified. Args: generator: A generator that returns dictionaries with inference results. @@ -1885,6 +1957,20 @@ def _make_labeled_frames_from_generator( predicted_frames = [] for ex in generator: + if "n_valid" in ex: + ex["instance_peaks"] = [ + x[:n] for x, n in zip(ex["instance_peaks"], ex["n_valid"]) + ] + ex["instance_peak_vals"] = [ + x[:n] for x, n in zip(ex["instance_peak_vals"], ex["n_valid"]) + ] + ex["centroids"] = [ + x[:n] for x, n in zip(ex["centroids"], ex["n_valid"]) + ] + ex["centroid_vals"] = [ + x[:n] for x, n in zip(ex["centroid_vals"], ex["n_valid"]) + ] + # Loop over frames. for image, video_ind, frame_ind, points, confidences, scores in zip( ex["image"], @@ -1926,43 +2012,6 @@ def _make_labeled_frames_from_generator( return predicted_frames - def predict( - self, - data: Union[Provider, sleap.Labels, sleap.Video], - make_labels: bool = True, - ) -> Union[List[Dict[str, np.ndarray]], sleap.Labels]: - """Run inference and tracking on a data source. - - Args: - data: A `sleap.pipelines.Provider`, `sleap.Labels` or `sleap.Video` to - run inference over. - make_labels: If `True` (the default), returns a `sleap.Labels` instance with - `sleap.PredictedInstance`s. If `False`, just return a list of - dictionaries containing the raw arrays returned by the inference model. - - Returns: - A `sleap.Labels` with `sleap.PredictedInstance`s if `make_labels` is `True`, - otherwise a list of dictionaries containing batches of numpy arrays with the - raw results. - """ - # Create provider if necessary. - if isinstance(data, sleap.Labels): - data = LabelsReader(data) - elif isinstance(data, sleap.Video): - data = VideoReader(data) - - # Initialize inference loop generator. - generator = self._predict_generator(data) - - if make_labels: - # Create SLEAP data structures while consuming results. - return sleap.Labels( - self._make_labeled_frames_from_generator(generator, data) - ) - else: - # Just return the raw results. - return list(generator) - class BottomUpInferenceLayer(InferenceLayer): """Keras layer that predicts instances from images using a trained model. @@ -2263,7 +2312,7 @@ def call(self, example): @attr.s(auto_attribs=True) -class BottomupPredictor(Predictor): +class BottomUpPredictor(Predictor): """Bottom-up multi-instance predictor. This high-level class handles initialization, preprocessing and tracking using a @@ -2326,6 +2375,7 @@ def _initialize_inference_model(self): ), input_scale=self.bottomup_config.data.preprocessing.input_scaling, pad_to_stride=self.bottomup_model.maximum_stride, + peak_threshold=self.peak_threshold, refinement="integral" if self.integral_refinement else "local", integral_patch_size=self.integral_patch_size, ) @@ -2339,7 +2389,7 @@ def from_trained_models( peak_threshold: float = 0.2, integral_refinement: bool = True, integral_patch_size: int = 5, - ) -> "BottomupPredictor": + ) -> "BottomUpPredictor": """Create predictor from a saved model. Args: @@ -2359,7 +2409,7 @@ def from_trained_models( integral refinement as an integer scalar. Returns: - An instance of `BottomupPredictor` with the loaded model. + An instance of `BottomUpPredictor` with the loaded model. """ # Load bottomup model. bottomup_config = TrainingJobConfig.load_json(model_path) @@ -2379,83 +2429,6 @@ def from_trained_models( obj._initialize_inference_model() return obj - def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: - """Make a data loading pipeline. - - Args: - data_provider: If not `None`, the pipeline will be created with an instance - of a `sleap.pipelines.Provider`. - - Returns: - The created `sleap.pipelines.Pipeline` with batching and prefetching. - - Notes: - This method also updates the class attribute for the pipeline and will be - called automatically when predicting on data from a new source. - """ - pipeline = Pipeline() - if data_provider is not None: - pipeline.providers = [data_provider] - - pipeline += sleap.nn.data.pipelines.Batcher( - batch_size=self.batch_size, drop_remainder=False, unrag=False - ) - - pipeline += Prefetcher() - - self.pipeline = pipeline - - return pipeline - - def _predict_generator( - self, data_provider: Provider - ) -> Iterator[Dict[str, np.ndarray]]: - """Create a generator that yields batches of inference results. - - This method handles creating or updating the input `sleap.pipelines.Pipeline` - for loading the data, as well as looping over the batches and running inference. - - Args: - data_provider: The `sleap.pipelines.Provider` that contains data that should - be used for inference. - - Returns: - A generator yielding batches predicted results as dictionaries of numpy - arrays. - """ - # Initialize data pipeline and inference model if needed. - if self.pipeline is None: - self.make_pipeline() - if self.inference_model is None: - self._initialize_inference_model() - - # Update the data provider source. - self.pipeline.providers = [data_provider] - - # Loop over data batches. - for ex in self.pipeline.make_dataset(): - # Run inference on current batch. - preds = self.inference_model.predict(ex) - - # Crop possibly variable length results. - ex["instance_peaks"] = [ - x[:n] for x, n in zip(preds["instance_peaks"], preds["n_valid"]) - ] - ex["instance_peak_vals"] = [ - x[:n] for x, n in zip(preds["instance_peak_vals"], preds["n_valid"]) - ] - ex["instance_scores"] = [ - x[:n] for x, n in zip(preds["instance_scores"], preds["n_valid"]) - ] - - # Convert to numpy arrays if not already. - if isinstance(ex["video_ind"], tf.Tensor): - ex["video_ind"] = ex["video_ind"].numpy().flatten() - if isinstance(ex["frame_ind"], tf.Tensor): - ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - - yield ex - def _make_labeled_frames_from_generator( self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider ) -> List[sleap.LabeledFrame]: @@ -2484,6 +2457,18 @@ def _make_labeled_frames_from_generator( predicted_frames = [] for ex in generator: + if "n_valid" in ex: + # Crop possibly variable length results. + ex["instance_peaks"] = [ + x[:n] for x, n in zip(ex["instance_peaks"], ex["n_valid"]) + ] + ex["instance_peak_vals"] = [ + x[:n] for x, n in zip(ex["instance_peak_vals"], ex["n_valid"]) + ] + ex["instance_scores"] = [ + x[:n] for x, n in zip(ex["instance_scores"], ex["n_valid"]) + ] + # Loop over frames. for image, video_ind, frame_ind, points, confidences, scores in zip( ex["image"], @@ -2525,116 +2510,210 @@ def _make_labeled_frames_from_generator( return predicted_frames - def predict( - self, - data: Union[Provider, sleap.Labels, sleap.Video], - make_labels: bool = True, - ) -> Union[List[Dict[str, np.ndarray]], sleap.Labels]: - """Run inference and tracking on a data source. - Args: - data: A `sleap.pipelines.Provider`, `sleap.Labels` or `sleap.Video` to - run inference over. - make_labels: If `True` (the default), returns a `sleap.Labels` instance with - `sleap.PredictedInstance`s. If `False`, just return a list of - dictionaries containing the raw arrays returned by the inference model. +def load_model( + model_path: Union[str, List[str]], + batch_size: int = 4, + peak_threshold: float = 0.2, + refinement: str = "integral", + tracker: Optional[str] = None, + tracker_window: int = 5, + tracker_max_instances: Optional[int] = None, + disable_gpu_preallocation: bool = True, + progress_reporting: str = "rich", +) -> Predictor: + """Load a trained SLEAP model. - Returns: - A `sleap.Labels` with `sleap.PredictedInstance`s if `make_labels` is `True`, - otherwise a list of dictionaries containing batches of numpy arrays with the - raw results. - """ - # Create provider if necessary. - if isinstance(data, sleap.Labels): - data = LabelsReader(data) - elif isinstance(data, sleap.Video): - data = VideoReader(data) + Args: + model_path: Path to model or list of path to models that were trained by SLEAP. + These should be the directories that contain `training_job.json` and + `best_model.h5`. + batch_size: Number of frames to predict at a time. Larger values result in + faster inference speeds, but require more memory. + peak_threshold: Minimum confidence map value to consider a peak as valid. + refinement: If `"integral"`, peak locations will be refined with integral + regression. If `"local"`, peaks will be refined with quarter pixel local + gradient offset. This has no effect if the model has an offset regression + head. + tracker: Name of the tracker to use with the inference model. Must be one of + `"simple"` or `"flow"`. If `None`, no identity tracking across frames will + be performed. + tracker_window: Number of frames of history to use when tracking. No effect when + `tracker` is `None`. + tracker_max_instances: If not `None`, discard instances beyond this count when + tracking. No effect when `tracker` is `None`. + disable_gpu_preallocation: If `True` (the default), initialize the GPU and + disable preallocation of memory. This is necessary to prevent freezing on + some systems with low GPU memory and has negligible impact on performance. + If `False`, no GPU initialization is performed. No effect if running in + CPU-only mode. + progress_reporting: Mode of inference progress reporting. If `"rich"` (the + default), an updating progress bar is displayed in the console or notebook. + If `"json"`, a JSON-serialized message is printed out which can be captured + for programmatic progress monitoring. If `"none"`, nothing is displayed + during inference -- this is recommended when running on clusters or headless + machines where the output is captured to a log file. - # Initialize inference loop generator. - generator = self._predict_generator(data) + Returns: + An instance of a `Predictor` based on which model type was detected. - if make_labels: - # Create SLEAP data structures while consuming results. - return sleap.Labels( - self._make_labeled_frames_from_generator(generator, data) - ) - else: - # Just return the raw results. - return list(generator) + If this is a top-down model, paths to the centroids model as well as the + centered instance model must be provided. A `TopDownPredictor` instance will be + returned. + + If this is a bottom-up model, a `BottomUpPredictor` will be returned. + + If this is a single-instance model, a `SingleInstancePredictor` will be + returned. + + If a `tracker` is specified, the predictor will also run identity tracking over + time. + + See also: TopDownPredictor, BottomUpPredictor, SingleInstancePredictor + """ + if isinstance(model_path, str): + model_paths = [model_path] + else: + model_paths = model_path + + # Uncompress ZIP packaged models. + tmp_dirs = [] + for i, model_path in enumerate(model_paths): + if model_path.endswith(".zip"): + # Create temp dir on demand. + tmp_dir = tempfile.TemporaryDirectory() + tmp_dirs.append(tmp_dir) + + # Remove the temp dir when program exits in case something goes wrong. + atexit.register(shutil.rmtree, tmp_dir.name, ignore_errors=True) + + # Extract and replace in the list. + shutil.unpack_archive(model_path, extract_dir=tmp_dir.name) + model_paths[i] = tmp_dir.name + + if disable_gpu_preallocation: + sleap.disable_preallocation() + + predictor = Predictor.from_model_paths( + model_paths, + peak_threshold=peak_threshold, + integral_refinement=refinement == "integral", + batch_size=batch_size, + ) + predictor.verbosity = progress_reporting + if tracker is not None: + predictor.tracker = Tracker.make_tracker_by_name( + tracker=tracker, + track_window=tracker_window, + post_connect_single_breaks=True, + clean_instance_count=tracker_max_instances, + ) + # Remove temp dirs. + for tmp_dir in tmp_dirs: + tmp_dir.cleanup() -CLI_PREDICTORS = { - "topdown": TopdownPredictor, - "bottomup": BottomupPredictor, - "single": SingleInstancePredictor, -} + return predictor -def make_cli_parser(): - import argparse - from sleap.util import frame_list +def _make_cli_parser() -> argparse.ArgumentParser: + """Create argument parser for CLI. + Returns: + The `argparse.ArgumentParser` that defines the CLI options. + """ parser = argparse.ArgumentParser() - # Add args for entire pipeline parser.add_argument( - "video_path", type=str, nargs="?", default="", help="Path to video file" + "data_path", + type=str, + nargs="?", + default="", + help=( + "Path to data to predict on. This can be a labels (.slp) file or any " + "supported video format." + ), ) parser.add_argument( "-m", "--model", dest="models", action="append", - help="Path to trained model directory (with training_config.json). " - "Multiple models can be specified, each preceded by --model.", + help=( + "Path to trained model directory (with training_config.json). " + "Multiple models can be specified, each preceded by --model." + ), ) - parser.add_argument( "--frames", - type=frame_list, + type=sleap.util.frame_list, default="", - help="List of frames to predict. Either comma separated list (e.g. 1,2,3) or " - "a range separated by hyphen (e.g. 1-3, for 1,2,3). (default is entire video)", + help=( + "List of frames to predict when running on a video. Can be specified as a " + "comma separated list (e.g. 1,2,3) or a range separated by hyphen (e.g., " + "1-3, for 1,2,3). If not provided, defaults to predicting on the entire " + "video." + ), ) parser.add_argument( "--only-labeled-frames", action="store_true", default=False, - help="Only run inference on labeled frames (when running on labels dataset file).", + help=( + "Only run inference on user labeled frames when running on labels dataset. " + "This is useful for generating predictions to compare against ground truth." + ), ) parser.add_argument( "--only-suggested-frames", action="store_true", default=False, - help="Only run inference on suggested frames (when running on labels dataset file).", + help=( + "Only run inference on unlabeled suggested frames when running on labels " + "dataset. This is useful for generating predictions for initialization " + "during labeling." + ), ) parser.add_argument( "-o", "--output", type=str, default=None, - help="The output filename to use for the predicted data.", + help=( + "The output filename to use for the predicted data. If not provided, " + "defaults to '[data_path].predictions.slp'." + ), ) parser.add_argument( - "--labels", + "--no-empty-frames", + action="store_true", + default=False, + help=( + "Clear any empty frames that did not have any detected instances before " + "saving to output." + ), + ) + parser.add_argument( + "--verbosity", type=str, - default=None, - help="Path to labels dataset file (for inference on multiple videos or for re-tracking pre-existing predictions).", + choices=["none", "rich", "json"], + default="rich", + help=( + "Verbosity of inference progress reporting. 'none' does not output " + "anything during inference, 'rich' displays an updating progress bar, " + "and 'json' outputs the progress as a JSON encoded response to the " + "console." + ), ) - - # TODO: better video parameters - parser.add_argument( - "--video.dataset", type=str, default="", help="The dataset for HDF5 videos." + "--video.dataset", type=str, default=None, help="The dataset for HDF5 videos." ) - parser.add_argument( "--video.input_format", type=str, - default="", + default="channels_last", help="The input_format for HDF5 videos.", ) - device_group = parser.add_mutually_exclusive_group(required=False) device_group.add_argument( "--cpu", @@ -2655,257 +2734,196 @@ def make_cli_parser(): "--gpu", type=int, default=0, help="Run inference on the i-th GPU specified." ) - # Add args for each predictor class - for predictor_name, predictor_class in CLI_PREDICTORS.items(): - if "peak_threshold" in attr.fields_dict(predictor_class): - # get the default value to show in help string, although we'll - # use None as default so that unspecified vals won't be passed to - # builder. - default_val = attr.fields_dict(predictor_class)["peak_threshold"].default - - parser.add_argument( - f"--{predictor_name}.peak_threshold", - type=float, - default=None, - help=f"Threshold to use when finding peaks in {predictor_class.__name__} (default: {default_val}).", - ) - - if "batch_size" in attr.fields_dict(predictor_class): - default_val = attr.fields_dict(predictor_class)["batch_size"].default - parser.add_argument( - f"--{predictor_name}.batch_size", - type=int, - default=4, - help=f"Batch size to use for model inference in {predictor_class.__name__} (default: {default_val}).", - ) - - # Add args for tracking - Tracker.add_cli_parser_args(parser, arg_scope="tracking") + parser.add_argument( + "--peak_threshold", + type=float, + default=0.2, + help="Minimum confidence map value to consider a peak as valid.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4, + help=( + "Number of frames to predict at a time. Larger values result in faster " + "inference speeds, but require more memory." + ), + ) + # Deprecated legacy args. These will still be parsed for backward compatibility but + # are hidden from the CLI help. parser.add_argument( - "--test-pipeline", - default=False, - action="store_true", - help="Test pipeline construction without running anything.", + "--labels", + type=str, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--single.peak_threshold", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--topdown.peak_threshold", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, ) + parser.add_argument( + "--bottomup.peak_threshold", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--single.batch_size", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--topdown.batch_size", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--bottomup.batch_size", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + + # Add tracker args. + Tracker.add_cli_parser_args(parser, arg_scope="tracking") return parser -def make_video_readers_from_cli(args) -> List[VideoReader]: - if args.video_path: - # TODO: better support for video params - video_kwargs = dict( - dataset=vars(args).get("video.dataset"), - input_format=vars(args).get("video.input_format"), - ) +def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: + """Make data provider from parsed CLI args. - video_reader = VideoReader.from_filepath( - filename=args.video_path, example_indices=args.frames, **video_kwargs - ) + Args: + args: Parsed CLI namespace. - return [video_reader] + Returns: + A tuple of `(provider, data_path)` with the data `Provider` and path to the data + that was specified in the args. + """ + # Figure out which input path to use. + labels_path = getattr(args, "labels", None) + if labels_path is not None: + data_path = labels_path + else: + data_path = args.data_path - if args.labels: - # TODO: Replace with LabelsReader. - labels = sleap.Labels.load_file(args.labels) + if data_path is None: + raise ValueError("You must specify a path to a video or a labels dataset.") - readers = [] + if data_path.endswith(".slp"): + labels = sleap.Labels.load_file(data_path) if args.only_labeled_frames: - user_labeled_frames = labels.user_labeled_frames + provider = LabelsReader.from_user_labeled_frames(labels) + elif args.only_suggested_frames: + provider = LabelsReader.from_unlabeled_suggestions(labels) else: - user_labeled_frames = [] - - for video in labels.videos: - if args.only_labeled_frames: - frame_indices = [ - lf.frame_idx for lf in user_labeled_frames if lf.video == video - ] - readers.append(VideoReader(video=video, example_indices=frame_indices)) - elif args.only_suggested_frames: - readers.append( - VideoReader( - video=video, example_indices=labels.get_video_suggestions(video) - ) - ) - else: - readers.append(VideoReader(video=video)) - - return readers - - raise ValueError("You must specify either video_path or labels dataset path.") - - -def make_predictor_from_paths(paths, **kwargs) -> Predictor: - """Build predictor object from a list of model paths.""" - return make_predictor_from_models(find_heads_for_model_paths(paths), **kwargs) - - -def find_heads_for_model_paths(paths) -> Dict[str, str]: - """Given list of models paths, returns dict with path keyed by head name.""" - trained_model_paths = dict() - - if paths is None: - return trained_model_paths - - for model_path in paths: - # Load the model config - cfg = TrainingJobConfig.load_json(model_path) - - # Get the head from the model (i.e., what the model will predict) - key = cfg.model.heads.which_oneof_attrib_name() - - # If path is to config file json, then get the path to parent dir - if model_path.endswith(".json"): - model_path = os.path.dirname(model_path) - - trained_model_paths[key] = model_path + provider = LabelsReader(labels) - return trained_model_paths - - -def make_predictor_from_models( - trained_model_paths: Dict[str, str], - labels_path: Optional[str] = None, - policy_args: Optional[dict] = None, - **kwargs, -) -> Predictor: - """Given dict of paths keyed by head name, returns appropriate predictor.""" - - def get_relevant_args(key): - if policy_args is not None and key in policy_args: - return policy_args[key] - return dict() - - if "multi_instance" in trained_model_paths: - predictor = BottomupPredictor.from_trained_models( - trained_model_paths["multi_instance"], - **get_relevant_args("bottomup"), - **kwargs, - ) - elif "single_instance" in trained_model_paths: - predictor = SingleInstancePredictor.from_trained_models( - trained_model_paths["single_instance"], - **get_relevant_args("single"), - **kwargs, - ) - elif ( - "centroid" in trained_model_paths and "centered_instance" in trained_model_paths - ): - predictor = TopdownPredictor.from_trained_models( - centroid_model_path=trained_model_paths["centroid"], - confmap_model_path=trained_model_paths["centered_instance"], - **get_relevant_args("topdown"), - **kwargs, - ) - elif len(trained_model_paths) == 0 and labels_path: - predictor = MockPredictor.from_trained_models(labels_path=labels_path) else: - raise ValueError( - f"Unable to run inference with {list(trained_model_paths.keys())} heads." + # TODO: Clean this up. + video_kwargs = dict( + dataset=vars(args).get("video.dataset"), + input_format=vars(args).get("video.input_format"), + ) + provider = VideoReader.from_filepath( + filename=data_path, example_indices=args.frames, **video_kwargs ) - return predictor - - -def make_tracker_from_cli(policy_args): - if "tracking" in policy_args: - tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) - return tracker - - return None - - -def save_predictions_from_cli(args, predicted_frames, prediction_metadata=None): - from sleap import Labels - - if args.output: - output_path = args.output - elif args.video_path: - out_dir = os.path.dirname(args.video_path) - out_name = os.path.basename(args.video_path) + ".predictions.slp" - output_path = os.path.join(out_dir, out_name) - elif args.labels: - out_dir = os.path.dirname(args.labels) - out_name = os.path.basename(args.labels) + ".predictions.slp" - output_path = os.path.join(out_dir, out_name) - else: - # We shouldn't ever get here but if we do, just save in working dir. - output_path = "predictions.slp" - - labels = Labels(labeled_frames=predicted_frames, provenance=prediction_metadata) - - print(f"Saving: {output_path}") - Labels.save_file(labels, output_path) + return provider, data_path -def load_model( - model_path: Union[str, List[str]], - batch_size: int = 4, - refinement: str = "integral", - tracker: Optional[str] = None, - tracker_window: int = 5, - tracker_max_instances: Optional[int] = None, -) -> Predictor: - """Load a trained SLEAP model. +def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: + """Make predictor from parsed CLI args. Args: - model_path: Path to model or list of path to models that were trained by SLEAP. - These should be the directories that contain `training_job.json` and - `best_model.h5`. - batch_size: Number of frames to predict at a time. Larger values result in - faster inference speeds, but require more memory. - refinement: If `"integral"`, peak locations will be refined with integral - regression. If `"local"`, peaks will be refined with quarter pixel local - gradient offset. This has no effect if the model has an offset regression - head. - tracker: Name of the tracker to use with the inference model. Must be one of - `"simple"` or `"flow"`. If `None`, no identity tracking across frames will - be performed. - tracker_window: Number of frames of history to use when tracking. No effect when - `tracker` is `None`. - tracker_max_instances: If not `None`, discard instances beyond this count when - tracking. No effect when `tracker` is `None`. + args: Parsed CLI namespace. Returns: - An instance of a `Predictor` based on which model type was detected. - - If this is a top-down model, paths to the centroids model as well as the - centered instance model must be provided. A `TopdownPredictor` instance will be - returned. + The `Predictor` created from loaded models. + """ + peak_threshold = None + for deprecated_arg in [ + "single.peak_threshold", + "topdown.peak_threshold", + "bottomup.peak_threshold", + ]: + val = getattr(args, deprecated_arg, None) + if val is not None: + peak_threshold = val + if peak_threshold is None: + peak_threshold = args.peak_threshold + + batch_size = None + for deprecated_arg in [ + "single.batch_size", + "topdown.batch_size", + "bottomup.batch_size", + ]: + val = getattr(args, deprecated_arg, None) + if val is not None: + batch_size = val + if batch_size is None: + batch_size = args.batch_size + + predictor = Predictor.from_model_paths( + args.models, + peak_threshold=peak_threshold, + integral_refinement=True, + batch_size=batch_size, + ) + predictor.verbosity = args.verbosity + return predictor - If this is a bottom-up model, a `BottomupPredictor` will be returned. - If this is a single-instance model, a `SingleInstancePredictor` will be - returned. +def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]: + """Make tracker from parsed CLI arguments. - If a `tracker` is specified, the predictor will also run identity tracking over - time. + Args: + args: Parsed CLI namespace. - See also: TopdownPredictor, BottomupPredictor, SingleInstancePredictor + Returns: + An instance of `Tracker` or `None` if tracking method was not specified. """ - if isinstance(model_path, str): - model_path = [model_path] - predictor = make_predictor_from_paths( - model_path, batch_size=batch_size, integral_refinement=refinement == "integral" - ) - if tracker is not None: - predictor.tracker = Tracker.make_tracker_by_name( - tracker=tracker, - track_window=tracker_window, - post_connect_single_breaks=True, - clean_instance_count=tracker_max_instances, - ) - return predictor + policy_args = sleap.util.make_scoped_dictionary(vars(args), exclude_nones=True) + if "tracking" in policy_args: + tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) + return tracker + return None def main(): - """CLI for running inference.""" - parser = make_cli_parser() + """Entrypoint for `sleap-track` CLI for running inference.""" + t0 = time() + start_timestamp = str(datetime.now()) + print("Started inference at:", start_timestamp) + + # Setup CLI. + parser = _make_cli_parser() + + # Parse inputs. args, _ = parser.parse_known_args() - print(args) + args_msg = ["Args:"] + for name, val in vars(args).items(): + if name == "frames" and val is not None: + args_msg.append(f" frames: {min(val)}-{max(val)} ({len(val)})") + else: + args_msg.append(f" {name}: {val}") + print("\n".join(args_msg)) + + # Setup devices. if args.cpu or not sleap.nn.system.is_gpu_system(): sleap.nn.system.use_cpu_only() else: @@ -2915,69 +2933,57 @@ def main(): sleap.nn.system.use_last_gpu() else: sleap.nn.system.use_gpu(args.gpu) - sleap.nn.system.disable_preallocation() + sleap.disable_preallocation() + + print("Versions:") + sleap.versions() + print() print("System:") sleap.nn.system.summary() + print() - video_readers = make_video_readers_from_cli(args) + # Setup data loader. + provider, data_path = _make_provider_from_cli(args) - # Find the specified models - model_paths_by_head = find_heads_for_model_paths(args.models) - - # Make a scoped dictionary with args specified from cli - policy_args = util.make_scoped_dictionary(vars(args), exclude_nones=True) - - # Create appropriate predictor given these models - predictor = make_predictor_from_models( - model_paths_by_head, labels_path=args.labels, policy_args=policy_args - ) + # Setup models. + predictor = _make_predictor_from_cli(args) - # Make the tracker - tracker = make_tracker_from_cli(policy_args) + # Setup tracker. + tracker = _make_tracker_from_cli(args) predictor.tracker = tracker - if args.test_pipeline: - print() - - print(policy_args) - print() - - print(predictor) - print() - - predictor.make_pipeline() - print("===pipeline transformers===") - print() - for transformer in predictor.pipeline.transformers: - print(transformer.__class__.__name__) - print(f"\t-> {transformer.input_keys}") - print(f"\t {transformer.output_keys} ->") - print() - - print("--test-pipeline arg set so stopping here.") - return - # Run inference! - t0 = time.time() - predicted_frames = [] - - for video_reader in video_readers: - video_predicted_frames = predictor.predict(video_reader).labeled_frames - predicted_frames.extend(video_predicted_frames) - - # Create dictionary of metadata we want to save with predictions - prediction_metadata = dict() - for head, path in model_paths_by_head.items(): - prediction_metadata[f"model.{head}.path"] = os.path.abspath(path) - for scope in policy_args.keys(): - for key, val in policy_args[scope].items(): - prediction_metadata[f"{scope}.{key}"] = val - prediction_metadata["video.path"] = args.video_path - prediction_metadata["sleap.version"] = sleap.__version__ - - save_predictions_from_cli(args, predicted_frames, prediction_metadata) - print(f"Total Time: {time.time() - t0}") + labels_pr = predictor.predict(provider) + + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + output_path = args.output + if output_path is None: + output_path = data_path + ".predictions.slp" + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["model_paths"] = predictor.model_paths + labels_pr.provenance["output_path"] = output_path + labels_pr.provenance["predictor"] = type(predictor).__name__ + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + # Save results. + labels_pr.save(output_path) + print("Saved output:", output_path) if __name__ == "__main__": diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index 40fae2a0e..0e60469df 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -28,6 +28,8 @@ def __init__( self.show_controller = show_controller self.stop_button = None + self.cancel_button = None + self.canceled = False self.redraw_batch_interval = 40 self.batches_to_show = -1 # -1 to show all @@ -162,9 +164,12 @@ def reset(self, what=""): control_layout.addStretch(1) - self.stop_button = QtWidgets.QPushButton("Stop Training") + self.stop_button = QtWidgets.QPushButton("Stop Early") self.stop_button.clicked.connect(self.stop) control_layout.addWidget(self.stop_button) + self.cancel_button = QtWidgets.QPushButton("Cancel Training") + self.cancel_button.clicked.connect(self.cancel) + control_layout.addWidget(self.cancel_button) widget = QtWidgets.QWidget() widget.setLayout(control_layout) @@ -242,12 +247,19 @@ def setup_zmq(self, zmq_context: Optional[zmq.Context]): self.timer.timeout.connect(self.check_messages) self.timer.start(20) + def cancel(self): + """Set the cancel flag.""" + self.canceled = True + if self.cancel_button is not None: + self.cancel_button.setText("Canceling...") + self.cancel_button.setEnabled(False) + def stop(self): """Action to stop training.""" if self.zmq_ctrl is not None: # send command to stop training - logger.info("Sending command to stop training") + logger.info("Sending command to stop training.") self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) # Disable the button diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 12b7fd8bd..33c7f663f 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -5,6 +5,7 @@ from datetime import datetime from time import time import logging +import shutil import tensorflow as tf import numpy as np @@ -43,7 +44,7 @@ BottomUpPipeline, KeyMapper, ) -from sleap.nn.data.training import split_labels +from sleap.nn.data.training import split_labels_train_val # Optimization from sleap.nn.config import OptimizationConfig @@ -116,6 +117,11 @@ def from_config( if test is None: test = labels_config.test_labels + if video_search_paths is None: + video_search_paths = [] + if labels_config.search_path_hints is not None: + video_search_paths.extend(labels_config.search_path_hints) + # Update the config fields with arguments (if not a full sleap.Labels instance). if update_config: if isinstance(training, Text): @@ -126,14 +132,16 @@ def from_config( labels_config.validation_fraction = validation if isinstance(test, Text): labels_config.test_labels = test + labels_config.search_path_hints = video_search_paths # Build class. - # TODO: use labels_config.search_path_hints for loading return cls.from_labels( training=training, validation=validation, test=test, video_search_paths=video_search_paths, + labels_config=labels_config, + update_config=update_config, ) @classmethod @@ -143,22 +151,72 @@ def from_labels( validation: Union[Text, sleap.Labels, float], test: Optional[Union[Text, sleap.Labels]] = None, video_search_paths: Optional[List[Text]] = None, + labels_config: Optional[LabelsConfig] = None, + update_config: bool = False, ) -> "DataReaders": """Create data readers from sleap.Labels datasets as data providers.""" - if isinstance(training, str): - print("video search paths: ", video_search_paths) + logger.info(f"Loading training labels from: {training}") training = sleap.Labels.load_file(training, video_search=video_search_paths) - print(training.videos) + + if labels_config is not None and labels_config.split_by_inds: + # First try to split by indices if specified in config. + if ( + labels_config.validation_inds is not None + and len(labels_config.validation_inds) > 0 + ): + logger.info( + "Creating validation split from explicit indices " + f"(n = {len(labels_config.validation_inds)})." + ) + validation = training[labels_config.validation_inds] + + if labels_config.test_inds is not None and len(labels_config.test_inds) > 0: + logger.info( + "Creating test split from explicit indices " + f"(n = {len(labels_config.test_inds)})." + ) + test = training[labels_config.test_inds] + + if ( + labels_config.training_inds is not None + and len(labels_config.training_inds) > 0 + ): + logger.info( + "Creating training split from explicit indices " + f"(n = {len(labels_config.training_inds)})." + ) + training = training[labels_config.training_inds] if isinstance(validation, str): + # If validation is still a path, load it. + logger.info(f"Loading validation labels from: {validation}") validation = sleap.Labels.load_file( validation, video_search=video_search_paths ) elif isinstance(validation, float): - training, validation = split_labels(training, [-1, validation]) + logger.info( + "Creating training and validation splits from " + f"validation fraction: {validation}" + ) + # If validation is still a float, create the split from training. + ( + training, + training_inds, + validation, + validation_inds, + ) = split_labels_train_val(training.with_user_labels_only(), validation) + logger.info( + f" Splits: Training = {len(training_inds)} /" + f" Validation = {len(validation_inds)}." + ) + if update_config and labels_config is not None: + labels_config.training_inds = training_inds + labels_config.validation_inds = validation_inds if isinstance(test, str): + # If test is still a path, load it. + logger.info(f"Loading test labels from: {test}") test = sleap.Labels.load_file(test, video_search=video_search_paths) test_reader = None @@ -207,7 +265,9 @@ def setup_optimizer(config: OptimizationConfig) -> tf.keras.optimizers.Optimizer return optimizer -def setup_losses(config: OptimizationConfig) -> Callable[[tf.Tensor], tf.Tensor]: +def setup_losses( + config: OptimizationConfig, +) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: """Set up model loss function from config.""" losses = [tf.keras.losses.MeanSquaredError()] @@ -554,6 +614,9 @@ def from_config( # Copy input config before we make any changes. initial_config = copy.deepcopy(config) + # Store SLEAP version on the training process. + config.sleap_version = sleap.__version__ + # Create data readers and store loaded skeleton. data_readers = DataReaders.from_config( config.data.labels, @@ -642,7 +705,10 @@ def _setup_model(self): logger.info(f" Parameters: {self.model.keras_model.count_params():3,d}") logger.info(" Heads: ") for i, head in enumerate(self.model.heads): - logger.info(f" heads[{i}] = {head}") + logger.info(f" [{i}] = {head}") + logger.info(" Outputs: ") + for i, output in enumerate(self.model.keras_model.outputs): + logger.info(f" [{i}] = {output}") @property def keras_model(self) -> tf.keras.Model: @@ -674,14 +740,18 @@ def _setup_pipelines(self): ) + key_mapper ) - logger.info(f"Training set: n = {len(self.data_readers.training_labels)}") + logger.info( + f"Training set: n = {len(self.data_readers.training_labels_reader)}" + ) self.validation_pipeline = ( self.pipeline_builder.make_training_pipeline( self.data_readers.validation_labels_reader ) + key_mapper ) - logger.info(f"Validation set: n = {len(self.data_readers.validation_labels)}") + logger.info( + f"Validation set: n = {len(self.data_readers.validation_labels_reader)}" + ) def _setup_optimization(self): """Set up optimizer, loss functions and compile the model.""" @@ -718,9 +788,14 @@ def _setup_optimization(self): def _setup_outputs(self): """Set up output-related functionality.""" if self.config.outputs.save_outputs: - # Build path to run folder. + # Build path to run folder. Timestamp will be added automatically. + # Example: 210204_041707.centroid.n=300 + model_type = self.config.model.heads.which_oneof_attrib_name() + n = len(self.data_readers.training_labels_reader) + len( + self.data_readers.validation_labels_reader + ) self.run_path = setup_new_run_folder( - self.config.outputs, base_run_name=type(self.model.backbone).__name__ + self.config.outputs, base_run_name=f"{model_type}.n={n}" ) # Setup output callbacks. @@ -818,30 +893,60 @@ def train(self): ) logger.info(f"Finished training loop. [{(time() - t0) / 60:.1f} min]") - # Save predictions and evaluations. + # Run post-training actions. if self.config.outputs.save_outputs: + if ( + self.config.outputs.save_visualizations + and self.config.outputs.delete_viz_images + ): + self.cleanup() + + self.evaluate() + + if self.config.outputs.zip_outputs: + self.package() + + def evaluate(self): + """Compute evaluation metrics on data splits and save them.""" + logger.info("Saving evaluation metrics to model folder...") + sleap.nn.evals.evaluate_model( + cfg=self.config, + labels_reader=self.data_readers.training_labels_reader, + model=self.model, + save=True, + split_name="train", + ) + sleap.nn.evals.evaluate_model( + cfg=self.config, + labels_reader=self.data_readers.validation_labels_reader, + model=self.model, + save=True, + split_name="val", + ) + if self.data_readers.test_labels_reader is not None: sleap.nn.evals.evaluate_model( cfg=self.config, - labels_reader=self.data_readers.training_labels_reader, - model=self.model, - save=True, - split_name="train", - ) - sleap.nn.evals.evaluate_model( - cfg=self.config, - labels_reader=self.data_readers.validation_labels_reader, + labels_reader=self.data_readers.test_labels_reader, model=self.model, save=True, - split_name="val", + split_name="test", ) - if self.data_readers.test_labels_reader is not None: - sleap.nn.evals.evaluate_model( - cfg=self.config, - labels_reader=self.data_readers.test_labels_reader, - model=self.model, - save=True, - split_name="test", - ) + + def cleanup(self): + """Delete visualization images subdirectory.""" + viz_path = os.path.join(self.run_path, "viz") + if os.path.exists(viz_path): + logger.info(f"Deleting visualization directory: {viz_path}") + shutil.rmtree(viz_path) + + def package(self): + """Package model folder into a zip file for portability.""" + if self.config.outputs.delete_viz_images: + self.cleanup() + logger.info(f"Packaging results to: {self.run_path}.zip") + shutil.make_archive( + base_name=self.run_path, root_dir=self.run_path, format="zip" + ) @attr.s(auto_attribs=True) @@ -1332,35 +1437,63 @@ def main(): parser.add_argument( "training_job_path", help="Path to training job profile JSON file." ) - parser.add_argument("labels_path", help="Path to labels file to use for training.") + parser.add_argument( + "labels_path", + nargs="?", + default="", + help=( + "Path to labels file to use for training. If specified, overrides the path " + "specified in the training job config." + ), + ) parser.add_argument( "--video-paths", type=str, default="", - help="List of paths for finding videos in case paths inside labels file need fixing.", + help=( + "List of paths for finding videos in case paths inside labels file are " + "not accessible." + ), ) parser.add_argument( "--val_labels", "--val", - help="Path to labels file to use for validation (overrides training job path if set).", + help=( + "Path to labels file to use for validation. If specified, overrides the " + "path specified in the training job config." + ), ) parser.add_argument( "--test_labels", "--test", - help="Path to labels file to use for test (overrides training job path if set).", + help=( + "Path to labels file to use for test. If specified, overrides the path " + "specified in the training job config." + ), ) parser.add_argument( "--tensorboard", action="store_true", - help="Enables TensorBoard logging to the run path.", + help=( + "Enable TensorBoard logging to the run path if not already specified in " + "the training job config." + ), ) parser.add_argument( "--save_viz", action="store_true", - help="Enables saving of prediction visualizations to the run folder.", + help=( + "Enable saving of prediction visualizations to the run folder if not " + "already specified in the training job config." + ), ) parser.add_argument( - "--zmq", action="store_true", help="Enables ZMQ logging (for GUI)." + "--zmq", + action="store_true", + help=( + "Enable ZMQ logging (for GUI) if not already specified in the training " + "job config." + ), ) parser.add_argument( "--run_name", @@ -1387,16 +1520,21 @@ def main(): # Override config settings for CLI-based training. job_config.outputs.save_outputs = True - job_config.outputs.tensorboard.write_logs = args.tensorboard - job_config.outputs.zmq.publish_updates = args.zmq - job_config.outputs.zmq.subscribe_to_controller = args.zmq + job_config.outputs.tensorboard.write_logs |= args.tensorboard + job_config.outputs.zmq.publish_updates |= args.zmq + job_config.outputs.zmq.subscribe_to_controller |= args.zmq if args.run_name != "": job_config.outputs.run_name = args.run_name if args.prefix != "": job_config.outputs.run_name_prefix = args.prefix if args.suffix != "": job_config.outputs.run_name_suffix = args.suffix - job_config.outputs.save_visualizations = args.save_viz + job_config.outputs.save_visualizations |= args.save_viz + if args.labels_path == "": + args.labels_path = None + + logger.info("Versions:") + sleap.versions() logger.info(f"Training labels file: {args.labels_path}") logger.info(f"Training profile: {job_filename}") diff --git a/sleap/prefs.py b/sleap/prefs.py index d39c86a16..5269ccd5d 100644 --- a/sleap/prefs.py +++ b/sleap/prefs.py @@ -12,18 +12,18 @@ class Preferences(object): _prefs = None _defaults = { - "medium step size": 4, + "medium step size": 10, "large step size": 100, "color predicted": False, + "propagate track labels": True, "palette": "standard", "bold lines": False, "trail length": 0, "trail width": 4.0, "trail node count": 1, - "hide videos dock": False, - "hide skeleton dock": False, - "hide instances dock": False, - "hide labeling suggestions dock": False, + "marker size": 4, + "edge style": "Line", + "window state": b"", } _filename = "preferences.yaml" @@ -48,6 +48,11 @@ def save(self): """Save preferences to file.""" util.save_config_yaml(self._filename, self._prefs) + def reset_to_default(self): + """Reset preferences to default.""" + util.save_config_yaml(self._filename, self._defaults) + self.load() + def _validate_key(self, key): if key not in self._defaults: raise KeyError(f"No preference matching '{key}'") diff --git a/sleap/skeleton.py b/sleap/skeleton.py index a8eb9476e..1b9a1d740 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -23,12 +23,14 @@ from networkx.readwrite import json_graph from scipy.io import loadmat + NodeRef = Union[str, "Node"] H5FileRef = Union[str, h5py.File] class EdgeType(Enum): - """ + """Type of edge in the skeleton graph. + The skeleton graph can store different types of edges to represent different things. All edges must specify one or more of the following types: @@ -45,9 +47,13 @@ class EdgeType(Enum): @attr.s(auto_attribs=True, slots=True, eq=False, order=False) class Node: - """ - The class :class:`Node` represents a potential skeleton node. - (But note that nodes can exist without being part of a skeleton.) + """This class represents node in the skeleton graph, i.e., a body part. + + Note: Nodes can exist without being part of a skeleton. + + Attributes: + name: String name of the node. + weight: Weight of the node (not currently used). """ name: str @@ -93,8 +99,7 @@ class Skeleton: _skeleton_idx = count(0) def __init__(self, name: str = None): - """ - Initialize an empty skeleton object. + """Initialize an empty skeleton object. Skeleton objects, once created, can be modified by adding nodes and edges. @@ -102,7 +107,6 @@ def __init__(self, name: str = None): Args: name: A name for this skeleton. """ - # If no skeleton was create, try to create a unique name for this Skeleton. if name is None or not isinstance(name, str) or not name: name = "Skeleton-" + str(next(self._skeleton_idx)) @@ -117,23 +121,33 @@ def __repr__(self) -> str: """Return full description of the skeleton.""" return ( f"Skeleton(name='{self.name}', " - f"nodes={self.node_names}, edges={self.edge_names})" + f"nodes={self.node_names}, " + f"edges={self.edge_names}, " + f"symmetries={self.symmetry_names}" + ")" ) def __str__(self) -> str: """Return short readable description of the skeleton.""" - return f"Skeleton(nodes={len(self.nodes)}, edges={len(self.edges)})" + nodes = ", ".join(self.node_names) + edges = ", ".join([f"{s}->{d}" for (s, d) in self.edge_names]) + symm = ", ".join([f"{s}<->{d}" for (s, d) in self.symmetry_names]) + return ( + "Skeleton(" + f"nodes=[{nodes}], " + f"edges=[{edges}], " + f"symmetries=[{symm}]" + ")" + ) def matches(self, other: "Skeleton") -> bool: - """ - Compare this `Skeleton` to another, ignoring skeleton name and - the identities of the `Node` objects in each graph. + """Compare this `Skeleton` to another, ignoring name and node identities. Args: other: The other skeleton. Returns: - True if match, False otherwise. + `True` if the skeleton graphs are isomorphic and node names. """ def dict_match(dict1, dict2): @@ -147,17 +161,16 @@ def dict_match(dict1, dict2): if not is_isomorphic: return False - # Now check that the nodes have the same labels and order. They can have - # different weights I guess?! + # Now check that the nodes have the same labels and order. for node1, node2 in zip(self._graph.nodes, other._graph.nodes): if node1.name != node2.name: return False - # Check if the two graphs are equal return True @property def is_arborescence(self) -> bool: + """Return whether this skeleton graph forms an arborescence.""" return nx.algorithms.tree.recognition.is_arborescence(self._graph) @property @@ -174,7 +187,7 @@ def cycles(self) -> List[List[Node]]: @property def graph(self): - """Returns subgraph of BODY edges for skeleton.""" + """Return subgraph of BODY edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") @@ -187,7 +200,7 @@ def graph(self): @property def graph_symmetry(self): - """Returns subgraph of symmetric edges for skeleton.""" + """Return subgraph of symmetric edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") @@ -197,8 +210,7 @@ def graph_symmetry(self): @staticmethod def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: - """ - Find all unique nodes from a list of skeletons. + """Find all unique nodes from a list of skeletons. Args: skeletons: The list of skeletons. @@ -210,8 +222,7 @@ def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: @staticmethod def make_cattr(idx_to_node: Dict[int, Node] = None) -> cattr.Converter: - """ - Make cattr.Convert() for `Skeleton`. + """Make cattr.Convert() for `Skeleton`. Make a cattr.Converter() that registers structure/unstructure hooks for Skeleton objects to handle serialization of skeletons. @@ -249,7 +260,8 @@ def name(self) -> str: @name.setter def name(self, name: str): - """ + """Set skeleton name (no-op). + A skeleton object cannot change its name. This property is immutable because it is used to hash skeletons. @@ -274,8 +286,7 @@ def name(self, name: str): @classmethod def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": - """ - Make copy of skeleton with new name. + """Make copy of skeleton with new name. This property is immutable because it is used to hash skeletons. If you want to rename a Skeleton you must use this class method. @@ -361,7 +372,6 @@ def edge_inds(self) -> List[Tuple[int, int]]: A list of (src_node_ind, dst_node_ind), where indices are subscripts into the Skeleton.nodes list. """ - return [ (self.nodes.index(src_node), self.nodes.index(dst_node)) for src_node, dst_node in self.edges @@ -399,6 +409,11 @@ def symmetries(self) -> List[Tuple[Node, Node]]: ) return symmetries + @property + def symmetry_names(self) -> List[Tuple[str, str]]: + """List of symmetry edges as tuples of node names.""" + return [(s.name, d.name) for (s, d) in self.symmetries] + @property def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of all symmetries with keys and attributes. @@ -427,8 +442,7 @@ def symmetric_inds(self) -> np.ndarray: ) def node_to_index(self, node: NodeRef) -> int: - """ - Return the index of the node, accepts either `Node` or name. + """Return the index of the node, accepts either `Node` or name. Args: node: The name of the node or the Node object. @@ -445,8 +459,8 @@ def node_to_index(self, node: NodeRef) -> int: except ValueError: return node_list.index(self.find_node(node)) - def edge_to_index(self, source: NodeRef, destination: NodeRef): - """Returns the index of edge from source to destination.""" + def edge_to_index(self, source: NodeRef, destination: NodeRef) -> int: + """Return the index of edge from source to destination.""" source = self.find_node(source) destination = self.find_node(destination) edge = (source, destination) @@ -464,9 +478,6 @@ def add_node(self, name: str): Raises: ValueError: If name is not unique. - - Returns: - None """ if not isinstance(name, str): raise TypeError("Cannot add nodes to the skeleton that are not str") @@ -477,14 +488,10 @@ def add_node(self, name: str): self._graph.add_node(Node(name)) def add_nodes(self, name_list: List[str]): - """ - Add a list of nodes representing animal parts to the skeleton. + """Add a list of nodes representing animal parts to the skeleton. Args: name_list: List of strings representing the nodes. - - Returns: - None """ for node in name_list: self.add_node(node) @@ -632,8 +639,7 @@ def delete_edge(self, source: str, destination: str): self._graph.remove_edge(source_node, destination_node) def clear_edges(self): - """Deletes all edges in skeleton.""" - + """Delete all edges in skeleton.""" for src, dst in self.edges: self.delete_edge(src, dst) @@ -657,8 +663,9 @@ def add_symmetry(self, node1: str, node2: str): """ node1_node, node2_node = self.find_node(node1), self.find_node(node2) - # We will represent symmetric pairs in the skeleton via additional edges in the _graph - # These edges will have a special attribute signifying they are not part of the skeleton itself + # We will represent symmetric pairs in the skeleton via additional edges in the + # _graph. These edges will have a special attribute signifying they are not part + # of the skeleton itself if node1 == node2: raise ValueError("Cannot add symmetry to the same node.") @@ -677,8 +684,7 @@ def add_symmetry(self, node1: str, node2: str): self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY) def delete_symmetry(self, node1: NodeRef, node2: NodeRef): - """ - Deletes a previously established symmetry between two nodes. + """Delete a previously established symmetry between two nodes. Args: node1: One node (by `Node` object or name) in symmetric pair. @@ -709,8 +715,7 @@ def delete_symmetry(self, node1: NodeRef, node2: NodeRef): self._graph.remove_edges_from(edges) def get_symmetry(self, node: NodeRef) -> Optional[Node]: - """ - Returns the node symmetric with the specified node. + """Return the node symmetric with the specified node. Args: node: Node (by `Node` object or name) to query. @@ -737,8 +742,7 @@ def get_symmetry(self, node: NodeRef) -> Optional[Node]: raise ValueError(f"{node} has more than one symmetry.") def get_symmetry_name(self, node: NodeRef) -> Optional[str]: - """ - Returns the name of the node symmetric with the specified node. + """Return the name of the node symmetric with the specified node. Args: node: Node (by `Node` object or name) to query. @@ -750,8 +754,7 @@ def get_symmetry_name(self, node: NodeRef) -> Optional[str]: return None if symmetric_node is None else symmetric_node.name def __getitem__(self, node_name: str) -> dict: - """ - Retrieves the node data associated with skeleton node. + """Retrieve the node data associated with skeleton node. Args: node_name: The name from which to retrieve data. @@ -770,8 +773,7 @@ def __getitem__(self, node_name: str) -> dict: return self._graph.nodes.data()[node] def __contains__(self, node_name: str) -> bool: - """ - Checks if specified node exists in skeleton. + """Check if specified node exists in skeleton. Args: node_name: the node name to query @@ -781,6 +783,10 @@ def __contains__(self, node_name: str) -> bool: """ return self.has_node(node_name) + def __len__(self) -> int: + """Return the number of nodes in the skeleton.""" + return len(self.nodes) + def relabel_node(self, old_name: str, new_name: str): """ Relabel a single node to a new name. diff --git a/sleap/util.py b/sleap/util.py index 3b70b6395..0cc47328f 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -310,12 +310,13 @@ def get_config_file( def get_config_yaml(shortname: str, get_defaults: bool = False) -> dict: config_path = get_config_file(shortname, get_defaults=get_defaults) with open(config_path, "r") as f: - return yaml.load(f, Loader=yaml.SafeLoader) + return yaml.load(f, Loader=yaml.Loader) def save_config_yaml(shortname: str, data: Any) -> dict: yaml_path = get_config_file(shortname, ignore_file_not_found=True) with open(yaml_path, "w") as f: + print(f"Saving config: {yaml_path}") yaml.dump(data, f) @@ -379,16 +380,3 @@ def find_files_by_suffix( ) return matching_files - - -def open_file(filename): - """ - Opens file (as if double-clicked by user). - - https://stackoverflow.com/questions/17317219/is-there-an-platform-independent-equivalent-of-os-startfile/17317468#17317468 - """ - if sys.platform == "win32": - os.startfile(filename) - else: - opener = "open" if sys.platform == "darwin" else "xdg-open" - subprocess.call([opener, filename]) diff --git a/sleap/version.py b/sleap/version.py index 108d51f53..50ce1fbc2 100644 --- a/sleap/version.py +++ b/sleap/version.py @@ -10,4 +10,23 @@ Must be a semver string, "aN" should be appended for alpha releases. """ + + __version__ = "1.1.0a9" + + +def versions(): + """Print versions of SLEAP and other libraries.""" + import tensorflow as tf + import numpy as np + import platform + + vers = {} + vers["SLEAP"] = __version__ + vers["TensorFlow"] = tf.__version__ + vers["Numpy"] = np.__version__ + vers["Python"] = platform.python_version() + vers["OS"] = platform.platform() + + msg = "\n".join([f"{k}: {v}" for k, v in vers.items()]) + print(msg) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index acdcb1b8b..6528e6677 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1,5 +1,10 @@ -from sleap.gui.commands import CommandContext, ImportDeepLabCutFolder +from sleap.gui.commands import ( + CommandContext, + ImportDeepLabCutFolder, + get_new_version_filename, +) from sleap.io.pathutils import fix_path_separator +from pathlib import PurePath def test_delete_user_dialog(centered_pair_predictions): @@ -44,3 +49,17 @@ def test_import_labels_from_dlc_folder(): } assert set([l.frame_idx for l in labels.labeled_frames]) == {0, 0, 1} + + +def test_get_new_version_filename(): + assert get_new_version_filename("labels.slp") == "labels copy.slp" + assert get_new_version_filename("labels.v0.slp") == "labels.v1.slp" + assert get_new_version_filename("/a/b/labels.slp") == str( + PurePath("/a/b/labels copy.slp") + ) + assert get_new_version_filename("/a/b/labels.v0.slp") == str( + PurePath("/a/b/labels.v1.slp") + ) + assert get_new_version_filename("/a/b/labels.v01.slp") == str( + PurePath("/a/b/labels.v02.slp") + ) diff --git a/tests/io/test_video.py b/tests/io/test_video.py index 05421bdf1..767ae60ce 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -4,12 +4,13 @@ import numpy as np -from sleap.io.video import Video, HDF5Video, MediaVideo, DummyVideo +from sleap.io.video import Video, HDF5Video, MediaVideo, DummyVideo, load_video from tests.fixtures.videos import ( TEST_H5_FILE, TEST_SMALL_ROBOT_MP4_FILE, TEST_H5_DSET, TEST_H5_INPUT_FORMAT, + TEST_SMALL_CENTERED_PAIR_VID, ) # FIXME: @@ -398,3 +399,9 @@ def test_safe_frame_loading_all_invalid(): assert idxs == [] assert frames is None + + +def test_load_video(): + video = load_video(TEST_SMALL_CENTERED_PAIR_VID) + assert video.shape == (1100, 384, 384, 1) + assert video[:3].shape == (3, 384, 384, 1) diff --git a/tests/nn/data/test_data_training.py b/tests/nn/data/test_data_training.py index 9aca2cb64..eb79464e0 100644 --- a/tests/nn/data/test_data_training.py +++ b/tests/nn/data/test_data_training.py @@ -1,64 +1,75 @@ import numpy as np -import tensorflow as tf -from sleap.nn.system import use_cpu_only +import sleap +from sleap.nn.data.training import split_labels_train_val -use_cpu_only() # hide GPUs for test -import sleap -from sleap.nn.data.providers import LabelsReader -from sleap.nn.data import training - - -def test_split_labels_reader(min_labels): - labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0], min_labels[0]]) - labels_reader = LabelsReader(labels) - reader1, reader2 = training.split_labels_reader(labels_reader, [0.5, 0.5]) - assert len(reader1) == 2 - assert len(reader2) == 2 - assert ( - len(set(reader1.example_indices).intersection(set(reader2.example_indices))) - == 0 - ) +sleap.use_cpu_only() # hide GPUs for test - reader1, reader2 = training.split_labels_reader(labels_reader, [0.1, 0.5]) - assert len(reader1) == 1 - assert len(reader2) == 2 - assert ( - len(set(reader1.example_indices).intersection(set(reader2.example_indices))) - == 0 - ) - reader1, reader2 = training.split_labels_reader(labels_reader, [0.1, -1]) - assert len(reader1) == 1 - assert len(reader2) == 3 - assert ( - len(set(reader1.example_indices).intersection(set(reader2.example_indices))) - == 0 +def test_split_labels_train_val(): + vid = sleap.Video(backend=sleap.io.video.MediaVideo) + labels = sleap.Labels([sleap.LabeledFrame(video=vid, frame_idx=0)]) + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0) + assert len(train) == 1 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.1) + assert len(train) == 1 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.5) + assert len(train) == 1 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 1.0) + assert len(train) == 1 + assert len(val) == 1 + + labels = sleap.Labels( + [ + sleap.LabeledFrame(video=vid, frame_idx=0), + sleap.LabeledFrame(video=vid, frame_idx=1), + ] ) + train, train_inds, val, val_inds = split_labels_train_val(labels, 0) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.1) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.5) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx + + train, train_inds, val, val_inds = split_labels_train_val(labels, 1.0) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx - labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0], min_labels[0]]) - labels_reader = LabelsReader(labels, example_indices=[1, 2, 3]) - reader1, reader2 = training.split_labels_reader(labels_reader, [0.1, -1]) - assert len(reader1) == 1 - assert len(reader2) == 2 - assert ( - len(set(reader1.example_indices).intersection(set(reader2.example_indices))) - == 0 + labels = sleap.Labels( + [ + sleap.LabeledFrame(video=vid, frame_idx=0), + sleap.LabeledFrame(video=vid, frame_idx=1), + sleap.LabeledFrame(video=vid, frame_idx=2), + ] ) - assert 0 not in reader1.example_indices - assert 0 not in reader2.example_indices - - -def test_keymapper(): - ds = tf.data.Dataset.from_tensors({"a": 0, "b": 1}) - mapper = training.KeyMapper(key_maps={"a": "x", "b": "y"}) - ds = mapper.transform_dataset(ds) - np.testing.assert_array_equal(next(iter(ds)), {"x": 0, "y": 1}) - assert mapper.input_keys == ["a", "b"] - assert mapper.output_keys == ["x", "y"] - - ds = tf.data.Dataset.from_tensors({"a": 0, "b": 1}) - ds = training.KeyMapper(key_maps=[{"a": "x"}, {"b": "y"}]).transform_dataset(ds) - np.testing.assert_array_equal(next(iter(ds)), ({"x": 0}, {"y": 1})) - assert mapper.input_keys == ["a", "b"] - assert mapper.output_keys == ["x", "y"] + train, train_inds, val, val_inds = split_labels_train_val(labels, 0) + assert len(train) == 2 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.1) + assert len(train) == 2 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.5) + assert len(train) + len(val) == 3 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 1.0) + assert len(train) == 1 + assert len(val) == 2 diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 9782fd944..e5b4b1a54 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -23,8 +23,8 @@ FindInstancePeaksGroundTruth, FindInstancePeaks, TopDownInferenceModel, - TopdownPredictor, - BottomupPredictor, + TopDownPredictor, + BottomUpPredictor, load_model, ) @@ -492,7 +492,7 @@ def test_single_instance_predictor( def test_topdown_predictor_centroid(min_labels, min_centroid_model_path): - predictor = TopdownPredictor.from_trained_models( + predictor = TopDownPredictor.from_trained_models( centroid_model_path=min_centroid_model_path ) labels_pr = predictor.predict(min_labels) @@ -512,7 +512,7 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path): def test_topdown_predictor_centered_instance( min_labels, min_centered_instance_model_path ): - predictor = TopdownPredictor.from_trained_models( + predictor = TopDownPredictor.from_trained_models( confmap_model_path=min_centered_instance_model_path ) labels_pr = predictor.predict(min_labels) @@ -530,7 +530,7 @@ def test_topdown_predictor_centered_instance( def test_topdown_predictor_bottomup(min_labels, min_bottomup_model_path): - predictor = BottomupPredictor.from_trained_models( + predictor = BottomUpPredictor.from_trained_models( model_path=min_bottomup_model_path ) labels_pr = predictor.predict(min_labels) @@ -557,7 +557,7 @@ def test_load_model( assert isinstance(predictor, SingleInstancePredictor) predictor = load_model([min_centroid_model_path, min_centered_instance_model_path]) - assert isinstance(predictor, TopdownPredictor) + assert isinstance(predictor, TopDownPredictor) predictor = load_model(min_bottomup_model_path) - assert isinstance(predictor, BottomupPredictor) + assert isinstance(predictor, BottomUpPredictor) diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index ad3667a94..e35aa5bec 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -396,13 +396,3 @@ def test_arborescence(): assert len(skeleton.cycles) == 0 assert len(skeleton.root_nodes) == 1 assert len(skeleton.in_degree_over_one) == 1 - - -def test_repr_str(): - skel = Skeleton(name="skel") - skel.add_node("A") - skel.add_node("B") - skel.add_edge("A", "B") - - assert repr(skel) == "Skeleton(name='skel', nodes=['A', 'B'], edges=[('A', 'B')])" - assert str(skel) == "Skeleton(nodes=2, edges=1)"