diff --git a/.conda/bld.bat b/.conda/bld.bat index 5aea27a19..040a42587 100644 --- a/.conda/bld.bat +++ b/.conda/bld.bat @@ -38,6 +38,7 @@ pip install jsmin pip install seaborn pip install pykalman==0.9.5 pip install segmentation-models==1.0.1 +pip install rich==9.10.0 rem # Use and update environment.yml call to install pip dependencies. This is slick. rem # While environment.yml contains the non pip dependencies, the only thing left diff --git a/.conda/build.sh b/.conda/build.sh index f43797a54..984701e19 100644 --- a/.conda/build.sh +++ b/.conda/build.sh @@ -36,6 +36,7 @@ pip install jsmin pip install seaborn pip install pykalman==0.9.5 pip install segmentation-models==1.0.1 +pip install rich==9.10.0 pip install setuptools-scm diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..bc6ed0ccb --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,26 @@ + +### Description +[describe your changes here] + +### Types of changes + +- [ ] Bugfix +- [ ] New feature +- [ ] Refactor / Code style update (no logical changes) +- [ ] Build / CI changes +- [ ] Documentation Update +- [ ] Other (explain) + +### Does this address any currently open issues? +[list open issues here] + +### Outside contributors checklist + +- [ ] Review the [guidelines for contributing](https://github.com/murthylab/sleap/wiki/Developer-Guide) to this repository +- [ ] Read and sign the [CLA](https://github.com/murthylab/sleap/blob/develop/sleap-cla.pdf) and add yourself to the [authors list](https://github.com/murthylab/sleap/blob/develop/AUTHORS) +- [ ] Make sure you are making a pull request against the **develop** branch (not *master*). Also you should start *your branch* off *develop* +- [ ] Add tests that prove your fix is effective or that your feature works +- [ ] Add necessary documentation (if appropriate) + +#### Thank you for contributing to SLEAP! +:heart: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d76863ec4..d270dbc66 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,39 @@ on: - 'tests/**' jobs: - build: + type_check: + name: Type Check + runs-on: "ubuntu-18.04" + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.6 + uses: actions/setup-python@v1 + with: + python-version: 3.6 + - name: Install Dependencies + run: | + pip install mypy + - name: Run MyPy + # TODO: remove this once all MyPy errors get fixed + continue-on-error: true + run: | + mypy --follow-imports=skip --ignore-missing-imports sleap tests + lint: + name: Lint + runs-on: "ubuntu-18.04" + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.6 + uses: actions/setup-python@v1 + with: + python-version: 3.6 + - name: Install Dependencies + run: | + pip install black + - name: Run Black + run: | + black --check sleap tests + tests: name: Tests (${{ matrix.os }}) runs-on: ${{ matrix.os }} strategy: diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 000000000..f2e84e949 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,11 @@ +# This is the official list of SLEAP authors not affiliated with Princeton University (for copyright purposes). + +# If you are contributing to SLEAP, please add your name and the name of your +# organization (which holds the copyright) to this list in alphabetical order. + +# Names should be added to this file as +# Name Organization name (or 'Individual Person' if not applicable) +# Please keep the list sorted. + + +John Smith Example Inc. diff --git a/docs/conf.py b/docs/conf.py index 6eb8197cf..5f2580a18 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -89,15 +89,13 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # -html_theme_options = { - 'font_size': '15px' -} +html_theme_options = {"font_size": "15px"} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/requirements.txt b/requirements.txt index be7a19828..67364300f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,5 @@ qimage2ndarray==1.8 jsmin seaborn pykalman==0.9.5 -segmentation-models==1.0.1 \ No newline at end of file +segmentation-models==1.0.1 +rich==9.10.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 67893fb4f..2dbaadf51 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ # Get the sleap version with open(path.join(here, "sleap/version.py")) as f: version_file = f.read() - sleap_version = re.search("\d.+(?=['\"])", version_file).group(0) + sleap_version = re.search('__version__ = "([0-9\\.a]+)"', version_file).group(1) def get_requirements(require_name=None): @@ -31,11 +31,13 @@ def get_requirements(require_name=None): version=sleap_version, setup_requires=["setuptools_scm"], install_requires=get_requirements(), - extras_require={"dev": get_requirements("dev"),}, - description="SLEAP (Social LEAP Estimates Animal Pose) is a deep learning framework for estimating animal pose.", + extras_require={ + "dev": get_requirements("dev"), + }, + description="SLEAP (Social LEAP Estimates Animal Poses) is a deep learning framework for animal pose tracking.", long_description=long_description, long_description_content_type="text/x-rst", - author="Talmo Pereira, David Turner, Nat Tabris", + author="Talmo Pereira, Arie Matsliah, David Turner, Nat Tabris", author_email="talmo@princeton.edu", project_urls={ "Documentation": "https://sleap.ai/#sleap", diff --git a/sleap-cla.pdf b/sleap-cla.pdf new file mode 100644 index 000000000..1a5332ae1 Binary files /dev/null and b/sleap-cla.pdf differ diff --git a/sleap/__init__.py b/sleap/__init__.py index 60de74ae5..3b6aa14a5 100644 --- a/sleap/__init__.py +++ b/sleap/__init__.py @@ -5,8 +5,9 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO) # Import submodules we want available at top-level +from sleap.version import __version__, versions from sleap.io.dataset import Labels, load_file -from sleap.io.video import Video +from sleap.io.video import Video, load_video from sleap.instance import LabeledFrame, Instance, PredictedInstance, Track from sleap.skeleton import Skeleton import sleap.nn @@ -15,4 +16,4 @@ from sleap.nn.inference import load_model from sleap.nn.system import use_cpu_only, disable_preallocation from sleap.nn.system import summary as system_summary -from sleap.version import __version__ +from sleap.nn.config import TrainingJobConfig, load_config diff --git a/sleap/config/shortcuts.yaml b/sleap/config/shortcuts.yaml index 8123231ce..e4a86436a 100644 --- a/sleap/config/shortcuts.yaml +++ b/sleap/config/shortcuts.yaml @@ -1,7 +1,7 @@ add instance: Ctrl+I add videos: Ctrl+A clear selection: Esc -close: QKeySequence.Close +close: Ctrl+Q color predicted: delete area: Ctrl+K delete clip: @@ -11,19 +11,19 @@ export clip: fit: Ctrl+= goto frame: Ctrl+J goto marked: Ctrl+Shift+M -goto next suggestion: QKeySequence.FindNext +goto next suggestion: Space goto next track spawn: Ctrl+E -goto next user: Ctrl+> -goto next labeled: Ctrl+. -goto prev suggestion: QKeySequence.FindPrevious -goto prev labeled: +goto next user: Ctrl+U +goto next labeled: Alt+Right +goto prev suggestion: Shift+Space +goto prev labeled: Alt+Left learning: Ctrl+L mark frame: Ctrl+M new: Ctrl+N next video: QKeySequence.Forward open: Ctrl+O prev video: QKeySequence.Back -save as: QKeySequence.SaveAs +save as: Ctrl+Shift+S save: Ctrl+S select next: '`' select to frame: Ctrl+Shift+J @@ -33,7 +33,7 @@ show trails: transpose: Ctrl+T frame next: Right frame prev: Left -frame next medium step: Down -frame prev medium step: Up -frame next large step: Space -frame prev large step: / +frame next medium step: Ctrl+Right +frame prev medium step: Ctrl+Left +frame next large step: Ctrl+Alt+Right +frame prev large step: Ctrl+Alt+Left diff --git a/sleap/config/training_editor_form.yaml b/sleap/config/training_editor_form.yaml index 1a82280ef..d362e8cee 100644 --- a/sleap/config/training_editor_form.yaml +++ b/sleap/config/training_editor_form.yaml @@ -466,6 +466,13 @@ augmentation: label: Scale Max name: optimization.augmentation_config.scale_max type: double +- label: Random flip + default: none + type: list + name: optimization.augmentation_config.random_flip + options: none,horizontal,vertical + default: none + help: 'Randomly reflect images and instances. IMPORTANT: Left/right symmetric nodes must be indicated in the skeleton or this will lead to incorrect results!' - default: false help: If True, uniformly distributed noise will be added to the image. This is effectively adding a different random value to each pixel to simulate shot noise. See `imgaug.augmenters.arithmetic.AddElementwise`. @@ -533,6 +540,7 @@ augmentation: label: Brightness Max Val name: optimization.augmentation_config.brightness_max_val type: double + optimization: - default: 8 help: Number of examples per minibatch, i.e., a single step of training. Higher diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 3902f5861..859543e3a 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -106,14 +106,14 @@ class MainWindow(QMainWindow): whether to show node labels, etc. """ - def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): + def __init__( + self, labels_path: Optional[str] = None, reset: bool = False, *args, **kwargs + ): """Initialize the app. Args: labels_path: Path to saved :class:`Labels` dataset. - - Returns: - None. + reset: If `True`, reset preferences to default (including window state). """ super(MainWindow, self).__init__(*args, **kwargs) @@ -137,14 +137,24 @@ def __init__(self, labels_path: Optional[str] = None, *args, **kwargs): self.state["filename"] = None self.state["show labels"] = True self.state["show edges"] = True - self.state["edge style"] = "Line" + self.state["edge style"] = prefs["edge style"] self.state["fit"] = False self.state["color predicted"] = prefs["color predicted"] + self.state["marker size"] = prefs["marker size"] + self.state["propagate track labels"] = prefs["propagate track labels"] + self.state.connect("marker size", self.plotFrame) self.release_checker = ReleaseChecker() self._initialize_gui() + if reset: + print("Reseting GUI state and preferences...") + prefs.reset_to_default() + elif len(prefs["window state"]) > 0: + print("Restoring GUI state...") + self.restoreState(prefs["window state"]) + if labels_path: self.loadProjectFile(labels_path) else: @@ -174,7 +184,17 @@ def event(self, e: QEvent) -> bool: return super().event(e) def closeEvent(self, event): - """Closes application window, prompting for saving as needed.""" + """Close application window, prompting for saving as needed.""" + # Save window state. + prefs["window state"] = self.saveState() + prefs["marker size"] = self.state["marker size"] + prefs["edge style"] = self.state["edge style"] + prefs["propagate track labels"] = self.state["propagate track labels"] + prefs["color predicted"] = self.state["color predicted"] + + # Save preferences. + prefs.save() + if not self.state["has_changes"]: # No unsaved changes, so accept event (close) event.accept() @@ -273,8 +293,9 @@ def connect_check(key): # add checkable menu item connected to state variable def add_menu_check_item(menu, key: str, name: str): - add_menu_item(menu, key, name, lambda: self.state.toggle(key)) + menu_item = add_menu_item(menu, key, name, lambda: self.state.toggle(key)) connect_check(key) + return menu_item # check and uncheck submenu items def _menu_check_single(menu, item_text): @@ -359,8 +380,6 @@ def add_submenu_choices(menu, title, options, key): fileMenu.addSeparator() add_menu_item(fileMenu, "save", "Save", self.commands.saveProject) add_menu_item(fileMenu, "save as", "Save As...", self.commands.saveProjectAs) - - fileMenu.addSeparator() add_menu_item( fileMenu, "export analysis", @@ -368,6 +387,11 @@ def add_submenu_choices(menu, title, options, key): self.commands.exportAnalysisFile, ) + fileMenu.addSeparator() + add_menu_item( + fileMenu, "reset prefs", "Reset preferences to defaults...", self.resetPrefs + ) + fileMenu.addSeparator() add_menu_item(fileMenu, "close", "Quit", self.close) @@ -489,10 +513,17 @@ def prev_vid(): key="edge style", ) + add_submenu_choices( + menu=viewMenu, + title="Node Marker Size", + options=(1, 4, 6, 8, 12), + key="marker size", + ) + add_submenu_choices( menu=viewMenu, title="Trail Length", - options=(0, 10, 20, 50, 100, 200, 500), + options=(0, 10, 50, 100, 250), key="trail_length", ) @@ -627,17 +658,28 @@ def new_instance_menu_action(): self.commands.deleteFrameLimitPredictions, ) - labelMenu.addSeparator() + # labelMenu.addSeparator() - self.track_menu = labelMenu.addMenu("Set Instance Track") + ### Tracks Menu ### + + tracksMenu = self.menuBar().addMenu("Tracks") + self.track_menu = tracksMenu.addMenu("Set Instance Track") + self.delete_tracks_menu = tracksMenu.addMenu("Delete Track") + self.delete_tracks_menu.setEnabled(False) + add_menu_check_item( + tracksMenu, "propagate track labels", "Propagate Track Labels" + ).setToolTip( + "If enabled, setting a track will also apply to subsequent instances of " + "the same track." + ) add_menu_item( - labelMenu, + tracksMenu, "transpose", "Transpose Instance Tracks", self.commands.transposeInstance, ) add_menu_item( - labelMenu, + tracksMenu, "delete track", "Delete Instance and Track", self.commands.deleteSelectedInstanceTrack, @@ -678,10 +720,21 @@ def new_instance_menu_action(): ) predictionMenu.addSeparator() + + training_package_menu = predictionMenu.addMenu("Export Training Package...") add_menu_item( - predictionMenu, + training_package_menu, + "export user labels package", + "Labeled frames", + self.commands.exportUserLabelsPackage, + ).setToolTip( + "Export user-labeled frames with image data into a single SLP file.\n\n" + "Use this for archiving a dataset with labeled frames only." + ) + add_menu_item( + training_package_menu, "export training package", - "Export Training Package...", + "Labeled + suggested frames (recommended)", self.commands.exportTrainingPackage, ).setToolTip( "Export user-labeled frames and suggested frames with image data into a " @@ -690,18 +743,9 @@ def new_instance_menu_action(): "unlabeled frames." ) add_menu_item( - predictionMenu, - "export user labels package", - "Export Package with User Labels Only...", - self.commands.exportUserLabelsPackage, - ).setToolTip( - "Export user-labeled frames with image data into a single SLP file.\n\n" - "Use this for archiving a dataset with labeled frames only." - ) - add_menu_item( - predictionMenu, + training_package_menu, "export full package", - "Export Package with All Labels...", + "Labeled + predicted + suggested frames", self.commands.exportFullPackage, ).setToolTip( "Export all frames (including predictions) and suggested frames with image " @@ -730,7 +774,7 @@ def new_instance_menu_action(): helpMenu.addSeparator() - helpMenu.addAction("Check for updates...", self.commands.checkForUpdates) + helpMenu.addAction("Latest versions:", self.commands.checkForUpdates) self.state["stable_version_menu"] = helpMenu.addAction( " Stable: N/A", self.commands.openStableVersion ) @@ -754,14 +798,16 @@ def wrapped_function(*args): return wrapped_function def _create_dock_windows(self): - """Create dock windows and connects them to gui.""" + """Create dock windows and connect them to GUI.""" def _make_dock(name, widgets=[], tab_with=None): dock = QDockWidget(name) + dock.setObjectName(name + "Dock") dock.setAllowedAreas(Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea) dock_widget = QWidget() + dock_widget.setObjectName(name + "Widget") layout = QVBoxLayout() for widget in widgets: @@ -770,10 +816,6 @@ def _make_dock(name, widgets=[], tab_with=None): dock_widget.setLayout(layout) dock.setWidget(dock_widget) - key = f"hide {name.lower()} dock" - if key in prefs and prefs[key]: - dock.hide() - self.addDockWidget(Qt.RightDockWidgetArea, dock) self.viewMenu.addAction(dock.toggleViewAction()) @@ -921,6 +963,33 @@ def new_edge(): hb = QHBoxLayout() + _add_button( + hb, + "Add current frame", + self.process_events_then(self.commands.addCurrentFrameAsSuggestion), + "add current frame as suggestion", + ) + + _add_button( + hb, + "Remove", + self.process_events_then(self.commands.removeSuggestion), + "remove suggestion", + ) + + _add_button( + hb, + "Clear all", + self.process_events_then(self.commands.clearSuggestions), + "clear suggestions", + ) + + hbw = QWidget() + hbw.setLayout(hb) + suggestions_layout.addWidget(hbw) + + hb = QHBoxLayout() + _add_button( hb, "Prev", @@ -943,7 +1012,8 @@ def new_edge(): suggestions_layout.addWidget(hbw) self.suggestions_form_widget = YamlFormWidget.from_name( - "suggestions", title="Generate Suggestions", + "suggestions", + title="Generate Suggestions", ) self.suggestions_form_widget.mainAction.connect( self.process_events_then(self.commands.generateSuggestions) @@ -1031,6 +1101,7 @@ def _update_gui_state(self): # Update menus self.track_menu.setEnabled(has_selected_instance) + self.delete_tracks_menu.setEnabled(has_tracks) self._menu_actions["clear selection"].setEnabled(has_selected_instance) self._menu_actions["delete instance"].setEnabled(has_selected_instance) @@ -1124,13 +1195,14 @@ def _has_topic(topic_list): suggestion_status_text = "" suggestion_list = self.labels.get_suggestions() if suggestion_list: - suggestion_label_counts = [ - self.labels.instance_count(item.video, item.frame_idx) - for item in suggestion_list - ] - labeled_count = len(suggestion_list) - suggestion_label_counts.count(0) + labeled_count = 0 + for suggestion in suggestion_list: + lf = self.labels.get((suggestion.video, suggestion.frame_idx)) + if lf is not None and lf.has_user_instances: + labeled_count += 1 + prc = (labeled_count / len(suggestion_list)) * 100 suggestion_status_text = ( - f"{labeled_count}/{len(suggestion_list)} labeled" + f"{labeled_count}/{len(suggestion_list)} labeled ({prc:.1f}%)" ) self.suggested_count_label.setText(suggestion_status_text) @@ -1176,13 +1248,17 @@ def updateStatusMessage(self, message: Optional[str] = None): spacer = " " if message is None: - message = f"Frame: {frame_idx+1:,}/{len(current_video):,}" + message = "" + if len(self.labels.videos) > 1: + message += f"Video {self.labels.videos.index(current_video)+1}/" + message += f"{len(self.labels.videos)}" + message += spacer + + message += f"Frame: {frame_idx+1:,}/{len(current_video):,}" if self.player.seekbar.hasSelection(): start, end = self.state["frame_range"] - message += f" (selection: {start+1:,}-{end:,})" - - if len(self.labels.videos) > 1: - message += f" of video {self.labels.videos.index(current_video)+1}" + message += spacer + message += f"Selection: {start+1:,}-{end:,} ({end-start+1:,} frames)" message += f"{spacer}Labeled Frames: " if current_video is not None: @@ -1211,6 +1287,15 @@ def updateStatusMessage(self, message: Optional[str] = None): self.statusBar().showMessage(message) + def resetPrefs(self): + """Reset preferences to defaults.""" + prefs.reset_to_default() + msg = QMessageBox() + msg.setText( + "Note: Some preferences may not take effect until application is restarted." + ) + msg.exec_() + def loadProjectFile(self, filename: Optional[str] = None): """ Loads given labels file into GUI. @@ -1280,15 +1365,19 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None): def _update_track_menu(self): """Updates track menu options.""" self.track_menu.clear() - for track in self.labels.tracks: + self.delete_tracks_menu.clear() + for track_ind, track in enumerate(self.labels.tracks): key_command = "" - if self.labels.tracks.index(track) < 9: + if track_ind < 9: key_command = Qt.CTRL + Qt.Key_0 + self.labels.tracks.index(track) + 1 self.track_menu.addAction( f"{track.name}", lambda x=track: self.commands.setInstanceTrack(x), key_command, ) + self.delete_tracks_menu.addAction( + f"{track.name}", lambda x=track: self.commands.deleteTrack(x) + ) self.track_menu.addAction( "New Track", self.commands.addTrack, Qt.CTRL + Qt.Key_0 ) @@ -1425,7 +1514,9 @@ def _show_learning_dialog(self, mode: str): if self._child_windows.get(mode, None) is None: self._child_windows[mode] = LearningDialog( - mode, self.state["filename"], self.labels, + mode, + self.state["filename"], + self.labels, ) self._child_windows[mode]._handle_learning_finished.connect( self._handle_learning_finished @@ -1559,6 +1650,16 @@ def main(): const=True, default=False, ) + parser.add_argument( + "--reset", + help=( + "Reset GUI state and preferences. Use this flag if the GUI appears " + "incorrectly or fails to open." + ), + action="store_const", + const=True, + default=False, + ) args = parser.parse_args() @@ -1574,15 +1675,23 @@ def main(): app = QApplication([]) app.setApplicationName(f"SLEAP Label v{sleap.version.__version__}") - window = MainWindow(labels_path=args.labels_path) + window = MainWindow(labels_path=args.labels_path, reset=args.reset) window.showMaximized() - # if not args.labels_path: - # window.commands.openProject(first_open=True) + # Disable GPU in GUI process. This does not affect subprocesses. + sleap.use_cpu_only() + + # Print versions. + print() + print("Software versions:") + sleap.versions() + print() + print("Happy SLEAPing! :)") if args.profiling: import cProfile - cProfile.runctx('app.exec_()', globals=globals(), locals=locals()) + + cProfile.runctx("app.exec_()", globals=globals(), locals=locals()) else: app.exec_() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index e1a278de1..503b9786a 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -29,6 +29,9 @@ class which inherits from `AppCommand` (or a more specialized class such as import attr import operator import os +import re +import sys +import subprocess from abc import ABC from enum import Enum @@ -40,7 +43,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from PySide2 import QtCore, QtWidgets, QtGui -from PySide2.QtWidgets import QMessageBox +from PySide2.QtWidgets import QMessageBox, QProgressDialog from sleap.gui.dialogs.delete import DeleteDialog from sleap.skeleton import Skeleton @@ -326,6 +329,18 @@ def prevSuggestedFrame(self): """Goes to previous suggested frame.""" self.execute(GoPrevSuggestedFrame) + def addCurrentFrameAsSuggestion(self): + """Add current frame as a suggestion.""" + self.execute(AddSuggestion) + + def removeSuggestion(self): + """Remove the selected frame from suggestions.""" + self.execute(RemoveSuggestion) + + def clearSuggestions(self): + """Clear all suggestions.""" + self.execute(ClearSuggestions) + def nextTrackFrame(self): """Goes to next frame on which a track starts.""" self.execute(GoNextTrackFrame) @@ -486,6 +501,10 @@ def setInstanceTrack(self, new_track: "Track"): """Sets track for selected instance.""" self.execute(SetSelectedInstanceTrack, new_track=new_track) + def deleteTrack(self, track: "Track"): + """Delete a track and remove from all instances.""" + self.execute(DeleteTrack, track=track) + def setTrackName(self, track: "Track", name: str): """Sets name for track.""" self.execute(SetTrackName, track=track, name=name) @@ -628,7 +647,9 @@ class ImportLEAP(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_leap_matlab(filename=params["filename"],) + labels = Labels.load_leap_matlab( + filename=params["filename"], + ) new_window = context.app.__class__() new_window.showMaximized() @@ -679,18 +700,6 @@ def ask(context: "CommandContext", params: dict) -> bool: if len(filename) == 0: return False - # QtWidgets.QMessageBox( - # text="Please locate the directory with image files for this dataset." - # ).exec_() - # - # img_dir = FileDialog.openDir( - # None, - # directory=os.path.dirname(filename), - # caption="Open Image Directory" - # ) - # if len(img_dir) == 0: - # return False - params["filename"] = filename params["img_dir"] = os.path.dirname(filename) @@ -729,10 +738,16 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportDeepLabCutFolder(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - csv_files = ImportDeepLabCutFolder.find_dlc_files_in_folder(params['folder_name']) + csv_files = ImportDeepLabCutFolder.find_dlc_files_in_folder( + params["folder_name"] + ) if csv_files: - win = MessageDialog(f"Importing {len(csv_files)} DeepLabCut datasets...", context.app) - merged_labels = ImportDeepLabCutFolder.import_labels_from_dlc_files(csv_files) + win = MessageDialog( + f"Importing {len(csv_files)} DeepLabCut datasets...", context.app + ) + merged_labels = ImportDeepLabCutFolder.import_labels_from_dlc_files( + csv_files + ) win.hide() new_window = context.app.__class__() @@ -809,6 +824,22 @@ def ask(context: "CommandContext", params: dict) -> bool: return True +def get_new_version_filename(filename: str) -> str: + """Increment version number in filenames that end in `.v###.slp`.""" + p = PurePath(filename) + + match = re.match(".*\\.v(\\d+)\\.slp", filename) + if match is not None: + old_ver = match.group(1) + new_ver = str(int(old_ver) + 1).zfill(len(old_ver)) + filename = filename.replace(f".v{old_ver}.slp", f".v{new_ver}.slp") + filename = str(PurePath(filename)) + else: + filename = str(p.with_name(f"{p.stem} copy{p.suffix}")) + + return filename + + class SaveProjectAs(AppCommand): @staticmethod def _try_save(context, labels: Labels, filename: str): @@ -841,15 +872,12 @@ def do_action(cls, context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - default_name = context.state["filename"] or "untitled" - p = PurePath(default_name) - default_name = str(p.with_name(f"{p.stem} copy{p.suffix}")) - - filters = [ - "SLEAP HDF5 dataset (*.slp)", - "SLEAP JSON dataset (*.json)", - "Compressed JSON (*.zip)", - ] + default_name = context.state["filename"] + if default_name: + default_name = get_new_version_filename(default_name) + else: + default_name = "labels.v000.slp" + filters = ["SLEAP labels dataset (*.slp)"] filename, selected_filter = FileDialog.save( context.app, caption="Save As...", @@ -873,7 +901,7 @@ def do_action(cls, context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - default_name = context.state["filename"] or "untitled" + default_name = context.state["filename"] or "labels" p = PurePath(default_name) default_name = str(p.with_name(f"{p.stem}.analysis.h5")) @@ -1001,25 +1029,65 @@ def ask(context: CommandContext, params: dict) -> bool: return True +def export_dataset_gui( + labels: Labels, filename: str, all_labeled: bool = False, suggested: bool = False +) -> str: + """Export dataset with image data and display progress GUI dialog. + + Args: + labels: `sleap.Labels` dataset to export. + filename: Output filename. Should end in `.pkg.slp`. + all_labeled: If `True`, export all labeled frames, including frames with no user + instances. + suggested: If `True`, include image data for suggested frames. + """ + win = QProgressDialog("Exporting dataset with frame images...", "Cancel", 0, 1) + + def update_progress(n, n_total): + if win.wasCanceled(): + return False + win.setMaximum(n_total) + win.setValue(n) + win.setLabelText( + "Exporting dataset with frame images...
" + f"{n}/{n_total} ({(n/n_total)*100:.1f}%)" + ) + QtWidgets.QApplication.instance().processEvents() + return True + + Labels.save_file( + labels, + filename, + default_suffix="slp", + save_frame_data=True, + all_labeled=all_labeled, + suggested=suggested, + progress_callback=update_progress, + ) + + if win.wasCanceled(): + # Delete output if saving was canceled. + os.remove(filename) + return "canceled" + + win.hide() + + return filename + + class ExportDatasetWithImages(AppCommand): all_labeled = False suggested = False @classmethod def do_action(cls, context: CommandContext, params: dict): - win = MessageDialog("Exporting dataset with frame images...", context.app) - - Labels.save_file( - context.state["labels"], - params["filename"], - default_suffix="slp", - save_frame_data=True, + export_dataset_gui( + labels=context.state["labels"], + filename=params["filename"], all_labeled=cls.all_labeled, suggested=cls.suggested, ) - win.hide() - @staticmethod def ask(context: CommandContext, params: dict) -> bool: filters = [ @@ -1801,11 +1869,12 @@ def do_action(context: CommandContext, params: dict): if selected_instance is None: return - old_track = selected_instance.track - # When setting track for an instance that doesn't already have a track set, # just set for selected instance. - if old_track is None: + if ( + selected_instance.track is None + or not context.state["propagate track labels"] + ): # Move anything already in the new track out of it new_track_instances = context.labels.find_track_occupancy( video=context.state["video"], @@ -1825,6 +1894,7 @@ def do_action(context: CommandContext, params: dict): # When the instance does already have a track, then we want to update # the track for a range of frames. else: + old_track = selected_instance.track # Determine range that should be affected if context.state["has_frame_range"]: @@ -1846,6 +1916,15 @@ def do_action(context: CommandContext, params: dict): context.state["instance"] = selected_instance +class DeleteTrack(EditCommand): + topics = [UpdateTopic.tracks] + + @staticmethod + def do_action(context: CommandContext, params: dict): + track = params["track"] + context.labels.remove_track(track) + + class SetTrackName(EditCommand): topics = [UpdateTopic.tracks, UpdateTopic.frame] @@ -1862,7 +1941,11 @@ class GenerateSuggestions(EditCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - win = MessageDialog("Generating list of suggested frames...", context.app) + # TODO: Progress bar + win = MessageDialog( + "Generating list of suggested frames... " "This may take a few minutes.", + context.app, + ) new_suggestions = VideoFrameSuggestions.suggest( labels=context.labels, params=params @@ -1873,6 +1956,56 @@ def do_action(cls, context: CommandContext, params: dict): win.hide() +class AddSuggestion(EditCommand): + topics = [UpdateTopic.suggestions] + + @classmethod + def do_action(cls, context: CommandContext, params: dict): + context.labels.add_suggestion( + context.state["video"], context.state["frame_idx"] + ) + context.app.suggestionsTable.selectRow(len(context.labels) - 1) + + +class RemoveSuggestion(EditCommand): + topics = [UpdateTopic.suggestions] + + @classmethod + def do_action(cls, context: CommandContext, params: dict): + selected_frame = context.app.suggestionsTable.getSelectedRowItem() + if selected_frame is not None: + context.labels.remove_suggestion( + selected_frame.video, selected_frame.frame_idx + ) + + +class ClearSuggestions(EditCommand): + topics = [UpdateTopic.suggestions] + + @staticmethod + def ask(context: CommandContext, params: dict) -> bool: + if len(context.labels.suggestions) == 0: + return False + + # Warn that suggestions will be cleared + + response = QMessageBox.warning( + context.app, + "Clearing all suggestions", + "Are you sure you want to remove all suggestions from the project?", + QMessageBox.Yes, + QMessageBox.No, + ) + if response == QMessageBox.No: + return False + + return True + + @classmethod + def do_action(cls, context: CommandContext, params: dict): + context.labels.clear_suggestions() + + class MergeProject(EditCommand): topics = [UpdateTopic.all] @@ -1934,7 +2067,8 @@ def do_action(cls, context: CommandContext, params: dict): if context.state["labeled_frame"] is None: return - # FIXME: filter by skeleton type + if len(context.state["skeleton"]) == 0: + return from_predicted = copy_instance from_prev_frame = False @@ -2127,6 +2261,7 @@ def add_best_nodes(cls, context, instance, visible): @classmethod def add_random_nodes(cls, context, instance, visible): + # TODO: Move this to Instance so we can do this on-demand # the rect that's currently visible in the window view in_view_rect = context.app.player.getVisibleRect() @@ -2249,10 +2384,19 @@ def do_action(cls, context: CommandContext, params: dict): context.labels.add_instance(context.state["labeled_frame"], new_instance) +def open_website(url: str): + """Open website in default browser. + + Args: + url: URL to open. + """ + QtGui.QDesktopServices.openUrl(QtCore.QUrl(url)) + + class OpenWebsite(AppCommand): @staticmethod def do_action(context: CommandContext, params: dict): - QtGui.QDesktopServices.openUrl(QtCore.QUrl(params["url"])) + open_website(params["url"]) class CheckForUpdates(AppCommand): @@ -2268,6 +2412,7 @@ def do_action(context: CommandContext, params: dict): f" Prerelease: {prerelease.version}" ) context.state["prerelease_version_menu"].setEnabled(True) + # TODO: Provide GUI feedback about result. @@ -2285,3 +2430,30 @@ def do_action(context: CommandContext, params: dict): rls = context.app.release_checker.latest_prerelease if rls is not None: context.openWebsite(rls.url) + + +def copy_to_clipboard(text: str): + """Copy a string to the system clipboard. + + Args: + text: String to copy to clipboard. + """ + clipboard = QtWidgets.QApplication.clipboard() + clipboard.clear(mode=clipboard.Clipboard) + clipboard.setText(text, mode=clipboard.Clipboard) + + +def open_file(filename: str): + """Opens file in native system file browser or registered application. + + Args: + filename: Path to file or folder. + + Notes: + Source: https://stackoverflow.com/a/16204023 + """ + if sys.platform == "win32": + os.startfile(filename) + else: + opener = "open" if sys.platform == "darwin" else "xdg-open" + subprocess.call([opener, filename]) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 2834468c1..4106575d3 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -303,8 +303,7 @@ def selectionChanged(self, new, old): self.state[self.name_prefix + self.row_name] = item def activateSelected(self, *args): - """ - Activates item currently selected in table. + """Activate item currently selected in table. "Activate" means that the relevant :py:class:`GuiState` state variable is set to the currently selected item. @@ -313,8 +312,7 @@ def activateSelected(self, *args): self.state[self.row_name] = self.getSelectedRowItem() def selectRowItem(self, item: Any): - """ - Selects row corresponding to item. + """Select row corresponding to item. If the table model converts items to dictionaries (using `item_to_data` method), then `item` argument should be the original item, not the @@ -330,9 +328,12 @@ def selectRowItem(self, item: Any): if self.row_name: self.state[self.name_prefix + self.row_name] = item + def selectRow(self, idx: int): + """Select row corresponding to index.""" + self.selectRowItem(self.model().original_items[idx]) + def getSelectedRowItem(self) -> Any: - """ - Returns item corresponding to currently selected row. + """Return item corresponding to currently selected row. Note that if the table model converts items to dictionaries (using `item_to_data` method), then returned item will be the original item, @@ -461,7 +462,10 @@ def item_to_data(self, obj, item): item_dict["SuggestionFrame"] = item - video_string = f"{labels.videos.index(item.video)+1}: {os.path.basename(item.video.filename)}" + video_string = ( + f"{labels.videos.index(item.video)+1}: " + f"{os.path.basename(item.video.filename)}" + ) item_dict["group"] = str(item.group + 1) if item.group is not None else "" item_dict["group_int"] = item.group if item.group is not None else -1 @@ -469,7 +473,8 @@ def item_to_data(self, obj, item): item_dict["frame"] = int(item.frame_idx) + 1 # start at frame 1 rather than 0 # show how many labeled instances are in this frame - val = labels.instance_count(item.video, item.frame_idx) + lf = labels.get((item.video, item.frame_idx)) + val = 0 if lf is None else len(lf.user_instances) val = str(val) if val > 0 else "" item_dict["labeled"] = val diff --git a/sleap/gui/dialogs/delete.py b/sleap/gui/dialogs/delete.py index 08c5162e0..19bdae90b 100644 --- a/sleap/gui/dialogs/delete.py +++ b/sleap/gui/dialogs/delete.py @@ -24,7 +24,10 @@ class DeleteDialog(QtWidgets.QDialog): # NOTE: use type by name (rather than importing CommandContext) to avoid # circular includes. def __init__( - self, context: "CommandContext", *args, **kwargs, + self, + context: "CommandContext", + *args, + **kwargs, ): super(DeleteDialog, self).__init__(*args, **kwargs) @@ -161,7 +164,9 @@ def inst_condition(inst): frame_idx=self.context.state["frame_idx"], ) elif frames_value == "current video": - lf_list = labels.find(video=self.context.state["video"],) + lf_list = labels.find( + video=self.context.state["video"], + ) elif frames_value == "all videos": lf_list = labels.labeled_frames elif frames_value == "selected clip": diff --git a/sleap/gui/dialogs/formbuilder.py b/sleap/gui/dialogs/formbuilder.py index 41289162f..f69140e9e 100644 --- a/sleap/gui/dialogs/formbuilder.py +++ b/sleap/gui/dialogs/formbuilder.py @@ -501,7 +501,8 @@ def add_item(self, item: Dict[Text, Any]): add_blank_option = True field = TextOrListWidget( - result_as_idx=result_as_optional_idx, add_blank_option=add_blank_option, + result_as_idx=result_as_optional_idx, + add_blank_option=add_blank_option, ) if item["name"] in self.field_options_lists: @@ -910,7 +911,8 @@ def __init__(self, result_as_idx=False, add_blank_option=False, *args, **kwargs) self.text_widget = QtWidgets.QLineEdit() self.list_widget = FieldComboWidget( - result_as_idx=result_as_idx, add_blank_option=add_blank_option, + result_as_idx=result_as_idx, + add_blank_option=add_blank_option, ) layout.addWidget(self.text_widget) diff --git a/sleap/gui/dialogs/importvideos.py b/sleap/gui/dialogs/importvideos.py index 4890f741d..4b78ce404 100644 --- a/sleap/gui/dialogs/importvideos.py +++ b/sleap/gui/dialogs/importvideos.py @@ -49,10 +49,10 @@ def __init__(self): def ask(self): """Runs the import UI. - + 1. Show file selection dialog. 2. Show import parameter dialog with widget for each file. - + Args: None. Returns: @@ -81,7 +81,7 @@ def create_video(import_item: Dict[str, Any]) -> Video: class ImportParamDialog(QDialog): """Dialog for selecting parameters with preview when importing video. - + Args: filenames (list): List of files we want to import. """ @@ -253,7 +253,7 @@ def is_enabled(self): def get_data(self) -> dict: """Get all data (fixed and user-selected) for imported video. - + Returns: Dict with data for this video. """ diff --git a/sleap/gui/dialogs/shortcuts.py b/sleap/gui/dialogs/shortcuts.py index f4e912321..89f700019 100644 --- a/sleap/gui/dialogs/shortcuts.py +++ b/sleap/gui/dialogs/shortcuts.py @@ -28,11 +28,26 @@ def accept(self): for action, widget in self.key_widgets.items(): self.shortcuts[action] = widget.keySequence().toString() self.shortcuts.save() + self.info_msg() + super(ShortcutDialog, self).accept() + + def info_msg(self): + """Display information about changes.""" + msg = QtWidgets.QMessageBox() + msg.setText( + "Application must be restarted before changes to keyboard shortcuts take " + "effect." + ) + msg.exec_() + def reset(self): + """Reset to defaults.""" + self.shortcuts.reset_to_default() + self.info_msg() super(ShortcutDialog, self).accept() def load_shortcuts(self): - """Loads shortcuts object.""" + """Load shortcuts object.""" self.shortcuts = Shortcuts() def make_form(self): @@ -41,26 +56,28 @@ def make_form(self): layout = QtWidgets.QVBoxLayout() layout.addWidget(self.make_shortcuts_widget()) - layout.addWidget( - QtWidgets.QLabel( - "Any changes to keyboard shortcuts will not take effect " - "until you quit and re-open the application." - ) - ) layout.addWidget(self.make_buttons_widget()) self.setLayout(layout) def make_buttons_widget(self) -> QtWidgets.QDialogButtonBox: - """Makes the form buttons.""" - buttons = QtWidgets.QDialogButtonBox( - QtWidgets.QDialogButtonBox.Ok | QtWidgets.QDialogButtonBox.Cancel - ) - buttons.accepted.connect(self.accept) - buttons.rejected.connect(self.reject) + """Make the form buttons.""" + buttons = QtWidgets.QDialogButtonBox() + save = QtWidgets.QPushButton("Save") + save.clicked.connect(self.accept) + buttons.addButton(save, QtWidgets.QDialogButtonBox.AcceptRole) + + cancel = QtWidgets.QPushButton("Cancel") + cancel.clicked.connect(self.reject) + buttons.addButton(cancel, QtWidgets.QDialogButtonBox.RejectRole) + + reset = QtWidgets.QPushButton("Reset to defaults") + reset.clicked.connect(self.reset) + buttons.addButton(reset, QtWidgets.QDialogButtonBox.ActionRole) + return buttons def make_shortcuts_widget(self) -> QtWidgets.QWidget: - """Makes the widget will fields for all shortcuts.""" + """Make the widget will fields for all shortcuts.""" shortcuts = self.shortcuts widget = QtWidgets.QWidget() @@ -75,7 +92,7 @@ def make_shortcuts_widget(self) -> QtWidgets.QWidget: return widget def make_column_widget(self, shortcuts: List) -> QtWidgets.QWidget: - """Makes a single column of shortcut fields. + """Make a single column of shortcut fields. Args: shortcuts: The list of shortcuts to include in this column. diff --git a/sleap/gui/learning/configs.py b/sleap/gui/learning/configs.py index ec02f4353..ad6debece 100644 --- a/sleap/gui/learning/configs.py +++ b/sleap/gui/learning/configs.py @@ -249,7 +249,11 @@ def update(self, select: Optional[ConfigFileInfo] = None): @property def _menu_cfg_idx_offset(self): - if hasattr(self, "options_list") and self.options_list and self.options_list[0] == "": + if ( + hasattr(self, "options_list") + and self.options_list + and self.options_list[0] == "" + ): return 1 return 0 @@ -360,7 +364,9 @@ def find_configs(self) -> List[ConfigFileInfo]: # Collect all configs from specified directories, sorted from most recently modified to least for config_dir in filter(lambda d: os.path.exists(d), self.dir_paths): # Find all json files in dir and subdirs to specified depth - json_files = sleap_utils.find_files_by_suffix(config_dir, ".json", depth=self.search_depth) + json_files = sleap_utils.find_files_by_suffix( + config_dir, ".json", depth=self.search_depth + ) # Sort files, starting with most recently modified json_files.sort(key=lambda f: f.stat().st_mtime, reverse=True) @@ -372,8 +378,9 @@ def find_configs(self) -> List[ConfigFileInfo]: configs.append(cfg_info) # Push old configs to the end of the list, while preserving the time-based order otherwise - configs = [c for c in configs if not c.filename.startswith('old.')] +\ - [c for c in configs if c.filename.startswith('old.')] + configs = [c for c in configs if not c.filename.startswith("old.")] + [ + c for c in configs if c.filename.startswith("old.") + ] return configs diff --git a/sleap/gui/learning/datagen.py b/sleap/gui/learning/datagen.py index b951b4d26..6ee6ea3b0 100644 --- a/sleap/gui/learning/datagen.py +++ b/sleap/gui/learning/datagen.py @@ -145,7 +145,7 @@ def make_datagen_results(reader: LabelsReader, cfg: TrainingJobConfig) -> np.nda sigma=cfg.model.heads.multi_instance.pafs.sigma, output_stride=cfg.model.heads.multi_instance.pafs.output_stride, skeletons=reader.labels.skeletons, - flatten_channels=True + flatten_channels=True, ) ds = pipeline.make_dataset() diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 48be758fe..12f4b74e9 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -3,22 +3,22 @@ """ import cattr import os +import shutil +import atexit +import tempfile +from pathlib import Path -import networkx as nx - +import sleap from sleap import Labels, Video from sleap.gui.dialogs.filedialog import FileDialog from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield -from typing import Dict, List, Optional, Text +from typing import Dict, List, Optional, Text, Optional from PySide2 import QtWidgets, QtCore -# Debug option to skip the training run -SKIP_TRAINING = False - # List of fields which should show list of skeleton nodes NODE_LIST_FIELDS = [ "data.instance_cropping.center_on_part", @@ -84,13 +84,20 @@ def __init__( # Layout for buttons buttons = QtWidgets.QDialogButtonBox() - self.cancel_button = buttons.addButton(QtWidgets.QDialogButtonBox.Cancel) self.save_button = buttons.addButton( - "Save configuration files...", QtWidgets.QDialogButtonBox.ApplyRole + "Save configuration files...", QtWidgets.QDialogButtonBox.ActionRole ) - self.run_button = buttons.addButton( - "Run", QtWidgets.QDialogButtonBox.AcceptRole + self.export_button = buttons.addButton( + "Export training job package...", QtWidgets.QDialogButtonBox.ActionRole + ) + self.cancel_button = buttons.addButton(QtWidgets.QDialogButtonBox.Cancel) + self.run_button = buttons.addButton("Run", QtWidgets.QDialogButtonBox.ApplyRole) + + self.save_button.setToolTip("Save scripts and configuration to run pipeline.") + self.export_button.setToolTip( + "Export data, configuration, and scripts for remote training and inference." ) + self.run_button.setToolTip("Run pipeline locally (GPU recommended).") buttons_layout = QtWidgets.QHBoxLayout() buttons_layout.addWidget(buttons, alignment=QtCore.Qt.AlignTop) @@ -132,9 +139,10 @@ def __init__( self.connect_signals() # Connect actions for buttons - buttons.accepted.connect(self.run) - buttons.rejected.connect(self.reject) - buttons.clicked.connect(self.on_button_click) + self.save_button.clicked.connect(self.save) + self.export_button.clicked.connect(self.export_package) + self.cancel_button.clicked.connect(self.reject) + self.run_button.clicked.connect(self.run) # Connect button for previewing the training data if "_view_datagen" in self.pipeline_form_widget.buttons: @@ -450,8 +458,6 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen frame_count = self.count_total_frames_for_selection_option(frame_selection) if predict_frames_choice.startswith("user"): - # For inference on user labeled frames, we'll have a single - # inference item. items_for_inference = runners.ItemsForInference( items=[ runners.DatasetItemForInference( @@ -461,8 +467,6 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen total_frame_count=frame_count, ) elif predict_frames_choice.startswith("suggested"): - # For inference on all suggested frames, we'll have a single - # inference item. items_for_inference = runners.ItemsForInference( items=[ runners.DatasetItemForInference( @@ -472,7 +476,6 @@ def get_items_for_inference(self, pipeline_form_data) -> runners.ItemsForInferen total_frame_count=frame_count, ) else: - # Otherwise, make an inference item for each video with list of frames. items_for_inference = runners.ItemsForInference.from_video_frames_dict( frame_selection, total_frame_count=frame_count ) @@ -491,24 +494,37 @@ def _validate_pipeline(self): ] if untrained: can_run = False - message = f"Cannot run inference with untrained models ({', '.join(untrained)})." + message = ( + "Cannot run inference with untrained models " + f"({', '.join(untrained)})." + ) # Make sure skeleton will be valid for bottom-up inference. if self.mode == "training" and self.current_pipeline == "bottom-up": skeleton = self.labels.skeletons[0] if not skeleton.is_arborescence: - message += "Cannot run bottom-up pipeline when skeleton is not an arborescence." + message += ( + "Cannot run bottom-up pipeline when skeleton is not an " + "arborescence." + ) root_names = [n.name for n in skeleton.root_nodes] over_max_in_degree = [n.name for n in skeleton.in_degree_over_one] cycles = skeleton.cycles if len(root_names) > 1: - message += f" There are multiple root nodes: {', '.join(root_names)} (there should be exactly one node which is not a target)." + message += ( + f" There are multiple root nodes: {', '.join(root_names)} " + "(there should be exactly one node which is not a target)." + ) if over_max_in_degree: - message += f" There are nodes which are target in multiple edges: {', '.join(over_max_in_degree)} (maximum in-degree should be 1)." + message += ( + " There are nodes which are target in multiple edges: " + f"{', '.join(over_max_in_degree)} (maximum in-degree should be " + "1)." + ) if cycles: cycle_strings = [] @@ -577,27 +593,128 @@ def run(self): win.setWindowTitle("Inference Results") win.exec_() - def save(self): - models_dir = os.path.join(os.path.dirname(self.labels_filename), "/models") - output_dir = FileDialog.openDir( - None, directory=models_dir, caption="Select directory to save scripts" - ) + def save( + self, output_dir: Optional[str] = None, labels_filename: Optional[str] = None + ): + """Save scripts and configs to run pipeline.""" + if output_dir is None: + models_dir = os.path.join(os.path.dirname(self.labels_filename), "/models") + output_dir = FileDialog.openDir( + None, directory=models_dir, caption="Select directory to save scripts" + ) - if not output_dir: - return + if not output_dir: + return pipeline_form_data = self.pipeline_form_widget.get_form_data() items_for_inference = self.get_items_for_inference(pipeline_form_data) config_info_list = self.get_every_head_config_data(pipeline_form_data) + if labels_filename is None: + labels_filename = self.labels_filename + runners.write_pipeline_files( output_dir=output_dir, - labels_filename=self.labels_filename, + labels_filename=labels_filename, config_info_list=config_info_list, inference_params=pipeline_form_data, items_for_inference=items_for_inference, ) + def export_package(self, output_path: Optional[str] = None, gui: bool = True): + """Export training job package.""" + # TODO: Warn if self.mode != "training"? + if output_path is None: + # Prompt for output path. + output_path, _ = FileDialog.save( + caption="Export Training Job Package...", + dir=f"{self.labels_filename}.training_job.zip", + filter="Training Job Package (*.zip)", + ) + if len(output_path) == 0: + return + + # Create temp dir before packaging. + tmp_dir = tempfile.TemporaryDirectory() + + # Remove the temp dir when program exits in case something goes wrong. + # atexit.register(shutil.rmtree, tmp_dir.name, ignore_errors=True) + + # Check if we need to include suggestions. + include_suggestions = False + items_for_inference = self.get_items_for_inference( + self.pipeline_form_widget.get_form_data() + ) + for item in items_for_inference.items: + if ( + isinstance(item, runners.DatasetItemForInference) + and item.frame_filter == "suggested" + ): + include_suggestions = True + + # Save dataset with images. + labels_pkg_filename = str( + Path(self.labels_filename).with_suffix(".pkg.slp").name + ) + if gui: + ret = sleap.gui.commands.export_dataset_gui( + self.labels, + tmp_dir.name + "/" + labels_pkg_filename, + all_labeled=False, + suggested=include_suggestions, + ) + if ret == "canceled": + # Quit if user canceled during export. + tmp_dir.cleanup() + return + else: + self.labels.save( + tmp_dir.name + "/" + labels_pkg_filename, + with_images=True, + embed_all_labeled=False, + embed_suggested=include_suggestions, + ) + + # Save config and scripts. + self.save(tmp_dir.name, labels_filename=labels_pkg_filename) + + # Package everything. + shutil.make_archive( + base_name=str(Path(output_path).with_suffix("")), + format="zip", + root_dir=tmp_dir.name, + ) + + msg = f"Saved training job package to: {output_path}" + print(msg) + + # Close training editor. + self.accept() + + if gui: + msgBox = QtWidgets.QMessageBox(text=f"Created training job package:") + msgBox.setDetailedText(output_path) + msgBox.setWindowTitle("Training Job Package") + okButton = msgBox.addButton(QtWidgets.QMessageBox.Ok) + openFolderButton = msgBox.addButton( + "Open containing folder", QtWidgets.QMessageBox.ActionRole + ) + colabButton = msgBox.addButton( + "Go to Colab", QtWidgets.QMessageBox.ActionRole + ) + msgBox.exec_() + + if msgBox.clickedButton() == openFolderButton: + sleap.gui.commands.open_file(str(Path(output_path).resolve().parent)) + elif msgBox.clickedButton() == colabButton: + # TODO: Update this to more workflow-tailored notebook. + sleap.gui.commands.copy_to_clipboard(output_path) + sleap.gui.commands.open_website( + "https://colab.research.google.com/github/murthylab/sleap-notebooks/blob/master/Training_and_inference_using_Google_Drive.ipynb" + ) + + tmp_dir.cleanup() + class TrainingPipelineWidget(QtWidgets.QWidget): """ @@ -619,7 +736,8 @@ def __init__( if hasattr(skeleton, "node_names"): for field_name in NODE_LIST_FIELDS: self.form_widget.set_field_options( - field_name, skeleton.node_names, + field_name, + skeleton.node_names, ) # Connect actions for change to pipeline @@ -733,7 +851,8 @@ def __init__( for field_name in NODE_LIST_FIELDS: form_name = field_name.split(".")[0] self.form_widgets[form_name].set_field_options( - field_name, skeleton.node_names, + field_name, + skeleton.node_names, ) if self._video: @@ -806,7 +925,7 @@ def __init__( @classmethod def from_trained_config(cls, cfg_info: configs.ConfigFileInfo): - widget = cls(require_trained=True) + widget = cls(require_trained=True, head=cfg_info.head_name) widget.acceptSelectedConfigInfo(cfg_info) widget.setWindowTitle(cfg_info.path_dir) return widget @@ -890,7 +1009,9 @@ def _update_use_trained(self, check_state=0): def _set_head(self): if self.head: self.set_fields_from_key_val_dict( - {"_heads_name": self.head,} + { + "_heads_name": self.head, + } ) self.form_widgets["model"].set_field_enabled("_heads_name", False) diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 6a55128c1..4eda57ff1 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -4,9 +4,12 @@ import abc import attr import os -import subprocess as sub +import psutil +import json +import subprocess import tempfile import time +import shutil from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Text, Tuple @@ -17,7 +20,17 @@ from sleap.nn import training from sleap.nn.config import TrainingJobConfig -SKIP_TRAINING = False + +def kill_process(pid: int): + """Force kill a running process and any child processes. + + Args: + proc: A process ID. + """ + proc_ = psutil.Process(pid) + for subproc_ in proc_.children(recursive=True): + subproc_.kill() + proc_.kill() @attr.s(auto_attribs=True) @@ -151,7 +164,10 @@ class InferenceTask: results: List[LabeledFrame] = attr.ib(default=attr.Factory(list)) def make_predict_cli_call( - self, item_for_inference: ItemForInference, output_path: Optional[str] = None + self, + item_for_inference: ItemForInference, + output_path: Optional[str] = None, + gui: bool = True, ) -> List[Text]: """Makes list of CLI arguments needed for running inference.""" cli_args = ["sleap-track"] @@ -222,6 +238,11 @@ def make_predict_cli_call( cli_args.extend(("-o", output_path)) + if gui: + cli_args.extend(("--verbosity", "json")) + + cli_args.extend(("--no-empty-frames",)) + return cli_args, output_path def predict_subprocess( @@ -229,23 +250,43 @@ def predict_subprocess( item_for_inference: ItemForInference, append_results: bool = False, waiting_callback: Optional[Callable] = None, + gui: bool = True, ) -> Tuple[Text, bool]: """Runs inference in a subprocess.""" - cli_args, output_path = self.make_predict_cli_call(item_for_inference) + cli_args, output_path = self.make_predict_cli_call(item_for_inference, gui=gui) print("Command line call:") - print(" \\\n".join(cli_args)) + print(" ".join(cli_args)) print() - with sub.Popen(cli_args) as proc: + # Run inference CLI capturing output. + with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc: + + # Poll until finished. while proc.poll() is None: - if waiting_callback is not None: - if waiting_callback() == -1: - # -1 signals user cancellation - return "", False + # Read line. + line = proc.stdout.readline() + line = line.decode().rstrip() + + if line.startswith("{"): + # Parse line. + line_data = json.loads(line) + else: + # Pass through non-json output. + print(line) + line_data = {} - time.sleep(0.1) + if waiting_callback is not None: + # Pass line data to callback. + ret = waiting_callback(**line_data) + + if ret == "cancel": + # Stop if callback returned cancel signal. + kill_process(proc.pid) + print(f"Killed PID: {proc.pid}") + return "", "canceled" + time.sleep(0.05) print(f"Process return code: {proc.returncode}") success = proc.returncode == 0 @@ -255,7 +296,9 @@ def predict_subprocess( new_inference_labels = Labels.load_file(output_path, match_to=self.labels) self.results.extend(new_inference_labels.labeled_frames) - return output_path, success + # Return "success" or return code if failed. + ret = "success" if success else proc.returncode + return output_path, ret def merge_results(self): """Merges result frames into labels dataset.""" @@ -370,7 +413,8 @@ def write_pipeline_files( # Get list of cli args cli_args, _ = inference_task.make_predict_cli_call( - item_for_inference=item_for_inference, output_path=prediction_output_path, + item_for_inference=item_for_inference, + output_path=prediction_output_path, ) # And join them into a single call to inference inference_script += " ".join(cli_args) + "\n" @@ -420,6 +464,9 @@ def run_learning_pipeline( if None in trained_job_paths.values(): return -1 + if len(items_for_inference) == 0: + return 0 + inference_task = InferenceTask( labels=labels, labels_filename=labels_filename, @@ -497,7 +544,8 @@ def run_gui_training( os.path.dirname(labels_filename), "models" ) training.setup_new_run_folder( - job.outputs, base_run_name=f"{model_type}.{len(labels)}" + job.outputs, + base_run_name=f"{model_type}.n={len(labels.user_labeled_frames)}", ) if gui: @@ -517,9 +565,11 @@ def run_gui_training( def waiting(): if gui: QtWidgets.QApplication.instance().processEvents() + if win.canceled: + return "cancel" # Run training - trained_job_path, success = train_subprocess( + trained_job_path, ret = train_subprocess( job_config=job, labels_filename=labels_filename, video_paths=video_path_list, @@ -527,10 +577,17 @@ def waiting(): save_viz=save_viz, ) - if success: + if ret == "success": # get the path to the resulting TrainingJob file trained_job_paths[model_type] = trained_job_path print(f"Finished training {str(model_type)}.") + elif ret == "canceled": + if gui: + win.close() + print("Deleting canceled run data:", trained_job_path) + shutil.rmtree(trained_job_path, ignore_errors=True) + trained_job_paths[model_type] = None + break else: if gui: win.close() @@ -567,48 +624,93 @@ def run_gui_inference( """ if gui: - # show message while running inference progress = QtWidgets.QProgressDialog( - f"Running inference on {len(items_for_inference)} videos...", + "Initializing...", "Cancel", 0, - len(items_for_inference), + 1, ) progress.show() QtWidgets.QApplication.instance().processEvents() # Make callback to process events while running inference - def waiting(done_count): + def waiting( + n_processed: Optional[int] = None, + n_total: Optional[int] = None, + elapsed: Optional[float] = None, + rate: Optional[float] = None, + eta: Optional[float] = None, + current_item: Optional[int] = None, + total_items: Optional[int] = None, + **kwargs, + ) -> str: if gui: QtWidgets.QApplication.instance().processEvents() - progress.setValue(done_count) + if n_total is not None: + progress.setMaximum(n_total) + if n_processed is not None: + progress.setValue(n_processed) + + msg = "Predicting..." + # if current_item is not None and total_items is not None: + # msg += f" [{current_item + 1}/{total_items}]" + + if n_processed is not None and n_total is not None: + + prc = (n_processed / n_total) * 100 + msg = f"Predicted: {n_processed+1}/{n_total} ({prc:.1f}%)" + + # Show time elapsed? + if rate is not None and eta is not None: + eta_mins, eta_secs = divmod(eta, 60) + if eta_mins > 60: + eta_hours, eta_mins = divmod(eta, 60) + eta_str = f"{int(eta_hours)}:{int(eta_mins):02}:{int(eta_secs):02}" + else: + eta_str = f"{int(eta_mins)}:{int(eta_secs):02}" + msg += f"
ETA: {eta_str} FPS: {rate:.1f}" + + msg = msg.replace(" ", " ") + + progress.setLabelText(msg) + QtWidgets.QApplication.instance().processEvents() + if progress.wasCanceled(): - return -1 + return "cancel" for i, item_for_inference in enumerate(items_for_inference.items): - # Run inference for desired frames in this video - predictions_path, success = inference_task.predict_subprocess( - item_for_inference, append_results=True, waiting_callback=lambda: waiting(i) + + def waiting_item(**kwargs): + kwargs["current_item"] = i + kwargs["total_items"] = len(items_for_inference.items) + return waiting(**kwargs) + + # Run inference for desired frames in this video. + predictions_path, ret = inference_task.predict_subprocess( + item_for_inference, + append_results=True, + waiting_callback=waiting_item, + gui=gui, ) - if not success: + if gui: + progress.close() + + if ret == "success": + inference_task.merge_results() + return len(inference_task.results) + elif ret == "canceled": + return -1 + else: if gui: - progress.close() QtWidgets.QMessageBox( - text="An error occcured during inference. Your command line " - "terminal may have more information about the error." + text=( + "An error occcured during inference. Your command line " + "terminal may have more information about the error." + ) ).exec_() return -1 - inference_task.merge_results() - - # close message window - if gui: - progress.close() - - # return total_new_lf_count - return len(inference_task.results) - def train_subprocess( job_config: TrainingJobConfig, @@ -618,8 +720,6 @@ def train_subprocess( save_viz: bool = False, ): """Runs training inside subprocess.""" - - # run_name = job_config.outputs.run_name run_path = job_config.outputs.run_path success = False @@ -654,20 +754,26 @@ def train_subprocess( print(cli_args) - if not SKIP_TRAINING: - # Run training in a subprocess - with sub.Popen(cli_args) as proc: - - # Wait till training is done, calling a callback if given. - while proc.poll() is None: - if waiting_callback is not None: - if waiting_callback() == -1: - # -1 signals user cancellation - return "", False - time.sleep(0.1) - - success = proc.returncode == 0 + # Run training in a subprocess. + proc = subprocess.Popen(cli_args) + + # Wait till training is done, calling a callback if given. + while proc.poll() is None: + if waiting_callback is not None: + ret = waiting_callback() + if ret == "cancel": + print("Canceling training...") + kill_process(proc.pid) + print(f"Killed PID: {proc.pid}") + return run_path, "canceled" + time.sleep(0.1) + + # Check return code. + if proc.returncode == 0: + ret = "success" + else: + ret = proc.returncode print("Run Path:", run_path) - return run_path, success + return run_path, ret diff --git a/sleap/gui/learning/scopedkeydict.py b/sleap/gui/learning/scopedkeydict.py index 7d298d5e3..d10867ddc 100644 --- a/sleap/gui/learning/scopedkeydict.py +++ b/sleap/gui/learning/scopedkeydict.py @@ -119,6 +119,18 @@ def apply_cfg_transforms_to_key_val_dict(key_val_dict: dict): key_val_dict[f"model.backbone.{backbone_name}.output_stride"] = output_stride key_val_dict[f"model.backbone.{backbone_name}.max_stride"] = max_stride + # Convert random flip dropdown selection to config. + random_flip = key_val_dict.get( + "optimization.augmentation_config.random_flip", "none" + ) + if random_flip == "none": + key_val_dict["optimization.augmentation_config.random_flip"] = False + else: + key_val_dict["optimization.augmentation_config.random_flip"] = True + key_val_dict["optimization.augmentation_config.flip_horizontal"] = ( + random_flip == "horizontal" + ) + def find_backbone_name_from_key_val_dict(key_val_dict: dict): """Find the backbone model name from the config dictionary.""" diff --git a/sleap/gui/overlays/confmaps.py b/sleap/gui/overlays/confmaps.py index 71d45dcc5..8f6ca1cc7 100644 --- a/sleap/gui/overlays/confmaps.py +++ b/sleap/gui/overlays/confmaps.py @@ -57,13 +57,11 @@ def __init__( ) def boundingRect(self) -> QtCore.QRectF: - """Method required by Qt. - """ + """Method required by Qt.""" return self.rect def paint(self, painter, option, widget=None): - """Method required by Qt. - """ + """Method required by Qt.""" pass diff --git a/sleap/gui/overlays/instance.py b/sleap/gui/overlays/instance.py index 998e15b6f..7ec8ae894 100644 --- a/sleap/gui/overlays/instance.py +++ b/sleap/gui/overlays/instance.py @@ -43,7 +43,9 @@ def add_to_scene(self, video, frame_idx): has_user = any((True for inst in instances if not hasattr(inst, "score"))) for instance in instances: - self.player.addInstance(instance=instance) + self.player.addInstance( + instance=instance, markerRadius=self.state.get("marker size", 4) + ) self.player.showLabels(self.state.get("show labels", default=True)) self.player.showEdges(self.state.get("show edges", default=True)) diff --git a/sleap/gui/overlays/pafs.py b/sleap/gui/overlays/pafs.py index 3c5ca82d6..55c51c869 100644 --- a/sleap/gui/overlays/pafs.py +++ b/sleap/gui/overlays/pafs.py @@ -82,13 +82,11 @@ def __init__( self.affinity_field.append(aff_field_item) def boundingRect(self) -> QtCore.QRectF: - """Method required by Qt. - """ + """Method required by Qt.""" return QtCore.QRectF() def paint(self, painter, option, widget=None): - """Method required by Qt. - """ + """Method required by Qt.""" pass diff --git a/sleap/gui/release_checker.py b/sleap/gui/release_checker.py index 4a57b1ab7..6c167e7d0 100644 --- a/sleap/gui/release_checker.py +++ b/sleap/gui/release_checker.py @@ -65,9 +65,7 @@ class ReleaseChecker: """ repo_id: str = REPO_ID - releases: List[Release] = attr.ib( - factory=list, converter=filter_test_releases - ) + releases: List[Release] = attr.ib(factory=list, converter=filter_test_releases) checked: bool = attr.ib(default=False, init=False) def check_for_releases(self) -> bool: diff --git a/sleap/gui/shortcuts.py b/sleap/gui/shortcuts.py index 280081d9f..ec76e45a3 100644 --- a/sleap/gui/shortcuts.py +++ b/sleap/gui/shortcuts.py @@ -3,9 +3,7 @@ """ from typing import Dict, Union - from PySide2.QtGui import QKeySequence - from sleap import util @@ -100,6 +98,11 @@ def save(self): util.save_config_yaml("shortcuts.yaml", data) + def reset_to_default(self): + """Reset shortcuts to default and save.""" + self._shortcuts = util.get_config_yaml("shortcuts.yaml", get_defaults=True) + self.save() + def __getitem__(self, idx: Union[slice, int, str]) -> Union[str, Dict[str, str]]: """ Returns shortcut value, accessed by range, index, or key. diff --git a/sleap/gui/widgets/multicheck.py b/sleap/gui/widgets/multicheck.py index 5af7137e5..08e2a5c57 100644 --- a/sleap/gui/widgets/multicheck.py +++ b/sleap/gui/widgets/multicheck.py @@ -91,11 +91,9 @@ def setSelected(self, selected: list): check_button.setChecked(False) def boundingRect(self) -> QRectF: - """Method required by Qt. - """ + """Method required by Qt.""" return QRectF() def paint(self, painter, option, widget=None): - """Method required by Qt. - """ + """Method required by Qt.""" pass diff --git a/sleap/gui/widgets/slider.py b/sleap/gui/widgets/slider.py index 54bc7af74..876dd5724 100644 --- a/sleap/gui/widgets/slider.py +++ b/sleap/gui/widgets/slider.py @@ -487,7 +487,10 @@ def _update_selection_box_positions(self, box_object, a: float, b: float): start_pos = self._toPos(start, center=True) end_pos = self._toPos(end, center=True) box_rect = QtCore.QRect( - start_pos, self._header_height, end_pos - start_pos, self.box_rect.height(), + start_pos, + self._header_height, + end_pos - start_pos, + self.box_rect.height(), ) box_object.setRect(box_rect) diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 4b62d6ebc..34789bb96 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -396,8 +396,7 @@ def load_video(self, video: Video, plot=True): self.plot() def reset(self): - """ Reset viewer by removing all video data. - """ + """Reset viewer by removing all video data.""" self.video = None self.state["frame_idx"] = None self.view.clear() @@ -466,7 +465,7 @@ def plot(self, *args): self._video_image_loader.request(idx) def showLabels(self, show): - """ Show/hide node labels for all instances in viewer. + """Show/hide node labels for all instances in viewer. Args: show: Show if True, hide otherwise. @@ -475,7 +474,7 @@ def showLabels(self, show): inst.showLabels(show) def showEdges(self, show): - """ Show/hide node edges for all instances in viewer. + """Show/hide node edges for all instances in viewer. Args: show: Show if True, hide otherwise. @@ -489,8 +488,7 @@ def highlightPredictions(self, highlight_text: str = ""): inst.highlight_text = highlight_text def zoomToFit(self): - """ Zoom view to fit all instances. - """ + """Zoom view to fit all instances.""" zoom_rect = self.view.instancesBoundingRect(margin=20) if not zoom_rect.size().isEmpty(): self.view.zoomToRect(zoom_rect) @@ -755,13 +753,11 @@ def __init__(self, state=None, player=None, *args, **kwargs): self.setTransformationAnchor(anchor_mode) def hasImage(self) -> bool: - """ Returns whether or not the scene contains an image pixmap. - """ + """Returns whether or not the scene contains an image pixmap.""" return self._pixmapHandle is not None def clear(self): - """ Clears the displayed frame from the scene. - """ + """Clears the displayed frame from the scene.""" if self._pixmapHandle: # get the pixmap currently shown @@ -890,7 +886,7 @@ def selectInstance(self, select: Union[Instance, int]): self.updatedSelection.emit() def getSelectionIndex(self) -> Optional[int]: - """ Returns the index of the currently selected instance. + """Returns the index of the currently selected instance. If no instance selected, returns None. """ instances = self.all_instances @@ -901,7 +897,7 @@ def getSelectionIndex(self) -> Optional[int]: return idx def getSelectionInstance(self) -> Optional[Instance]: - """ Returns the currently selected instance. + """Returns the currently selected instance. If no instance selected, returns None. """ instances = self.all_instances @@ -928,13 +924,11 @@ def is_selectable(item): return None def resizeEvent(self, event): - """ Maintain current zoom on resize. - """ + """Maintain current zoom on resize.""" self.updateViewer() def mousePressEvent(self, event): - """ Start mouse pan or zoom mode. - """ + """Start mouse pan or zoom mode.""" scenePos = self.mapToScene(event.pos()) # keep track of click location self._down_pos = event.pos() @@ -962,8 +956,7 @@ def mousePressEvent(self, event): QGraphicsView.mousePressEvent(self, event) def mouseReleaseEvent(self, event): - """ Stop mouse pan or zoom mode (apply zoom if valid). - """ + """Stop mouse pan or zoom mode (apply zoom if valid).""" QGraphicsView.mouseReleaseEvent(self, event) scenePos = self.mapToScene(event.pos()) @@ -1037,8 +1030,7 @@ def zoomToRect(self, zoom_rect: QRectF): self.centerOn(zoom_rect.center()) def clearZoom(self): - """ Clear zoom stack. Doesn't update display. - """ + """Clear zoom stack. Doesn't update display.""" self.zoomFactor = 1 @staticmethod @@ -1077,8 +1069,7 @@ def instancesBoundingRect(self, margin: float = 0) -> QRectF: return GraphicsView.getInstancesBoundingRect(self.all_instances, margin=margin) def mouseDoubleClickEvent(self, event: QMouseEvent): - """ Custom event handler, clears zoom. - """ + """Custom event handler, clears zoom.""" scenePos = self.mapToScene(event.pos()) if event.button() == Qt.LeftButton: @@ -1093,8 +1084,7 @@ def mouseDoubleClickEvent(self, event: QMouseEvent): QGraphicsView.mouseDoubleClickEvent(self, event) def wheelEvent(self, event): - """ Custom event handler. Zoom in/out based on scroll wheel change. - """ + """Custom event handler. Zoom in/out based on scroll wheel change.""" # zoom on wheel when no mouse buttons are pressed if event.buttons() == Qt.NoButton: angle = event.angleDelta().y() @@ -1166,7 +1156,7 @@ def __init__( self.adjustStyle() def adjustPos(self, *args, **kwargs): - """ Update the position of the label based on the position of the node. + """Update the position of the label based on the position of the node. Args: Accepts arbitrary arguments so we can connect to various signals. @@ -1220,8 +1210,7 @@ def adjustPos(self, *args, **kwargs): self.adjustStyle() def adjustStyle(self): - """ Update visual display of the label and its node. - """ + """Update visual display of the label and its node.""" complete_color = ( QColor(80, 194, 159) if self.node.point.complete else QColor(232, 45, 32) @@ -1248,35 +1237,29 @@ def adjustStyle(self): self.setDefaultTextColor(complete_color) # redish def boundingRect(self): - """ Method required by Qt. - """ + """Method required by Qt.""" return super(QtNodeLabel, self).boundingRect() def paint(self, *args, **kwargs): - """ Method required by Qt. - """ + """Method required by Qt.""" super(QtNodeLabel, self).paint(*args, **kwargs) def mousePressEvent(self, event): - """ Pass events along so that clicking label is like clicking node. - """ + """Pass events along so that clicking label is like clicking node.""" self.setCursor(Qt.ArrowCursor) self.node.mousePressEvent(event) def mouseMoveEvent(self, event): - """ Pass events along so that clicking label is like clicking node. - """ + """Pass events along so that clicking label is like clicking node.""" self.node.mouseMoveEvent(event) def mouseReleaseEvent(self, event): - """ Pass events along so that clicking label is like clicking node. - """ + """Pass events along so that clicking label is like clicking node.""" self.unsetCursor() self.node.mouseReleaseEvent(event) def wheelEvent(self, event): - """ Pass events along so that clicking label is like clicking node. - """ + """Pass events along so that clicking label is like clicking node.""" self.node.wheelEvent(event) @@ -1375,8 +1358,7 @@ def __init__( self.updatePoint(user_change=False) def calls(self): - """ Method to call all callbacks. - """ + """Method to call all callbacks.""" for callback in self.callbacks: if callable(callback): callback(self) @@ -1437,8 +1419,7 @@ def toggleVisibility(self): self.point.visible = visible def mousePressEvent(self, event): - """ Custom event handler for mouse press. - """ + """Custom event handler for mouse press.""" # Do nothing if node is from predicted instance if self.parentObject().predicted: return @@ -1474,8 +1455,7 @@ def mousePressEvent(self, event): pass def mouseMoveEvent(self, event): - """ Custom event handler for mouse move. - """ + """Custom event handler for mouse move.""" # print(event) if self.dragParent: self.parentObject().mouseMoveEvent(event) @@ -1486,8 +1466,7 @@ def mouseMoveEvent(self, event): ) # don't count change until mouse release def mouseReleaseEvent(self, event): - """ Custom event handler for mouse release. - """ + """Custom event handler for mouse release.""" # print(event) self.unsetCursor() if self.dragParent: @@ -1543,12 +1522,18 @@ def __init__( self.show_non_visible = show_non_visible super(QtEdge, self).__init__( - polygon=QPolygonF(), parent=parent, *args, **kwargs, + polygon=QPolygonF(), + parent=parent, + *args, + **kwargs, ) self.setLine( QLineF( - self.src.point.x, self.src.point.y, self.dst.point.x, self.dst.point.y, + self.src.point.x, + self.src.point.y, + self.dst.point.x, + self.dst.point.y, ) ) @@ -1935,13 +1920,11 @@ def showEdges(self, show: bool): self.edges_shown = show def boundingRect(self): - """ Method required Qt to determine bounding rect for item. - """ + """Method required Qt to determine bounding rect for item.""" return self._bounding_rect def paint(self, painter, option, widget=None): - """ Method required by Qt. - """ + """Method required by Qt.""" pass @@ -1957,13 +1940,11 @@ def __init__(self, *args, **kwargs): self.setFlag(QGraphicsItem.ItemIgnoresTransformations) def boundingRect(self): - """ Method required by Qt. - """ + """Method required by Qt.""" return super(QtTextWithBackground, self).boundingRect() def paint(self, painter, option, *args, **kwargs): - """ Method required by Qt. - """ + """Method required by Qt.""" text_color = self.defaultTextColor() brush = painter.brush() background_color = "white" if text_color.lightnessF() < 0.4 else "black" diff --git a/sleap/info/feature_suggestions.py b/sleap/info/feature_suggestions.py index 22ca4ca5e..51f9038a5 100644 --- a/sleap/info/feature_suggestions.py +++ b/sleap/info/feature_suggestions.py @@ -344,7 +344,7 @@ def sample(self, per_group: int, unique_samples: bool = True): class ItemStack(object): """ Container for items, each item can "own" one or more rows of data. - + Attributes: items: The list of items data: An ndarray with rows of data corresponding to items. @@ -352,7 +352,7 @@ class ItemStack(object): items. meta: List which stores metadata about each operation on stack. group_sets: List of GroupSets of items. - + """ items: List = attr.ib(default=attr.Factory(list)) diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index 3412d2626..2ac61d339 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -213,7 +213,7 @@ def compare_instance_lists( instances_b: List[Union[Instance, PredictedInstance]], ) -> np.ndarray: """Given two lists of corresponding Instances, returns - (instances * nodes) matrix of distances between corresponding nodes.""" + (instances * nodes) matrix of distances between corresponding nodes.""" paired_points_array_distances = [] for inst_a, inst_b in zip(instances_a, instances_b): diff --git a/sleap/instance.py b/sleap/instance.py index 3399ccf2e..214fc8f25 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -319,7 +319,7 @@ class Track: name: A name given to this track for identifying purposes. """ - spawned_on: int = attr.ib(converter=int) + spawned_on: int = attr.ib(default=0, converter=int) name: str = attr.ib(default="", converter=str) def matches(self, other: "Track"): @@ -341,10 +341,9 @@ def matches(self, other: "Track"): # that are created in post init so they are not serialized. -@attr.s(eq=False, order=False, slots=True) +@attr.s(eq=False, order=False, slots=True, repr=False, str=False) class Instance: - """ - The class :class:`Instance` represents a labelled instance of a skeleton. + """This class represents a labeled instance. Args: skeleton: The skeleton that this instance is associated with. @@ -375,8 +374,7 @@ class Instance: def _validate_from_predicted_( self, attribute, from_predicted: Optional["PredictedInstance"] ): - """ - Validation method called by attrs. + """Validation method called by attrs. Checks that from_predicted is None or :class:`PredictedInstance` @@ -391,13 +389,13 @@ def _validate_from_predicted_( """ if from_predicted is not None and type(from_predicted) != PredictedInstance: raise TypeError( - f"Instance.from_predicted type must be PredictedInstance (not {type(from_predicted)})" + f"Instance.from_predicted type must be PredictedInstance (not " + "{type(from_predicted)})" ) @_points.validator def _validate_all_points(self, attribute, points: Union[dict, PointArray]): - """ - Validation method called by attrs. + """Validation method called by attrs. Checks that all the _points defined for the skeleton are found in the skeleton. @@ -425,12 +423,12 @@ def _validate_all_points(self, attribute, points: Union[dict, PointArray]): elif isinstance(points, PointArray): if len(points) != len(self.skeleton.nodes): raise ValueError( - "PointArray does not have the same number of rows as skeleton nodes." + "PointArray does not have the same number of rows as skeleton " + "nodes." ) def __attrs_post_init__(self): - """ - Method called by attrs after __init__() + """Method called by attrs after __init__(). Initializes points if none were specified when creating object, caches list of nodes so what we can still find points in array @@ -441,12 +439,8 @@ def __attrs_post_init__(self): Raises: ValueError: If object has no `Skeleton`. - - Returns: - None """ - - if not self.skeleton: + if self.skeleton is None: raise ValueError("No skeleton set for Instance") # If the user did not pass a points list initialize a point array for future @@ -472,8 +466,7 @@ def __attrs_post_init__(self): def _points_dict_to_array( points: Dict[Union[str, Node], Point], parray: PointArray, skeleton: Skeleton ): - """ - Sets values in given :class:`PointsArray` from dictionary. + """Set values in given :class:`PointsArray` from dictionary. Args: points: The dictionary of points. Keys can be either node @@ -485,11 +478,7 @@ def _points_dict_to_array( Raises: ValueError: If dictionary keys are not either all strings or all :class:`Node`s. - - Returns: - None """ - # Check if the dict contains all strings is_string_dict = set(map(type, points)) == {str} @@ -522,8 +511,7 @@ def _points_dict_to_array( pass def _node_to_index(self, node: Union[str, Node]) -> int: - """ - Helper method to get the index of a node from its name. + """Helper method to get the index of a node from its name. Args: node: Node name or :class:`Node` object. @@ -534,10 +522,10 @@ def _node_to_index(self, node: Union[str, Node]) -> int: return self.skeleton.node_to_index(node) def __getitem__( - self, node: Union[List[Union[str, Node]], Union[str, Node]] - ) -> Union[List[Point], Point]: - """ - Get the Points associated with particular skeleton node(s). + self, + node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray], + ) -> Union[List[Point], Point, np.ndarray]: + """Get the Points associated with particular skeleton node(s). Args: node: A single node or list of nodes within the skeleton @@ -552,26 +540,29 @@ def __getitem__( to each node. """ - - # If the node is a list of nodes, use get item recursively and return a list of _points. - if type(node) is list: - ret_list = [] + # If the node is a list of nodes, use get item recursively and return a list of + # _points. + if isinstance(node, (list, tuple, np.ndarray)): + pts = [] for n in node: - ret_list.append(self.__getitem__(n)) + pts.append(self.__getitem__(n)) - return ret_list + if isinstance(node, np.ndarray): + return np.array([[pt.x, pt.y] for pt in pts]) + else: + return pts - try: - node = self._node_to_index(node) - return self._points[node] - except ValueError: - raise KeyError( - f"The underlying skeleton ({self.skeleton}) has no node '{node}'" - ) + if isinstance(node, (Node, str)): + try: + node = self._node_to_index(node) + except ValueError: + raise KeyError( + f"The underlying skeleton ({self.skeleton}) has no node '{node}'" + ) + return self._points[node] - def __contains__(self, node: Union[str, Node]) -> bool: - """ - Whether this instance has a point with the specified node. + def __contains__(self, node: Union[str, Node, int]) -> bool: + """Whether this instance has a point with the specified node. Args: node: Node name or :class:`Node` object. @@ -584,21 +575,20 @@ def __contains__(self, node: Union[str, Node]) -> bool: if isinstance(node, Node): node = node.name - if node not in self.skeleton: - return False - - node_idx = self._node_to_index(node) + if isinstance(node, str): + if node not in self.skeleton: + return False + node = self._node_to_index(node) # If the points are nan, then they haven't been allocated. - return not self._points[node_idx].isnan() + return not self._points[node].isnan() def __setitem__( self, - node: Union[List[Union[str, Node]], Union[str, Node]], - value: Union[List[Point], Point], + node: Union[List[Union[str, Node, int]], Union[str, Node, int], np.ndarray], + value: Union[List[Point], Point, np.ndarray], ): - """ - Set the point(s) for given node(s). + """Set the point(s) for given node(s). Args: node: Either node (by name or `Node`) or list of nodes. @@ -608,39 +598,38 @@ def __setitem__( IndexError: If lengths of lists don't match, or if exactly one of the inputs is a list. KeyError: If skeleton does not have (one of) the node(s). - - Returns: - None """ - # Make sure node and value, if either are lists, are of compatible size - if type(node) is not list and type(value) is list and len(value) != 1: - raise IndexError( - "Node list for indexing must be same length and value list." - ) - - if type(node) is list and type(value) is not list and len(node) != 1: - raise IndexError( - "Node list for indexing must be same length and value list." - ) + if isinstance(node, (list, np.ndarray)): + if not isinstance(value, (list, np.ndarray)) or len(value) != len(node): + raise IndexError( + "Node list for indexing must be same length and value list." + ) - # If we are dealing with lists, do multiple assignment recursively, this should be ok because - # skeletons and instances are small. - if type(node) is list: for n, v in zip(node, value): self.__setitem__(n, v) else: - try: - node_idx = self._node_to_index(node) - self._points[node_idx] = value - except ValueError: - raise KeyError( - f"The underlying skeleton ({self.skeleton}) has no node '{node}'" - ) + if isinstance(node, (Node, str)): + try: + node_idx = self._node_to_index(node) + except ValueError: + raise KeyError( + f"The skeleton ({self.skeleton}) has no node '{node}'." + ) + else: + node_idx = node + + if not isinstance(value, Point): + if hasattr(value, "__len__") and len(value) == 2: + value = Point(x=value[0], y=value[1]) + else: + raise ValueError( + "Instance point values must be (x, y) coordinates." + ) + self._points[node_idx] = value def __delitem__(self, node: Union[str, Node]): - """ - Delete node key and points associated with that node. + """Delete node key and points associated with that node. Args: node: Node name or :class:`Node` object. @@ -660,9 +649,24 @@ def __delitem__(self, node: Union[str, Node]): f"The underlying skeleton ({self.skeleton}) has no node '{node}'" ) + def __repr__(self) -> str: + """Return string representation of this object.""" + pts = [] + for node, pt in self.nodes_points: + pts.append(f"{node.name}: ({pt.x:.1f}, {pt.y:.1f})") + pts = ", ".join(pts) + + return ( + "Instance(" + f"video={self.video}, " + f"frame_idx={self.frame_idx}, " + f"points=[{pts}], " + f"track={self.track}" + ")" + ) + def matches(self, other: "Instance") -> bool: - """ - Whether two instances match by value. + """Whether two instances match by value. Checks the types, points, track, and frame index. @@ -695,9 +699,7 @@ def matches(self, other: "Instance") -> bool: @property def nodes(self) -> Tuple[Node, ...]: - """ - The tuple of nodes that have been labelled for this instance. - """ + """Return nodes that have been labelled for this instance.""" self._fix_array() return tuple( self._nodes[i] @@ -707,32 +709,23 @@ def nodes(self) -> Tuple[Node, ...]: @property def nodes_points(self) -> List[Tuple[Node, Point]]: - """ - The list of (node, point) tuples for all labelled points. - """ + """Return a list of (node, point) tuples for all labeled points.""" names_to_points = dict(zip(self.nodes, self.points)) return names_to_points.items() @property def points(self) -> Tuple[Point, ...]: - """ - The tuple of labelled points, in order they were labelled. - """ + """Return a tuple of labelled points, in the order they were labelled.""" self._fix_array() return tuple(point for point in self._points if not point.isnan()) def _fix_array(self): - """ - Fixes PointArray after nodes have been added or removed. + """Fix PointArray after nodes have been added or removed. This updates the PointArray as required by comparing the cached list of nodes to the nodes in the `Skeleton` object (which may have changed). - - Returns: - None """ - # Check if cached skeleton nodes are different than current nodes if self._nodes != self.skeleton.nodes: # Create new PointArray (or PredictedPointArray) @@ -751,8 +744,7 @@ def _fix_array(self): def get_points_array( self, copy: bool = True, invisible_as_nan: bool = False, full: bool = False ) -> Union[np.ndarray, np.recarray]: - """ - Return the instance's points in array form. + """Return the instance's points in array form. Args: copy: If True, the return a copy of the points array as an ndarray. @@ -795,14 +787,13 @@ def get_points_array( @property def points_array(self) -> np.ndarray: - """ - Nx2 array of x and y for visible points. + """Return array of x and y coordinates for visible points. - Row in array corresponds to order of points in skeleton. - Invisible points will have nans. + Row in array corresponds to order of points in skeleton. Invisible points will + be denoted by NaNs. Returns: - ndarray of visible point coordinates. + A numpy array of of shape `(n_nodes, 2)` point coordinates. """ return self.get_points_array(invisible_as_nan=True) @@ -812,13 +803,19 @@ def numpy(self) -> np.ndarray: Alias for `points_array`. Returns: - Array of shape (n_nodes, 2) of dtype float32 containing the coordinates of - the instance's nodes. Missing/not visible nodes will be replaced with NaNs. + Array of shape `(n_nodes, 2)` of dtype `float32` containing the coordinates + of the instance's nodes. Missing/not visible nodes will be replaced with + `NaN`. """ return self.points_array def transform_points(self, transformation_matrix): - """Applies transformation matrix to points.""" + """Apply affine transformation matrix to points in the instance. + + Args: + transformation_matrix: Affine transformation matrix as a numpy array of + shape `(3, 3)`. + """ points = self.get_points_array(copy=True, full=False, invisible_as_nan=False) if transformation_matrix.shape[1] == 3: @@ -835,36 +832,50 @@ def transform_points(self, transformation_matrix): @property def centroid(self) -> np.ndarray: - """Returns instance centroid as (x,y) numpy row vector.""" + """Return instance centroid as an array of `(x, y)` coordinates + + Notes: + This computes the centroid as the median of the visible points. + """ points = self.points_array centroid = np.nanmedian(points, axis=0) return centroid @property def bounding_box(self) -> np.ndarray: - """Returns the instance's containing bounding box in [y1, x1, y2, x2] format.""" + """Return bounding box containing all points in `[y1, x1, y2, x2]` format.""" points = self.points_array bbox = np.concatenate( [np.nanmin(points, axis=0)[::-1], np.nanmax(points, axis=0)[::-1]] ) return bbox + @property + def midpoint(self) -> np.ndarray: + """Return the center of the bounding box of the instance points.""" + y1, x1, y2, x2 = self.bounding_box + return np.array([(x2 - x1) / 2, (y2 - y1) / 2]) + @property def n_visible_points(self) -> int: - """Returns the count of points that are visible in this instance.""" + """Return the number of visible points in this instance.""" return sum(~np.isnan(self.points_array[:, 0])) - @property - def frame_idx(self) -> Optional[int]: - """ - Get the index of the frame that this instance was found on. + def __len__(self) -> int: + """Return the number of visible points in this instance.""" + return self.n_visible_points - This is a convenience method for Instance.frame.frame_idx. + @property + def video(self) -> Optional[Video]: + """Return the video of the labeled frame this instance is associated with.""" + if self.frame is None: + return None + else: + return self.frame.video - Returns: - The frame number this instance was found on, or None if the - instance is not associated with frame. - """ + @property + def frame_idx(self) -> Optional[int]: + """Return the index of the labeled frame this instance is associated with.""" if self.frame is None: return None else: @@ -872,19 +883,20 @@ def frame_idx(self) -> Optional[int]: @classmethod def from_pointsarray( - cls, points: np.ndarray, skeleton: Skeleton, track: Optional[Track] = None, + cls, points: np.ndarray, skeleton: Skeleton, track: Optional[Track] = None ) -> "Instance": - """Create an instance from pointsarray. + """Create an instance from an array of points. Args: - points: A numpy array of shape (n_nodes, 2) and dtype float32 that contains - the points in (x, y) coordinates of each node. Missing nodes should be - represented as NaNs. - skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the - predicted instance. + points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that + contains the points in (x, y) coordinates of each node. Missing nodes + should be represented as `NaN`. + skeleton: A `sleap.Skeleton` instance with `n_nodes` nodes to associate with + the instance. + track: Optional `sleap.Track` object to associate with the instance. Returns: - A new Instance. + A new `Instance` object. """ predicted_points = dict() for point, node_name in zip(points, skeleton.node_names): @@ -895,8 +907,30 @@ def from_pointsarray( return cls(points=predicted_points, skeleton=skeleton, track=track) + @classmethod + def from_numpy( + cls, points: np.ndarray, skeleton: Skeleton, track: Optional[Track] = None + ) -> "Instance": + """Create an instance from a numpy array. + + Args: + points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that + contains the points in (x, y) coordinates of each node. Missing nodes + should be represented as `NaN`. + skeleton: A `sleap.Skeleton` instance with `n_nodes` nodes to associate with + the instance. + track: Optional `sleap.Track` object to associate with the instance. -@attr.s(eq=False, order=False, slots=True) + Returns: + A new `Instance` object. + + Notes: + This is an alias for `Instance.from_pointsarray()`. + """ + return cls.from_pointsarray(points, skeleton, track=track) + + +@attr.s(eq=False, order=False, slots=True, repr=False, str=False) class PredictedInstance(Instance): """ A predicted instance is an output of the inference procedure. @@ -918,35 +952,55 @@ def __attrs_post_init__(self): if self.from_predicted is not None: raise ValueError("PredictedInstance should not have from_predicted.") + def __repr__(self) -> str: + """Return string representation of this object.""" + pts = [] + for node, pt in self.nodes_points: + pts.append(f"{node.name}: ({pt.x:.1f}, {pt.y:.1f}, {pt.score:.2f})") + pts = ", ".join(pts) + + return ( + "PredictedInstance(" + f"video={self.video}, " + f"frame_idx={self.frame_idx}, " + f"points=[{pts}], " + f"score={self.score:.2f}, " + f"track={self.track}, " + f"tracking_score={self.tracking_score:.2f}" + ")" + ) + @property def points_and_scores_array(self) -> np.ndarray: - """ - (N, 3) array of (x, y, score) for predicted points. + """Return the instance points and scores as an array. - Row in arrow corresponds to order of points in skeleton. - Invisible points will have NaNs. + This will be a `(n_nodes, 3)` array of `(x, y, score)` for each predicted point. - Returns: - ndarray of visible point coordinates and scores. + Rows in the array correspond to the order of points in skeleton. Invisible + points will be represented as NaNs. """ pts = self.get_points_array(full=True, copy=True, invisible_as_nan=True) return pts[:, (0, 1, 4)] # (x, y, score) + @property + def scores(self) -> np.ndarray: + """Return point scores for each predicted node.""" + return self.points_and_scores_array[:, 2] + @classmethod def from_instance(cls, instance: Instance, score: float) -> "PredictedInstance": - """ - Create a :class:`PredictedInstance` from an :class:`Instance`. + """Create a `PredictedInstance` from an `Instance`. - The fields are copied in a shallow manner with the exception of - points. For each point in the instance a :class:`PredictedPoint` - is created with score set to default value. + The fields are copied in a shallow manner with the exception of points. For each + point in the instance a `PredictedPoint` is created with score set to default + value. Args: - instance: The Instance object to shallow copy data from. + instance: The `Instance` object to shallow copy data from. score: The score for this instance. Returns: - A PredictedInstance for the given Instance. + A `PredictedInstance` for the given `Instance`. """ kw_args = attr.asdict( instance, @@ -969,18 +1023,19 @@ def from_arrays( """Create a predicted instance from data arrays. Args: - points: A numpy array of shape (n_nodes, 2) and dtype float32 that contains - the points in (x, y) coordinates of each node. Missing nodes should be - represented as NaNs. - point_confidences: A numpy array of shape (n_nodes,) and dtype float32 that - contains the confidence/score of the points. + points: A numpy array of shape `(n_nodes, 2)` and dtype `float32` that + contains the points in `(x, y)` coordinates of each node. Missing nodes + should be represented as `NaN`. + point_confidences: A numpy array of shape `(n_nodes,)` and dtype `float32` + that contains the confidence/score of the points. instance_score: Scalar float representing the overall instance score, e.g., the PAF grouping score. skeleton: A sleap.Skeleton instance with n_nodes nodes to associate with the predicted instance. + track: Optional `sleap.Track` to associate with the instance. Returns: - A new PredictedInstance. + A new `PredictedInstance`. """ predicted_points = dict() for point, confidence, node_name in zip( @@ -1002,18 +1057,15 @@ def from_arrays( def make_instance_cattr() -> cattr.Converter: - """ - Create a cattr converter for Lists of Instances/PredictedInstances. + """Create a cattr converter for Lists of Instances/PredictedInstances. - This is required because cattrs doesn't automatically detect the - class when the attributes of one class are a subset of another. + This is required because cattrs doesn't automatically detect the class when the + attributes of one class are a subset of another. Returns: - A cattr converter with hooks registered for structuring and - unstructuring :class:`Instance` objects and - :class:`PredictedInstance`s. + A cattr converter with hooks registered for structuring and unstructuring + `Instance` and `PredictedInstance` objects. """ - converter = cattr.Converter() #### UNSTRUCTURE HOOKS @@ -1094,12 +1146,12 @@ def structure_point_array(x, t): @attr.s(auto_attribs=True, eq=False, repr=False, str=False) class LabeledFrame: - """ - Holds labeled data for a single frame of a video. + """Holds labeled data for a single frame of a video. Args: video: The :class:`Video` associated with this frame. frame_idx: The index of frame in video. + instances: List of instances associated with the frame. """ video: Video = attr.ib() @@ -1109,8 +1161,7 @@ class LabeledFrame: ) def __attrs_post_init__(self): - """ - Called by attrs. + """Called by attrs. Updates :attribute:`Instance.frame` for each instance associated with this :class:`LabeledFrame`. @@ -1121,19 +1172,19 @@ def __attrs_post_init__(self): instance.frame = self def __len__(self) -> int: - """Returns number of instances associated with frame.""" + """Return number of instances associated with frame.""" return len(self.instances) def __getitem__(self, index) -> Instance: - """Returns instance (retrieved by index).""" + """Return instance (retrieved by index).""" return self.instances.__getitem__(index) def index(self, value: Instance) -> int: - """Returns index of given :class:`Instance`.""" + """Return index of given :class:`Instance`.""" return self.instances.index(value) def __delitem__(self, index): - """Removes instance (by index) from frame.""" + """Remove instance (by index) from frame.""" value = self.instances.__getitem__(index) self.instances.__delitem__(index) @@ -1151,8 +1202,7 @@ def __repr__(self) -> str: ) def insert(self, index: int, value: Instance): - """ - Adds instance to frame. + """Add instance to frame. Args: index: The index in list of frame instances where we should @@ -1168,8 +1218,7 @@ def insert(self, index: int, value: Instance): value.frame = self def __setitem__(self, index, value: Instance): - """ - Sets nth instance in frame to the given instance. + """Set nth instance in frame to the given instance. Args: index: The index of instance to replace with new instance. @@ -1186,8 +1235,7 @@ def __setitem__(self, index, value: Instance): def find( self, track: Optional[Union[Track, int]] = -1, user: bool = False ) -> List[Instance]: - """ - Retrieves instances (if any) matching specifications. + """Retrieve instances (if any) matching specifications. Args: track: The :class:`Track` to match. Note that None will only @@ -1207,13 +1255,12 @@ def find( @property def instances(self) -> List[Instance]: - """Returns list of all instances associated with this frame.""" + """Return list of all instances associated with this frame.""" return self._instances @instances.setter def instances(self, instances: List[Instance]): - """ - Sets the list of instances associated with this frame. + """Set the list of instances associated with this frame. Updates the `frame` attribute on each instance to the :class:`LabeledFrame` which will contain the instance. @@ -1235,14 +1282,12 @@ def instances(self, instances: List[Instance]): @property def user_instances(self) -> List[Instance]: - """Returns list of user instances associated with this frame.""" - return [ - inst for inst in self._instances if not isinstance(inst, PredictedInstance) - ] + """Return list of user instances associated with this frame.""" + return [inst for inst in self._instances if type(inst) == Instance] @property def training_instances(self) -> List[Instance]: - """Returns list of user instances with points for training.""" + """Return list of user instances with points for training.""" return [ inst for inst in self._instances @@ -1251,23 +1296,22 @@ def training_instances(self) -> List[Instance]: @property def predicted_instances(self) -> List[PredictedInstance]: - """Returns list of predicted instances associated with frame.""" - return [inst for inst in self._instances if isinstance(inst, PredictedInstance)] + """Return list of predicted instances associated with frame.""" + return [inst for inst in self._instances if type(inst) == PredictedInstance] @property def has_user_instances(self) -> bool: - """Whether the frame contains any user instances.""" + """Return whether the frame contains any user instances.""" return len(self.user_instances) > 0 @property def has_predicted_instances(self) -> bool: - """Whether the frame contains any predicted instances.""" + """Return whether the frame contains any predicted instances.""" return len(self.predicted_instances) > 0 @property def unused_predictions(self) -> List[Instance]: - """ - Returns list of "unused" :class:`PredictedInstance` objects in frame. + """Return a list of "unused" :class:`PredictedInstance` objects in frame. This is all the :class:`PredictedInstance` objects which do not have a corresponding :class:`Instance` in the same track in frame. @@ -1305,8 +1349,7 @@ def unused_predictions(self) -> List[Instance]: @property def instances_to_show(self) -> List[Instance]: - """ - Return a list of instances to show in GUI for this frame. + """Return a list of instances to show in GUI for this frame. This list will not include any predicted instances for which there's a corresponding regular instance. @@ -1331,7 +1374,7 @@ def instances_to_show(self) -> List[Instance]: def merge_frames( labeled_frames: List["LabeledFrame"], video: "Video", remove_redundant=True ) -> List["LabeledFrame"]: - """Merged LabeledFrames for same video and frame index. + """Return merged LabeledFrames for same video and frame index. Args: labeled_frames: List of :class:`LabeledFrame` objects to merge. @@ -1379,8 +1422,7 @@ def merge_frames( def complex_merge_between( cls, base_labels: "Labels", new_frames: List["LabeledFrame"] ) -> Tuple[Dict[Video, Dict[int, List[Instance]]], List[Instance], List[Instance]]: - """ - Merge data from new frames into a :class:`Labels` object. + """Merge data from new frames into a :class:`Labels` object. Everything that can be merged cleanly is merged, any conflicts are returned. @@ -1433,8 +1475,7 @@ def complex_merge_between( def complex_frame_merge( cls, base_frame: "LabeledFrame", new_frame: "LabeledFrame" ) -> Tuple[List[Instance], List[Instance], List[Instance]]: - """ - Merge two frames, return conflicts if any. + """Merge two frames, return conflicts if any. A conflict occurs when * each frame has Instances which don't perfectly match those @@ -1531,34 +1572,40 @@ def image(self) -> np.ndarray: """Return the image for this frame of shape (height, width, channels).""" return self.video.get_frame(self.frame_idx) - def plot(self, image: bool = True): + def numpy(self) -> np.ndarray: + """Return the instances as an array of shape (instances, nodes, 2).""" + return np.stack([inst.numpy() for inst in self.instances], axis=0) + + def plot(self, image: bool = True, scale: float = 1.0): """Plot the frame with all instances. Args: image: If False, only the instances will be plotted without loading the original image. + scale: Relative scaling for the figure. Notes: - See sleap.nn.viz.plot_img and sleap.nn.viz.plot_instances for more plotting - options. + See `sleap.nn.viz.plot_img` and `sleap.nn.viz.plot_instances` for more + plotting options. """ if image: - sleap.nn.viz.plot_img(self.image) + sleap.nn.viz.plot_img(self.image, scale=scale) sleap.nn.viz.plot_instances(self.instances) - def plot_predicted(self, image: bool = True): + def plot_predicted(self, image: bool = True, scale: float = 1.0): """Plot the frame with all predicted instances. Args: image: If False, only the instances will be plotted without loading the original image. + scale: Relative scaling for the figure. Notes: - See sleap.nn.viz.plot_img and sleap.nn.viz.plot_instances for more plotting - options. + See `sleap.nn.viz.plot_img` and `sleap.nn.viz.plot_instances` for more + plotting options. """ if image: - sleap.nn.viz.plot_img(self.image) + sleap.nn.viz.plot_img(self.image, scale=scale) sleap.nn.viz.plot_instances( self.predicted_instances, color_by_track=(len(self.predicted_instances) > 0) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 59c055515..a6ee370a5 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -51,6 +51,7 @@ Iterable, Any, Set, + Callable, ) import attr @@ -203,7 +204,7 @@ def get_video_track_occupancy(self, video: Video) -> Dict[Track, RangeList]: return self._track_occupancy[video] def remove_frame(self, frame: LabeledFrame): - """Remvoe frame and update cache as needed.""" + """Remove frame and update cache as needed.""" self._lf_by_video[frame.video].remove(frame) # We'll assume that there's only a single LabeledFrame for this video and # frame_idx, and remove the frame_idx from the cache. @@ -412,7 +413,7 @@ class Labels(MutableSequence): skeletons: List[Skeleton] = attr.ib(default=attr.Factory(list)) nodes: List[Node] = attr.ib(default=attr.Factory(list)) tracks: List[Track] = attr.ib(default=attr.Factory(list)) - suggestions: List["SuggestionFrame"] = attr.ib(default=attr.Factory(list)) + suggestions: List[SuggestionFrame] = attr.ib(default=attr.Factory(list)) negative_anchors: Dict[Video, list] = attr.ib(default=attr.Factory(dict)) provenance: Dict[Text, Union[str, int, float, bool]] = attr.ib( default=attr.Factory(dict) @@ -510,8 +511,8 @@ def _update_from_labels(self, merge: bool = False): self.tracks.extend(new_tracks) def _update_containers(self, new_label: LabeledFrame): - """ Ensure that top-level containers are kept updated with new - instances of objects that come along with new labels. """ + """Ensure that top-level containers are kept updated with new + instances of objects that come along with new labels.""" if new_label.video not in self.videos: self.videos.append(new_label.video) @@ -644,6 +645,8 @@ def __getitem__(self, key, *args) -> Union[LabeledFrame, List[LabeledFrame]]: scalar key was provided. """ if len(args) > 0: + if type(key) != tuple: + key = (key,) key = key + tuple(args) if isinstance(key, int): @@ -685,6 +688,46 @@ def __getitem__(self, key, *args) -> Union[LabeledFrame, List[LabeledFrame]]: else: raise KeyError("Invalid label indexing arguments.") + def get(self, *args) -> Union[LabeledFrame, List[LabeledFrame]]: + """Get an item from the labels or return `None` if not found. + + This is a safe version of `labels[...]` that will not raise an exception if the + item is not found. + """ + try: + return self.__getitem__(*args) + except KeyError: + return None + + def extract(self, inds) -> "Labels": + """Extract labeled frames from indices and return a new `Labels` object. + + Args: + inds: Any valid indexing keys, e.g., a range, slice, list of label indices, + numpy array, `Video`, etc. See `__getitem__` for full list. + + Returns: + A new `Labels` object with the specified labeled frames. + + This will preserve the other data structures even if they are not found in + the extracted labels, including: + - `Labels.videos` + - `Labels.skeletons` + - `Labels.tracks` + - `Labels.suggestions` + - `Labels.provenance` + """ + lfs = self.__getitem__(inds) + new_labels = type(self)( + labeled_frames=lfs, + videos=self.videos, + skeletons=self.skeletons, + tracks=self.tracks, + suggestions=self.suggestions, + provenance=self.provenance, + ) + return new_labels + def __setitem__(self, index, value: LabeledFrame): """Set labeled frame at given index.""" # TODO: Maybe we should remove this method altogether? @@ -733,6 +776,13 @@ def remove_frames(self, lfs: List[LabeledFrame]): self.labeled_frames = [lf for lf in self.labeled_frames if lf not in to_remove] self.update_cache() + def remove_empty_frames(self): + """Remove frames with no instances.""" + self.labeled_frames = [ + lf for lf in self.labeled_frames if len(lf.instances) > 0 + ] + self.update_cache() + def find( self, video: Video, @@ -830,10 +880,19 @@ def find_last( return label @property - def user_labeled_frames(self): + def user_labeled_frames(self) -> List[LabeledFrame]: """Return all labeled frames with user (non-predicted) instances.""" return [lf for lf in self.labeled_frames if lf.has_user_instances] + @property + def user_labeled_frame_inds(self) -> List[int]: + """Return a list of indices of frames with user labeled instances.""" + return [i for i, lf in enumerate(self.labeled_frames) if lf.has_user_instances] + + def with_user_labels_only(self) -> "Labels": + """Return a new `Labels` object with only user labels.""" + return self.extract(self.user_labeled_frame_inds) + def get_labeled_frame_count(self, video: Optional[Video] = None, filter: Text = ""): return self._cache.get_frame_count(video, filter) @@ -855,46 +914,37 @@ def all_instances(self) -> List[Instance]: @property def user_instances(self) -> List[Instance]: """Return list of all user (non-predicted) instances.""" - return [inst for inst in self.all_instances if isinstance(inst, Instance)] + return [inst for inst in self.all_instances if type(inst) == Instance] @property def predicted_instances(self) -> List[PredictedInstance]: """Return list of all predicted instances.""" - return [ - inst for inst in self.all_instances if isinstance(inst, PredictedInstance) - ] + return [inst for inst in self.all_instances if type(inst) == PredictedInstance] def describe(self): """Print basic statistics about the labels dataset.""" - print(f"Videos: {len(self.videos)}") - n_user_inst = len(self.user_instances) - n_predicted_inst = len(self.predicted_instances) - print( - f"Instances: {n_user_inst:,} (user-labeled), " - f"{n_predicted_inst:,} (predicted), " - f"{n_user_inst + n_predicted_inst:,} (total)" - ) - n_user_only = 0 - n_pred_only = 0 - n_both = 0 + print(f"Skeleton: {self.skeleton}") + print(f"Videos: {[v.filename for v in self.videos]}") + n_user = 0 + n_pred = 0 + n_user_inst = 0 + n_pred_inst = 0 for lf in self.labeled_frames: - has_user = lf.has_user_instances - has_pred = lf.has_predicted_instances - if has_user and not has_pred: - n_user_only += 1 - elif not has_user and has_pred: - n_pred_only += 1 - elif has_user and has_pred: - n_both += 1 - n_total = len(self.labeled_frames) - print( - f"Frames: {n_user_only:,} (user-labeled), " - f"{n_pred_only:,} (predicted), " - f"{n_both:,} (both), " - f"{n_total:,} (total)" - ) - - def instances(self, video: Video = None, skeleton: Skeleton = None): + if lf.has_user_instances: + n_user += 1 + n_user_inst += len(lf.user_instances) + if lf.has_predicted_instances: + n_pred += 1 + n_pred_inst += len(lf.predicted_instances) + print(f"Frames (user/predicted): {n_user:,}/{n_pred:,}") + print(f"Instances (user/predicted): {n_user_inst:,}/{n_pred_inst:,}") + print("Tracks:", self.tracks) + print(f"Suggestions: {len(self.suggestions):,}") + print("Provenance:", self.provenance) + + def instances( + self, video: Optional[Video] = None, skeleton: Optional[Skeleton] = None + ): """Iterate over instances in the labels, optionally with filters. Args: @@ -1086,16 +1136,54 @@ def does_track_match(inst, tr, labeled_frame): ] return track_frame_inst - def get_video_suggestions(self, video: Video) -> List[int]: + def add_suggestion(self, video: Video, frame_idx: int): + """Add a suggested frame to the labels. + + Args: + video: `sleap.Video` instance of the suggestion. + frame_idx: Index of the frame of the suggestion. + """ + for suggestion in self.suggestions: + if suggestion.video == video and suggestion.frame_idx == frame_idx: + return + self.suggestions.append(SuggestionFrame(video=video, frame_idx=frame_idx)) + + def remove_suggestion(self, video: Video, frame_idx: int): + """Remove a suggestion from the list by video and frame index. + + Args: + video: `sleap.Video` instance of the suggestion. + frame_idx: Index of the frame of the suggestion. + """ + for suggestion in self.suggestions: + if suggestion.video == video and suggestion.frame_idx == frame_idx: + self.suggestions.remove(suggestion) + return + + def get_video_suggestions( + self, video: Video, user_labeled: bool = True + ) -> List[int]: """Return a list of suggested frame indices. Args: video: Video to get suggestions for. + user_labeled: If `True` (the default), return frame indices for suggestions + that already have user labels. If `False`, only suggestions with no user + labeled instances will be returned. Returns: - Indices of the labeled frames for for the specified video. + Indices of the suggested frames for for the specified video. """ - return [item.frame_idx for item in self.suggestions if item.video == video] + frame_indices = [] + for suggestion in self.suggestions: + if suggestion.video == video: + fidx = suggestion.frame_idx + if not user_labeled: + lf = self.get((video, fidx)) + if lf is not None and lf.has_user_instances: + continue + frame_indices.append(fidx) + return frame_indices def get_suggestions(self) -> List[SuggestionFrame]: """Return all suggestions as a list of SuggestionFrame items.""" @@ -1170,6 +1258,47 @@ def delete_suggestions(self, video): """Delete suggestions for specified video.""" self.suggestions = [item for item in self.suggestions if item.video != video] + def clear_suggestions(self): + """Delete all suggestions.""" + self.suggestions = [] + + @property + def unlabeled_suggestions(self) -> List[SuggestionFrame]: + """Return suggestions without user labels.""" + unlabeled_suggestions = [] + for suggestion in self.suggestions: + lf = self.get(suggestion.video, suggestion.frame_idx) + if lf is None or not lf.has_user_instances: + unlabeled_suggestions.append(suggestion) + return unlabeled_suggestions + + def get_unlabeled_suggestion_inds(self) -> List[int]: + """Find labeled frames for unlabeled suggestions and return their indices. + + This is useful for generating a list of example indices for inference on + unlabeled suggestions. + + Returns: + List of indices of the labeled frames that correspond to the suggestions + that do not have user instances. + + If a labeled frame corresponding to a suggestion does not exist, an empty + one will be created. + + See also: `Labels.remove_empty_frames` + """ + inds = [] + for suggestion in self.unlabeled_suggestions: + lf = self.get((suggestion.video, suggestion.frame_idx)) + if lf is None: + self.append( + LabeledFrame(video=suggestion.video, frame_idx=suggestion.frame_idx) + ) + inds.append(len(self.labeled_frames) - 1) + else: + inds.append(self.index(lf)) + return inds + def add_video(self, video: Video): """Add a video to the labels if it is not already in it. @@ -1643,6 +1772,29 @@ def save( suggested=embed_suggested, ) + def export(self, filename: str): + """Export labels to analysis HDF5 format. + + This expects the labels to contain data for a single video (e.g., predictions). + + Args: + filename: Path to output HDF5 file. + + Notes: + This will write the contents of the labels out as a HDF5 file without + complete metadata. + + The resulting file will have datasets: + - `/node_names`: List of skeleton node names. + - `/track_names`: List of track names. + - `/tracks`: All coordinates of the instances in the labels. + - `/track_occupancy`: Mask denoting which instances are present in each + frame. + """ + from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor + + SleapAnalysisAdaptor.write(filename, self) + @classmethod def load_json(cls, filename: str, *args, **kwargs) -> "Labels": from .format import read @@ -1679,6 +1831,18 @@ def load_deeplabcut(cls, filename: str) -> "Labels": return read(filename, for_object="labels", as_format="deeplabcut") + @classmethod + def load_deeplabcut_folder(cls, filename: str) -> "Labels": + csv_files = glob(f"{filename}/*/*.csv") + merged_labels = None + for csv_file in csv_files: + labels = cls.load_file(csv_file, as_format="deeplabcut") + if merged_labels is None: + merged_labels = labels + else: + merged_labels.extend_from(labels, unify=True) + return merged_labels + @classmethod def load_coco( cls, filename: str, img_dir: str, use_missing_gui: bool = False @@ -1750,6 +1914,7 @@ def save_frame_data_hdf5( user_labeled: bool = True, all_labeled: bool = False, suggested: bool = False, + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> List[HDF5Video]: """Write images for labeled frames from all videos to hdf5 file. @@ -1765,12 +1930,21 @@ def save_frame_data_hdf5( Defaults to `False`. suggested: Include suggested frames even if they do not have instances. Useful for inference after training. Defaults to `False`. + progress_callback: If provided, this function will be called to report the + progress of the frame data saving. This function should be a callable + of the form: `fn(n, n_total)` where `n` is the number of frames saved so + far and `n_total` is the total number of frames that will be saved. This + is called after each video is processed. If the function has a return + value and it returns `False`, saving will be canceled and the output + deleted. Returns: A list of :class:`HDF5Video` objects with the stored frames. """ - new_vids = [] - for v_idx, video in enumerate(self.videos): + # Build list of frames to save. + vids = [] + frame_idxs = [] + for video in self.videos: lfs_v = self.find(video) frame_nums = [ lf.frame_idx @@ -1784,13 +1958,29 @@ def save_frame_data_hdf5( if suggestion.video == video ] frame_nums = sorted(list(set(frame_nums))) + vids.append(video) + frame_idxs.append(frame_nums) + + n_total = sum([len(x) for x in frame_idxs]) + n = 0 + # Save images for each video. + new_vids = [] + for v_idx, (video, frame_nums) in enumerate(zip(vids, frame_idxs)): vid = video.to_hdf5( path=output_path, dataset=f"video{v_idx}", format=format, frame_numbers=frame_nums, ) + n += len(frame_nums) + if progress_callback is not None: + # Notify update callback. + ret = progress_callback(n, n_total) + if ret == False: + vid.close() + return [] + vid.close() new_vids.append(vid) @@ -1837,6 +2027,57 @@ def to_pipeline( pipeline += pipelines.Prefetcher() return pipeline + def numpy( + self, video: Optional[Video] = None, all_frames: bool = True + ) -> np.ndarray: + """Construct a numpy array from tracked instance points. + + Args: + video: Video to convert to numpy arrays. If `None` (the default), uses the + first video. + all_frames: If `True` (the default), allocate array of the same number of + frames as the video. If `False`, only return data between the first and + last frame with data. + + Returns: + An array of tracks of shape `(n_frames, n_tracks, n_nodes, 2)`. + + Missing data will be replaced with `np.nan`. + + Notes: + This method assumes that instances have tracks assigned and is intended to + function primarily for single-video prediction results. + """ + if video is None: + video = self.videos[0] + lfs = self.find(video=video) + + if all_frames: + first_frame, last_frame = None, None + for lf in lfs: + if first_frame is None: + first_frame = lf.frame_idx + if last_frame is None: + last_frame = lf.frame_idx + first_frame = min(first_frame, lf.frame_idx) + last_frame = max(last_frame, lf.frame_idx) + else: + first_frame, last_frame = 0, video.shape[0] - 1 + + n_frames = last_frame - first_frame + 1 + n_tracks = len(self.tracks) + n_nodes = len(self.skeleton.nodes) + + tracks = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, dtype="float32") + for lf in lfs: + i = lf.frame_idx - first_frame + for inst in lf: + if inst.track is not None: + j = self.tracks.index(inst.track) + tracks[i, j] = inst.numpy() + + return tracks + @classmethod def make_gui_video_callback(cls, search_paths: Optional[List] = None) -> Callable: return cls.make_video_callback(search_paths=search_paths, use_gui=True) @@ -1964,9 +2205,16 @@ def load_file( filename: Text, detect_videos: bool = True, search_paths: Optional[Union[List[Text], Text]] = None, + match_to: Optional[Labels] = None, ) -> Labels: """Load a SLEAP labels file. + SLEAP labels files (`.slp`) contain all the metadata for a labeling project or the + predicted labels from a video. This includes the skeleton, videos, labeled frames, + user-labeled and predicted instances, suggestions and tracks. + + See `sleap.io.dataset.Labels` for more detailed information. + Args: filename: Path to a SLEAP labels (.slp) file. detect_videos: If True, will attempt to detect missing videos by searching for @@ -1976,6 +2224,9 @@ def load_file( be the direct path to the video file or its containing folder. If not specified, defaults to searching for the videos in the same folder as the labels. + match_to: If a `sleap.Labels` object is provided, attempt to match and reuse + video and skeleton objects when loading. This is useful when comparing the + contents across sets of labels. Returns: The loaded `Labels` instance. @@ -1990,6 +2241,6 @@ def load_file( if detect_videos: if search_paths is None: search_paths = os.path.dirname(filename) - return Labels.load_file(filename, search_paths) + return Labels.load_file(filename, search_paths, match_to=match_to) else: - return Labels.load_file(filename) + return Labels.load_file(filename, match_to=match_to) diff --git a/sleap/io/format/deeplabcut.py b/sleap/io/format/deeplabcut.py index 77fe9ee1e..700518f9c 100644 --- a/sleap/io/format/deeplabcut.py +++ b/sleap/io/format/deeplabcut.py @@ -67,9 +67,17 @@ def does_write(self) -> bool: @classmethod def read( - cls, file: FileHandle, full_video: Optional[Video] = None, *args, **kwargs, + cls, + file: FileHandle, + full_video: Optional[Video] = None, + *args, + **kwargs, ) -> Labels: - return Labels(labeled_frames=cls.read_frames(file=file, full_video=full_video, *args, **kwargs)) + return Labels( + labeled_frames=cls.read_frames( + file=file, full_video=full_video, *args, **kwargs + ) + ) @classmethod def make_video_for_image_list(cls, image_dir, filenames) -> Video: @@ -232,7 +240,12 @@ def does_write(self) -> bool: return False @classmethod - def read(cls, file: FileHandle, *args, **kwargs,) -> Labels: + def read( + cls, + file: FileHandle, + *args, + **kwargs, + ) -> Labels: filename = file.filename # Load data from the YAML file diff --git a/sleap/io/format/deepposekit.py b/sleap/io/format/deepposekit.py index 9ad4dd29b..0207d2590 100644 --- a/sleap/io/format/deepposekit.py +++ b/sleap/io/format/deepposekit.py @@ -50,7 +50,12 @@ def does_write(self) -> bool: @classmethod def read( - cls, file: FileHandle, video_path: str, skeleton_path: str, *args, **kwargs, + cls, + file: FileHandle, + video_path: str, + skeleton_path: str, + *args, + **kwargs, ) -> Labels: f = file.file diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py index bca4bda7a..7ba60cd79 100644 --- a/sleap/io/format/hdf5.py +++ b/sleap/io/format/hdf5.py @@ -225,6 +225,7 @@ def write( frame_data_format: str = "png", all_labeled: bool = False, suggested: bool = False, + progress_callback: Optional[Callable[[int, int], None]] = None, ): labels = source_object @@ -245,6 +246,7 @@ def write( user_labeled=True, all_labeled=all_labeled, suggested=suggested, + progress_callback=progress_callback, ) # Replace path to video file with "." (which indicates that the @@ -266,7 +268,6 @@ def write( new_videos.append(video) d["videos"] = Video.cattr().unstructure(new_videos) - with h5py.File(filename, "a") as f: # Add all the JSON metadata diff --git a/sleap/io/format/labels_json.py b/sleap/io/format/labels_json.py index 181eaf9f4..69b49189c 100644 --- a/sleap/io/format/labels_json.py +++ b/sleap/io/format/labels_json.py @@ -396,6 +396,7 @@ def from_json_data( # if we're given a Labels object to match, use its objects when they match if match_to is not None: + # Match skeletons for idx, sk in enumerate(skeletons): for old_sk in match_to.skeletons: if sk.matches(old_sk): @@ -406,6 +407,8 @@ def from_json_data( # use skeleton from match skeletons[idx] = old_sk break + + # Match videos for idx, vid in enumerate(videos): for old_vid in match_to.videos: @@ -433,6 +436,13 @@ def from_json_data( if is_match: break + # Match tracks + for idx, track in enumerate(tracks): + for old_track in match_to.tracks: + if track.name == old_track.name: + tracks[idx] = old_track + break + suggestions = [] if "suggestions" in dicts: suggestions_cattr = cattr.Converter() diff --git a/sleap/io/format/leap_matlab.py b/sleap/io/format/leap_matlab.py index 26385aeaf..2536440a5 100644 --- a/sleap/io/format/leap_matlab.py +++ b/sleap/io/format/leap_matlab.py @@ -54,7 +54,11 @@ def does_write(self) -> bool: @classmethod def read( - cls, file: FileHandle, gui: bool = True, *args, **kwargs, + cls, + file: FileHandle, + gui: bool = True, + *args, + **kwargs, ): filename = file.filename diff --git a/sleap/io/format/sleap_analysis.py b/sleap/io/format/sleap_analysis.py index d2e218efc..4b413752b 100644 --- a/sleap/io/format/sleap_analysis.py +++ b/sleap/io/format/sleap_analysis.py @@ -58,7 +58,11 @@ def does_write(self) -> bool: @classmethod def read( - cls, file: FileHandle, video: Union[Video, str], *args, **kwargs, + cls, + file: FileHandle, + video: Union[Video, str], + *args, + **kwargs, ) -> Labels: connect_adj_nodes = False diff --git a/sleap/io/video.py b/sleap/io/video.py index 0203f91b7..5b2ffb603 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -982,7 +982,13 @@ def shape(self) -> Tuple[int, int, int, int]: def __str__(self) -> str: """Informal string representation (for print or format).""" - return type(self).__name__ + " ([%d x %d x %d x %d])" % self.shape + return ( + "Video(" + f"filename={self.filename}, " + f"shape={self.shape}, " + f"backend={type(self.backend).__name__}" + ")" + ) def __len__(self) -> int: """Return the length of the video as the number of frames.""" @@ -1172,7 +1178,7 @@ def from_filename(cls, filename: str, *args, **kwargs) -> "Video": backend_class = HDF5Video elif filename.endswith(("npy")): backend_class = NumpyVideo - elif filename.lower().endswith(("mp4", "avi")): + elif filename.lower().endswith(("mp4", "avi", "mov")): backend_class = MediaVideo elif os.path.isdir(filename) or "metadata.yaml" in filename: backend_class = ImgStoreVideo @@ -1512,3 +1518,51 @@ def fixup_path( if raise_warning: logger.warning(f"Cannot find a video file: {path}") return path + + +def load_video( + filename: str, + grayscale: Optional[bool] = None, + dataset: Optional[str] = None, + channels_first: bool = False, +) -> Video: + """Open a video from disk. + + Args: + filename: Path to a video file. The video reader backend will be determined by + the file extension. Support extensions include: `.mp4`, `.avi`, `.h5`, + `.hdf5` and `.slp` (for embedded images in a labels file). If the path to a + folder is provided, images within that folder will be treated as video + frames. + grayscale: Read frames as a single channel grayscale images. If `None` (the + default), this will be auto-detected. + dataset: Name of the dataset that contains the video if loading a video stored + in an HDF5 file. This has no effect for non-HDF5 inputs. + channels_first: If `False` (the default), assume the data in the HDF5 dataset + are formatted in `(frames, height, width, channels)` order. If `False`, + assume the data are in `(frames, channels, width, height)` format. This has + no effect for non-HDF5 inputs. + + Returns: + A `sleap.Video` instance with the appropriate backend for its format. + + This enables numpy-like access to video data. + + Example: + >>> video = sleap.load_video("centered_pair_small.mp4") + >>> video.shape + (1100, 384, 384, 1) + >>> imgs = video[0:3] + >>> imgs.shape + (3, 384, 384, 1) + + See also: + sleap.io.video.Video + """ + kwargs = {} + if grayscale is not None: + kwargs["grayscale"] = grayscale + if dataset is not None: + kwargs["dataset"] = dataset + kwargs["input_format"] = "channels_first" if channels_first else "channels_last" + return Video.from_filename(filename, **kwargs) diff --git a/sleap/io/videowriter.py b/sleap/io/videowriter.py index 449ec0bb4..53fe9db29 100644 --- a/sleap/io/videowriter.py +++ b/sleap/io/videowriter.py @@ -76,7 +76,9 @@ def __init__(self, filename, height, width, fps): fps = str(fps) self._writer = skvideo.io.FFmpegWriter( filename, - inputdict={"-r": fps,}, + inputdict={ + "-r": fps, + }, outputdict={ "-c:v": "libx264", "-preset": "superfast", diff --git a/sleap/io/visuals.py b/sleap/io/visuals.py index 6f848c456..215cd9879 100644 --- a/sleap/io/visuals.py +++ b/sleap/io/visuals.py @@ -84,7 +84,10 @@ def reader(out_q: Queue, video: Video, frames: List[int], scale: float = 1.0): def writer( - in_q: Queue, progress_queue: Queue, filename: str, fps: float, + in_q: Queue, + progress_queue: Queue, + filename: str, + fps: float, ): """Write annotated images to video. @@ -224,7 +227,8 @@ def marker(self): t0 = clock() imgs = self._mark_images( - frame_indices=frames_idx_chunk, frame_images=video_frame_images, + frame_indices=frames_idx_chunk, + frame_images=video_frame_images, ) elapsed = clock() - t0 @@ -264,7 +268,9 @@ def _mark_single_frame(self, video_frame: np.ndarray, frame_idx: int) -> np.ndar return self._plot_instances_cv(video_frame, frame_idx) def _plot_instances_cv( - self, img: np.ndarray, frame_idx: int, + self, + img: np.ndarray, + frame_idx: int, ) -> Optional[np.ndarray]: """Adds visuals annotations to single frame image. @@ -474,7 +480,10 @@ def save_labeled_video( crop_size_xy=crop_size_xy, color_manager=color_manager, ) - thread_write = Thread(target=writer, args=(q2, progress_queue, filename, fps),) + thread_write = Thread( + target=writer, + args=(q2, progress_queue, filename, fps), + ) thread_read.start() thread_mark.start() diff --git a/sleap/message.py b/sleap/message.py index ee358cf64..3be8d3e83 100644 --- a/sleap/message.py +++ b/sleap/message.py @@ -21,6 +21,7 @@ @attr.s(auto_attribs=True) class BaseMessageParticipant: """Base class for simple Sender and Receiver.""" + address: Text = "tcp://127.0.0.1:9001" context: Optional[zmq.Context] = None _socket: Optional[zmq.Socket] = None @@ -87,7 +88,7 @@ def check_message(self, timeout: int = 10, fresh: bool = False) -> Any: def check_messages(self, timeout: int = 10, times_to_check: int = 10) -> List[dict]: """ Attempt to receive multiple messages. - + This method allows us to keep up with the messages by getting multiple messages that have been sent since the last check. It keeps checking until limit is reached *or* we check without @@ -139,7 +140,7 @@ def send_array( """Sends dictionary + numpy ndarray.""" if self._socket is None: self.setup() - + header_data["dtype"] = str(A.dtype) header_data["shape"] = A.shape diff --git a/sleap/nn/architectures/__init__.py b/sleap/nn/architectures/__init__.py index 82b5df52c..108aeb819 100644 --- a/sleap/nn/architectures/__init__.py +++ b/sleap/nn/architectures/__init__.py @@ -11,4 +11,4 @@ from sleap.nn.architectures.hourglass import Hourglass from sleap.nn.architectures.resnet import ResNetv1, ResNet50, ResNet101, ResNet152 from sleap.nn.architectures.common import IntermediateFeature -from sleap.nn.architectures.pretrained_encoders import UnetPretrainedEncoder \ No newline at end of file +from sleap.nn.architectures.pretrained_encoders import UnetPretrainedEncoder diff --git a/sleap/nn/architectures/encoder_decoder.py b/sleap/nn/architectures/encoder_decoder.py index 7433d933b..450485169 100644 --- a/sleap/nn/architectures/encoder_decoder.py +++ b/sleap/nn/architectures/encoder_decoder.py @@ -290,7 +290,7 @@ def make_block( prefix: String that will be added to the name of every layer in the block. If not specified, instantiating this block multiple times may result in name conflicts if existing layers have the same name. - + Returns: The output tensor after applying all operations in the block. """ @@ -475,12 +475,11 @@ def decoder_features_stride(self) -> int: def maximum_stride(self) -> int: """Return the maximum stride that the input must be divisible by.""" return self.encoder_features_stride - + @property def output_stride(self) -> int: """Return stride of the output of the backbone.""" return self.decoder_features_stride - def make_stem(self, x_in: tf.Tensor, prefix: Text = "stem") -> tf.Tensor: """Instantiate the stem layers defined by the stem block configuration. diff --git a/sleap/nn/architectures/hourglass.py b/sleap/nn/architectures/hourglass.py index b0e85c131..d9fc199a6 100644 --- a/sleap/nn/architectures/hourglass.py +++ b/sleap/nn/architectures/hourglass.py @@ -90,13 +90,12 @@ def make_block(self, x_in: tf.Tensor, prefix: Text = "stem") -> tf.Tensor: prefix=prefix + "_conv7x7", ) x = conv(x, filters=2 * self.filters, prefix=prefix + "_conv3x3") - + x = tf.keras.layers.MaxPool2D( strides=2 if (self.pool and self.pooling_stride > 1) else 1, padding="same", - name=prefix + "_pool")( - x - ) + name=prefix + "_pool", + )(x) x = conv(x, filters=self.output_filters, prefix=prefix + "_conv3x3_out") return x @@ -197,7 +196,7 @@ class Hourglass(encoder_decoder.EncoderDecoder): """Encoder-decoder definition of the (stacked) hourglass network backbone. This implements the architecture of the `Associative Embedding paper - `_, which improves upon the architecture in the + `_, which improves upon the architecture in the `original hourglass paper `_. The primary changes are to replace the residual block with simple convolutions and modify the filter sizes. diff --git a/sleap/nn/architectures/hrnet.py b/sleap/nn/architectures/hrnet.py index ce6f7264c..39fae5b39 100644 --- a/sleap/nn/architectures/hrnet.py +++ b/sleap/nn/architectures/hrnet.py @@ -492,9 +492,13 @@ def make_first_stage( return x -def make_hrnet_backbone(x_in, C=32, initial_downsampling_steps=2, stem_filters=64, bottleneck=False): +def make_hrnet_backbone( + x_in, C=32, initial_downsampling_steps=2, stem_filters=64, bottleneck=False +): - x = make_stem(x_in, filters=stem_filters, downsampling_steps=initial_downsampling_steps) + x = make_stem( + x_in, filters=stem_filters, downsampling_steps=initial_downsampling_steps + ) x = make_first_stage( x, bottleneck=False, block_filters=64, blocks=4, output_filters=C diff --git a/sleap/nn/architectures/pretrained_encoders.py b/sleap/nn/architectures/pretrained_encoders.py index ce569b329..862e0732d 100644 --- a/sleap/nn/architectures/pretrained_encoders.py +++ b/sleap/nn/architectures/pretrained_encoders.py @@ -131,6 +131,8 @@ class UnetPretrainedEncoder: pretrained: If `True` (the default), load pretrained weights for the encoder. If `False`, the same model architecture will be used for the encoder but the weights will be randomly initialized. + decoder_batchnorm: If `False` (the default), do not use batch normalization in + the decoder layers. """ encoder: str = attr.ib( @@ -138,6 +140,7 @@ class UnetPretrainedEncoder: ) decoder_filters: Tuple[int] = (256, 256, 128, 128) pretrained: bool = True + decoder_batchnorm: bool = True @classmethod def from_config(cls, config: PretrainedEncoderConfig) -> "UnetPretrainedEncoder": @@ -160,6 +163,7 @@ def from_config(cls, config: PretrainedEncoderConfig) -> "UnetPretrainedEncoder" encoder=config.encoder, pretrained=config.pretrained, decoder_filters=tuple(decoder_filters), + decoder_batchnorm=config.decoder_batchnorm, ) @property @@ -219,7 +223,7 @@ def make_backbone( encoder_weights="imagenet" if self.pretrained else None, decoder_block_type="upsampling", decoder_filters=self.decoder_filters, - decoder_use_batchnorm=True, + decoder_use_batchnorm=self.decoder_batchnorm, layers=tf.keras.layers, models=tf.keras.models, backend=tf.keras.backend, diff --git a/sleap/nn/architectures/resnet.py b/sleap/nn/architectures/resnet.py index e3bad3e76..bc6833ddb 100644 --- a/sleap/nn/architectures/resnet.py +++ b/sleap/nn/architectures/resnet.py @@ -57,7 +57,7 @@ def make_resnet_model( """Instantiate the ResNet, ResNetV2 (TODO), and ResNeXt (TODO) architecture. Optionally loads weights pre-trained on ImageNet. - + Args: backbone_fn: a function that returns output tensor for the stacked residual blocks. diff --git a/sleap/nn/architectures/unet.py b/sleap/nn/architectures/unet.py index 6731f0daf..8e088c75b 100644 --- a/sleap/nn/architectures/unet.py +++ b/sleap/nn/architectures/unet.py @@ -173,7 +173,7 @@ def encoder_stack(self) -> List[encoder_decoder.SimpleConvBlock]: use_bias=True, batch_norm=False, activation="relu", - block_prefix="_middle_expand" + block_prefix="_middle_expand", ) ) @@ -200,7 +200,7 @@ def encoder_stack(self) -> List[encoder_decoder.SimpleConvBlock]: use_bias=True, batch_norm=False, activation="relu", - block_prefix="_middle_contract" + block_prefix="_middle_contract", ) ) @@ -213,12 +213,18 @@ def decoder_stack(self) -> List[encoder_decoder.SimpleUpsamplingBlock]: for block in range(self.up_blocks): block_filters_in = int( self.filters - * (self.filters_rate ** (self.down_blocks + self.stem_blocks - 1 - block)) + * ( + self.filters_rate + ** (self.down_blocks + self.stem_blocks - 1 - block) + ) ) if self.block_contraction: block_filters_out = int( self.filters - * (self.filters_rate ** (self.down_blocks + self.stem_blocks - 2 - block)) + * ( + self.filters_rate + ** (self.down_blocks + self.stem_blocks - 2 - block) + ) ) else: block_filters_out = block_filters_in diff --git a/sleap/nn/callbacks.py b/sleap/nn/callbacks.py index 364c9f910..40384e04b 100644 --- a/sleap/nn/callbacks.py +++ b/sleap/nn/callbacks.py @@ -56,7 +56,7 @@ def on_batch_end(self, batch, logs=None): self.set_lr(msg["lr"]) def set_lr(self, lr): - """ Adjust the model learning rate. + """Adjust the model learning rate. This is the based off of the implementation used in the native learning rate scheduling callbacks. diff --git a/sleap/nn/config/__init__.py b/sleap/nn/config/__init__.py index 5eb8e2709..af262315b 100644 --- a/sleap/nn/config/__init__.py +++ b/sleap/nn/config/__init__.py @@ -12,7 +12,9 @@ PartAffinityFieldsHeadConfig, MultiInstanceConfig, ClassMapsHeadConfig, - MultiClassConfig, + MultiClassBottomUpConfig, + ClassVectorsHeadConfig, + MultiClassTopDownConfig, HeadsConfig, LEAPConfig, UNetConfig, @@ -36,4 +38,4 @@ ZMQConfig, OutputsConfig, ) -from sleap.nn.config.training_job import TrainingJobConfig +from sleap.nn.config.training_job import TrainingJobConfig, load_config diff --git a/sleap/nn/config/data.py b/sleap/nn/config/data.py index bf25bfcd6..37cf181e8 100644 --- a/sleap/nn/config/data.py +++ b/sleap/nn/config/data.py @@ -26,6 +26,15 @@ class LabelsConfig: can be computed from these data during model optimization. This is also useful to explicitly keep track of the test set that should be used when multiple splits are created for training. + split_by_inds: If `True`, splits used for training will be determined by the + lists below by indexing into the labels in `training_labels`. If this is + `False`, the indices below will not be used even if specified. This is + useful for specifying the fixed split sets from examples within a single + labels file. If splits are generated automatically (using + `validation_fraction`), the selected indices are stored below for reference. + training_inds: List of indices of the training split labels. + validation_inds: List of indices of the validation split labels. + test_inds: List of indices of the test split labels. search_path_hints: List of paths to use for searching for missing data. This is useful when labels and data are moved across computers, network storage, or operating systems that may have different absolute paths than those stored @@ -40,6 +49,10 @@ class LabelsConfig: validation_labels: Optional[Text] = None validation_fraction: float = 0.1 test_labels: Optional[Text] = None + split_by_inds: bool = False + training_inds: Optional[List[int]] = None + validation_inds: Optional[List[int]] = None + test_inds: Optional[List[int]] = None search_path_hints: List[Text] = attr.ib(factory=list) skeletons: List[sleap.Skeleton] = attr.ib(factory=list) @@ -76,6 +89,14 @@ class PreprocessingConfig: max stride (typically 32). This padding will be ignored when instance cropping inputs since the crop size should already be divisible by the model's max stride. + resize_and_pad_to_target: If True, will resize and pad all images in the dataset + to match target dimensions. This is useful when preprocessing datasets with + mixed image dimensions (from different video resolutions). Aspect ratio is + preserved, and padding applied (if needed) to bottom or right of image only. + target_height: Target image height for 'resize_and_pad_to_target'. When not + explicitly provided, inferred as the max image height from the dataset. + target_width: Target image width for 'resize_and_pad_to_target'. When not + explicitly provided, inferred as the max image width from the dataset. """ ensure_rgb: bool = False @@ -88,6 +109,9 @@ class PreprocessingConfig: ) input_scaling: float = 1.0 pad_to_stride: Optional[int] = None + resize_and_pad_to_target: bool = True + target_height: Optional[int] = None + target_width: Optional[int] = None @attr.s(auto_attribs=True) @@ -132,4 +156,3 @@ class DataConfig: labels: LabelsConfig = attr.ib(factory=LabelsConfig) preprocessing: PreprocessingConfig = attr.ib(factory=PreprocessingConfig) instance_cropping: InstanceCroppingConfig = attr.ib(factory=InstanceCroppingConfig) - diff --git a/sleap/nn/config/model.py b/sleap/nn/config/model.py index 13837b88e..ca4573266 100644 --- a/sleap/nn/config/model.py +++ b/sleap/nn/config/model.py @@ -28,6 +28,9 @@ class SingleInstanceConfmapsHeadConfig: results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + loss_weight: Scalar float used to weigh the loss term for this head during + training. Increase this to encourage the optimization to focus on improving + this specific output in multi-head models. offset_refinement: If `True`, model will also output an offset refinement map used to achieve subpixel localization of peaks during inference. This can improve the localization accuracy of the model at the cost of additional @@ -40,6 +43,7 @@ class SingleInstanceConfmapsHeadConfig: part_names: Optional[List[Text]] = None sigma: float = 5.0 output_stride: int = 1 + loss_weight: float = 1.0 offset_refinement: bool = False @@ -72,6 +76,9 @@ class CentroidsHeadConfig: results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + loss_weight: Scalar float used to weigh the loss term for this head during + training. Increase this to encourage the optimization to focus on improving + this specific output in multi-head models. offset_refinement: If `True`, model will also output an offset refinement map used to achieve subpixel localization of peaks during inference. This can improve the localization accuracy of the model at the cost of additional @@ -84,6 +91,7 @@ class CentroidsHeadConfig: anchor_part: Optional[Text] = None sigma: float = 5.0 output_stride: int = 1 + loss_weight: float = 1.0 offset_refinement: bool = False @@ -129,6 +137,9 @@ class CenteredInstanceConfmapsHeadConfig: results in confidence maps that are 0.5x the size of the input. Increasing this value can considerably speed up model performance and decrease memory requirements, at the cost of decreased spatial resolution. + loss_weight: Scalar float used to weigh the loss term for this head during + training. Increase this to encourage the optimization to focus on improving + this specific output in multi-head models. offset_refinement: If `True`, model will also output an offset refinement map used to achieve subpixel localization of peaks during inference. This can improve the localization accuracy of the model at the cost of additional @@ -142,6 +153,7 @@ class CenteredInstanceConfmapsHeadConfig: part_names: Optional[List[Text]] = None sigma: float = 5.0 output_stride: int = 1 + loss_weight: float = 1.0 offset_refinement: bool = False @@ -305,7 +317,7 @@ class ClassMapsHeadConfig: @attr.s(auto_attribs=True) -class MultiClassConfig: +class MultiClassBottomUpConfig: """Configuration for multi-instance confidence map and class map models. This configuration specifies a multi-head model that outputs both multi-instance @@ -330,6 +342,69 @@ class MultiClassConfig: class_maps: ClassMapsHeadConfig = attr.ib(factory=ClassMapsHeadConfig) +@attr.s(auto_attribs=True) +class ClassVectorsHeadConfig: + """Configurations for class vectors heads. + + These heads are used in top-down multi-instance models that classify detected + points using a fixed set of learned classes (e.g., animal identities). + + Class vectors represent the probability that the image is associated with each of + the specified classes. This is similar to a standard classification task. + + Attributes: + classes: List of string names of the classes that this head will predict. + num_fc_layers: Number of fully-connected layers before the classification output + layer. These can help in transforming general image features into + classification-specific features. + num_fc_units: Number of units (dimensions) in the fully-connected layers before + classification. Increasing this can improve the representational capacity in + the pre-classification layers. + output_stride: The stride of the output class maps relative to the input image. + This is the reciprocal of the resolution, e.g., an output stride of 2 + results in maps that are 0.5x the size of the input. This should be the same + size as the confidence maps they are associated with. + loss_weight: Scalar float used to weigh the loss term for this head during + training. Increase this to encourage the optimization to focus on improving + this specific output in multi-head models. + """ + + classes: Optional[List[Text]] = None + num_fc_layers: int = 1 + num_fc_units: int = 64 + global_pool: bool = True + output_stride: int = 1 + loss_weight: float = 1.0 + + +@attr.s(auto_attribs=True) +class MultiClassTopDownConfig: + """Configuration for centered-instance confidence map and class map models. + + This configuration specifies a multi-head model that outputs both centered-instance + confidence maps and class vectors, which together enable multi-instance pose + tracking in a top-down fashion, i.e., instance-centered crops followed by pose + estimation and classification. + + The limitation with this approach is that the classes, e.g., animal identities, must + be labeled in the training data and cannot be generalized beyond those classes. This + is still useful for applications in which the animals are uniquely identifiable and + tracking their identities at inference time is critical, e.g., for closed loop + experiments. + + Attributes: + confmaps: Part confidence map configuration (see the description in + `CenteredInstanceConfmapsHeadConfig`). + class_vectors: Class map configuration (see the description in + `ClassVectorsHeadConfig`). + """ + + confmaps: CenteredInstanceConfmapsHeadConfig = attr.ib( + factory=CenteredInstanceConfmapsHeadConfig + ) + class_vectors: ClassVectorsHeadConfig = attr.ib(factory=ClassVectorsHeadConfig) + + @oneof @attr.s(auto_attribs=True) class HeadsConfig: @@ -342,14 +417,16 @@ class HeadsConfig: centroid: An instance of `CentroidsHeadConfig`. centered_instance: An instance of `CenteredInstanceConfmapsHeadConfig`. multi_instance: An instance of `MultiInstanceConfig`. - multi_class: An instance of `MultiClassConfig`. + multi_class_bottomup: An instance of `MultiClassBottomUpConfig`. + multi_class_topdown: An instance of `MultiClassTopDownConfig`. """ single_instance: Optional[SingleInstanceConfmapsHeadConfig] = None centroid: Optional[CentroidsHeadConfig] = None centered_instance: Optional[CenteredInstanceConfmapsHeadConfig] = None multi_instance: Optional[MultiInstanceConfig] = None - multi_class: Optional[MultiClassConfig] = None + multi_class_bottomup: Optional[MultiClassBottomUpConfig] = None + multi_class_topdown: Optional[MultiClassTopDownConfig] = None @attr.s(auto_attribs=True) @@ -517,13 +594,27 @@ class PretrainedEncoderConfig: """Configuration for UNet backbone with pretrained encoder. Attributes: - encoder: Name of the network architecture to use as the encoder. + encoder: Name of the network architecture to use as the encoder. Valid encoder + names are: + - `"vgg16", "vgg19",` + - `"resnet18", "resnet34", "resnet50", "resnet101", "resnet152"` + - `"resnext50", "resnext101"` + - `"inceptionv3", "inceptionresnetv2"` + - `"densenet121", "densenet169", "densenet201"` + - `"seresnet18", "seresnet34", "seresnet50", "seresnet101", "seresnet152",` + `"seresnext50", "seresnext101", "senet154"` + - `"mobilenet", "mobilenetv2"` + - `"efficientnetb0", "efficientnetb1", "efficientnetb2", "efficientnetb3",` + `"efficientnetb4", "efficientnetb5", "efficientnetb6", "efficientnetb7"` + Defaults to `"efficientnetb0"`. pretrained: If `True`, use initialized with weights pretrained on ImageNet. decoder_filters: Base number of filters for the upsampling blocks in the decoder. decoder_filters_rate: Factor to scale the number of filters by at each consecutive upsampling block in the decoder. output_stride: Stride of the final output. + decoder_batchnorm: If `True` (the default), use batch normalization in the + decoder layers. """ encoder: Text = attr.ib(default="efficientnetb0") @@ -531,6 +622,7 @@ class PretrainedEncoderConfig: decoder_filters: int = 256 decoder_filters_rate: float = 1.0 output_stride: int = 2 + decoder_batchnorm: bool = True @oneof diff --git a/sleap/nn/config/optimization.py b/sleap/nn/config/optimization.py index 725526bba..cd43136b7 100644 --- a/sleap/nn/config/optimization.py +++ b/sleap/nn/config/optimization.py @@ -52,6 +52,11 @@ class AugmentationConfig: the augmentations above. random_crop_width: Width of random crops. random_crop_height: Height of random crops. + random_flip: If `True`, images will be randomly reflected. The coordinates of + the instances will be adjusted accordingly. Body parts that are left/right + symmetric must be marked on the skeleton in order to be swapped correctly. + flip_horizontal: If `True`, flip images left/right when randomly reflecting + them. If `False`, flipping is down up/down instead. """ rotate: bool = False @@ -78,6 +83,8 @@ class AugmentationConfig: random_crop: bool = False random_crop_height: int = 256 random_crop_width: int = 256 + random_flip: bool = False + flip_horizontal: bool = True @attr.s(auto_attribs=True) diff --git a/sleap/nn/config/outputs.py b/sleap/nn/config/outputs.py index 5976b651c..ffb0d76e4 100644 --- a/sleap/nn/config/outputs.py +++ b/sleap/nn/config/outputs.py @@ -112,7 +112,7 @@ class ZMQConfig: @attr.s(auto_attribs=True) class OutputsConfig: """Configuration of training outputs. - + Attributes: save_outputs: If True, file system-based outputs will be saved. If False, nothing will be written to disk, which may be useful for interactive @@ -151,6 +151,11 @@ class OutputsConfig: save_visualizations: If True, will render and save visualizations of the model predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the split is one of "train", "validation", "test". + delete_viz_images: If True, delete the saved visualizations after training + completes. This is useful to reduce the model folder size if you do not need + to keep the visualization images. + zip_outputs: If True, compress the run folder to a zip file. This will be named + "{run_folder}.zip". log_to_csv: If True, loss and metrics will be saved to a simple CSV after each epoch to "{run_folder}/training_log.csv" checkpointing: Configuration options related to model checkpointing. @@ -165,6 +170,8 @@ class OutputsConfig: runs_folder: Text = "models" tags: List[Text] = attr.ib(factory=list) save_visualizations: bool = True + delete_viz_images: bool = True + zip_outputs: bool = False log_to_csv: bool = True checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig) tensorboard: TensorBoardConfig = attr.ib(factory=TensorBoardConfig) @@ -191,7 +198,8 @@ def run_path(self) -> Text: """ if self.run_name is None: raise ValueError( - "Run path cannot be determined when the run name is not set.") + "Run path cannot be determined when the run name is not set." + ) folder_name = self.run_name_prefix + self.run_name if self.run_name_suffix is not None: folder_name += self.run_name_suffix diff --git a/sleap/nn/config/training_job.py b/sleap/nn/config/training_job.py index fb7721afc..680fd8f15 100644 --- a/sleap/nn/config/training_job.py +++ b/sleap/nn/config/training_job.py @@ -27,13 +27,14 @@ import os import attr import cattr +import sleap from sleap.nn.config.data import DataConfig from sleap.nn.config.model import ModelConfig from sleap.nn.config.optimization import OptimizationConfig from sleap.nn.config.outputs import OutputsConfig import json from jsmin import jsmin -from typing import Text, Dict, Any +from typing import Text, Dict, Any, Optional @attr.s(auto_attribs=True) @@ -45,13 +46,20 @@ class TrainingJobConfig: model: Configuration options related to the model architecture. optimization: Configuration options related to the training. outputs: Configuration options related to outputs during training. + name: Optional name for this configuration profile. + description: Optional description of the configuration. + sleap_version: Version of SLEAP that generated this configuration. + filename: Path to this config file if it was loaded from disk. """ data: DataConfig = attr.ib(factory=DataConfig) model: ModelConfig = attr.ib(factory=ModelConfig) optimization: OptimizationConfig = attr.ib(factory=OptimizationConfig) outputs: OutputsConfig = attr.ib(factory=OutputsConfig) - # TODO: store fixed config format version + SLEAP version? + name: Optional[Text] = "" + description: Optional[Text] = "" + sleap_version: Optional[Text] = sleap.__version__ + filename: Optional[Text] = "" @classmethod def from_json_dicts(cls, json_data_dicts: Dict[Text, Any]) -> "TrainingJobConfig": @@ -82,16 +90,27 @@ def from_json(cls, json_data: Text) -> "TrainingJobConfig": return cls.from_json_dicts(json_data_dicts) @classmethod - def load_json(cls, filename: Text) -> "TrainingJobConfig": + def load_json( + cls, filename: Text, load_training_config: bool = True + ) -> "TrainingJobConfig": """Load a training job configuration from a file. Arguments: filename: Path to a training job configuration JSON file or a directory containing `"training_job.json"`. + load_training_config: If `True` (the default), prefer `training_job.json` + over `initial_config.json` if it is present in the same folder. Returns: A TrainingJobConfig instance parsed from the file. """ + if load_training_config and filename.endswith("initial_config.json"): + training_config_path = os.path.join( + os.path.dirname(filename), "training_config.json" + ) + if os.path.exists(training_config_path): + filename = training_config_path + # Use stored configuration if a directory was provided. if os.path.isdir(filename): filename = os.path.join(filename, "training_config.json") @@ -100,7 +119,9 @@ def load_json(cls, filename: Text) -> "TrainingJobConfig": with open(filename, "r") as f: json_data = f.read() - return cls.from_json(json_data) + obj = cls.from_json(json_data) + obj.filename = filename + return obj def to_json(self) -> str: """Serialize the configuration into JSON-encoded string format. @@ -117,5 +138,22 @@ def save_json(self, filename: Text): Arguments: filename: Path to save the training job file to. """ + self.filename = filename with open(filename, "w") as f: f.write(self.to_json()) + + +def load_config(filename: Text, load_training_config: bool = True) -> TrainingJobConfig: + """Load a training job configuration for a model run. + + Args: + filename: Path to a JSON file or directory containing `training_job.json`. + load_training_config: If `True` (the default), prefer `training_job.json` over + `initial_config.json` if it is present in the same folder. + + Returns: + The parsed `TrainingJobConfig`. + """ + return TrainingJobConfig.load_json( + filename, load_training_config=load_training_config + ) diff --git a/sleap/nn/data/augmentation.py b/sleap/nn/data/augmentation.py index 86088bfa5..186ea5db4 100644 --- a/sleap/nn/data/augmentation.py +++ b/sleap/nn/data/augmentation.py @@ -7,16 +7,109 @@ if hasattr(numpy.random, "_bit_generator"): numpy.random.bit_generator = numpy.random._bit_generator +import sleap import numpy as np import tensorflow as tf import attr -from typing import List, Text +from typing import List, Text, Optional import imgaug as ia import imgaug.augmenters as iaa from sleap.nn.config import AugmentationConfig from sleap.nn.data.instance_cropping import crop_bboxes +def flip_instances_lr( + instances: tf.Tensor, img_width: int, symmetric_inds: Optional[tf.Tensor] = None +) -> tf.Tensor: + """Flip a set of instance points horizontally with symmetric node adjustment. + + Args: + instances: Instance points as a `tf.Tensor` of shape `(n_instances, n_nodes, 2)` + and dtype `tf.float32`. + img_width: Width of image in the same units as `instances`. + symmetric_inds: Indices of symmetric pairs of nodes as a `tf.Tensor` of shape + `(n_symmetries, 2)` and dtype `tf.int32`. Each row contains the indices of + nodes that are mirror symmetric, e.g., left/right body parts. The ordering + of the list or which node comes first (e.g., left/right vs right/left) does + not matter. Each pair of nodes will be swapped to account for the + reflection if this is not `None` (the default). + + Returns: + The instance points with x-coordinates flipped horizontally. + """ + instances = (tf.cast([[[img_width - 1, 0]]], tf.float32) - instances) * tf.cast( + [[[1, -1]]], tf.float32 + ) + + if symmetric_inds is not None: + n_instances = tf.shape(instances)[0] + n_symmetries = tf.shape(symmetric_inds)[0] + + inst_inds = tf.reshape(tf.repeat(tf.range(n_instances), n_symmetries), [-1, 1]) + sym_inds1 = tf.reshape(tf.gather(symmetric_inds, 0, axis=1), [-1, 1]) + sym_inds2 = tf.reshape(tf.gather(symmetric_inds, 1, axis=1), [-1, 1]) + + inst_inds = tf.cast(inst_inds, tf.int32) + sym_inds1 = tf.cast(sym_inds1, tf.int32) + sym_inds2 = tf.cast(sym_inds2, tf.int32) + + subs1 = tf.concat([inst_inds, tf.tile(sym_inds1, [n_instances, 1])], axis=1) + subs2 = tf.concat([inst_inds, tf.tile(sym_inds2, [n_instances, 1])], axis=1) + + pts1 = tf.gather_nd(instances, subs1) + pts2 = tf.gather_nd(instances, subs2) + instances = tf.tensor_scatter_nd_update(instances, subs1, pts2) + instances = tf.tensor_scatter_nd_update(instances, subs2, pts1) + + return instances + + +def flip_instances_ud( + instances: tf.Tensor, img_height: int, symmetric_inds: Optional[tf.Tensor] = None +) -> tf.Tensor: + """Flip a set of instance points vertically with symmetric node adjustment. + + Args: + instances: Instance points as a `tf.Tensor` of shape `(n_instances, n_nodes, 2)` + and dtype `tf.float32`. + img_height: Height of image in the same units as `instances`. + symmetric_inds: Indices of symmetric pairs of nodes as a `tf.Tensor` of shape + `(n_symmetries, 2)` and dtype `tf.int32`. Each row contains the indices of + nodes that are mirror symmetric, e.g., left/right body parts. The ordering + of the list or which node comes first (e.g., left/right vs right/left) does + not matter. Each pair of nodes will be swapped to account for the + reflection if this is not `None` (the default). + + Returns: + The instance points with y-coordinates flipped horizontally. + """ + instances = (tf.cast([[[0, img_height - 1]]], tf.float32) - instances) * tf.cast( + [[[-1, 1]]], tf.float32 + ) + + if symmetric_inds is not None: + n_instances = tf.shape(instances)[0] + n_symmetries = tf.shape(symmetric_inds)[0] + + inst_inds = tf.reshape(tf.repeat(tf.range(n_instances), n_symmetries), [-1, 1]) + sym_inds1 = tf.reshape(tf.gather(symmetric_inds, 0, axis=1), [-1, 1]) + sym_inds2 = tf.reshape(tf.gather(symmetric_inds, 1, axis=1), [-1, 1]) + + inst_inds = tf.cast(inst_inds, tf.int32) + sym_inds1 = tf.cast(sym_inds1, tf.int32) + sym_inds2 = tf.cast(sym_inds2, tf.int32) + + subs1 = tf.concat([inst_inds, tf.tile(sym_inds1, [n_instances, 1])], axis=1) + subs2 = tf.concat([inst_inds, tf.tile(sym_inds2, [n_instances, 1])], axis=1) + + pts1 = tf.gather_nd(instances, subs1) + pts2 = tf.gather_nd(instances, subs2) + instances = tf.tensor_scatter_nd_update(instances, subs1, pts2) + instances = tf.tensor_scatter_nd_update(instances, subs2, pts1) + + return instances + + @attr.s(auto_attribs=True) class ImgaugAugmenter: """Data transformer based on the `imgaug` library. @@ -249,3 +342,94 @@ def random_crop(ex): return ex return input_ds.map(random_crop) + + +@attr.s(auto_attribs=True) +class RandomFlipper: + """Data transformer for applying random flipping to input images. + + This class can generate a `tf.data.Dataset` from an existing one that generates + image and instance data. Elements of the output dataset will have random horizontal + flips applied. + + Attributes: + symmetric_inds: Indices of symmetric pairs of nodes as a an array of shape + `(n_symmetries, 2)`. Each row contains the indices of nodes that are mirror + symmetric, e.g., left/right body parts. The ordering of the list or which + node comes first (e.g., left/right vs right/left) does not matter. Each pair + of nodes will be swapped to account for the reflection if this is not `None` + (the default). + horizontal: If `True` (the default), flips are applied horizontally instead of + vertically. + probability: The probability that the augmentation should be applied. + """ + + symmetric_inds: Optional[np.ndarray] = None + horizontal: bool = True + probability: float = 0.5 + + @classmethod + def from_skeleton( + cls, skeleton: sleap.Skeleton, horizontal: bool = True, probability: float = 0.5 + ) -> "RandomFlipper": + """Create an instance of `RandomFlipper` from a skeleton. + + Args: + skeleton: A `sleap.Skeleton` that may define symmetric nodes. + horizontal: If `True` (the default), flips are applied horizontally instead + of vertically. + probability: The probability that the augmentation should be applied. + + Returns: + An instance of `RandomFlipper`. + """ + return cls( + symmetric_inds=skeleton.symmetric_inds, + horizontal=horizontal, + probability=probability, + ) + + @property + def input_keys(self): + return ["image", "instances"] + + @property + def output_keys(self): + return self.input_keys + + def transform_dataset(self, input_ds: tf.data.Dataset): + """Create a `tf.data.Dataset` with elements containing augmented data. + + Args: + input_ds: A dataset with elements that contain the keys `"image"` and + `"instances"`. This is typically raw data from a data provider. + + Returns: + A `tf.data.Dataset` with the same keys as the input, but with images and + instance points updated with the applied random flip. + """ + symmetric_inds = self.symmetric_inds + if symmetric_inds is not None: + symmetric_inds = np.array(symmetric_inds) + if len(symmetric_inds) == 0: + symmetric_inds = None + + def random_flip(ex): + """Apply random flip to an example.""" + p = tf.random.uniform((), minval=0, maxval=1.0) + if p <= self.probability: + if self.horizontal: + img_width = tf.shape(ex["image"])[1] + ex["instances"] = flip_instances_lr( + ex["instances"], img_width, symmetric_inds=symmetric_inds + ) + ex["image"] = tf.image.flip_left_right(ex["image"]) + else: + img_height = tf.shape(ex["image"])[0] + ex["instances"] = flip_instances_ud( + ex["instances"], img_height, symmetric_inds=symmetric_inds + ) + ex["image"] = tf.image.flip_up_down(ex["image"]) + return ex + + return input_ds.map(random_flip) diff --git a/sleap/nn/data/confidence_maps.py b/sleap/nn/data/confidence_maps.py index e4cac3e1b..8c8dc83d9 100644 --- a/sleap/nn/data/confidence_maps.py +++ b/sleap/nn/data/confidence_maps.py @@ -535,9 +535,7 @@ def generate_confmaps(example): if self.with_offsets: example["offsets"] = mask_offsets( - make_offsets( - example["points"], xv, yv, stride=self.output_stride - ), + make_offsets(example["points"], xv, yv, stride=self.output_stride), example["confidence_maps"], self.offsets_threshold, ) diff --git a/sleap/nn/data/dataset_ops.py b/sleap/nn/data/dataset_ops.py index 8d6392e27..6c46fae38 100644 --- a/sleap/nn/data/dataset_ops.py +++ b/sleap/nn/data/dataset_ops.py @@ -275,7 +275,7 @@ def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset: @attr.s(auto_attribs=True) class Preloader: """Preload elements of the underlying dataset to generate in-memory examples. - + This transformer can lead to considerable performance improvements at the cost of memory consumption. diff --git a/sleap/nn/data/identity.py b/sleap/nn/data/identity.py index 816124a2e..7d5c57be3 100644 --- a/sleap/nn/data/identity.py +++ b/sleap/nn/data/identity.py @@ -62,7 +62,7 @@ class vectors weighed by the relative contribution of each instance. # Normalize instance mask. mask = confmaps / tf.reduce_sum(confmaps, axis=2, keepdims=True) - mask = tf.where(confmaps > threshold, mask, 0.) # (h, w, n_instances) + mask = tf.where(confmaps > threshold, mask, 0.0) # (h, w, n_instances) mask = tf.expand_dims(mask, axis=3) # (h, w, n_instances, 1) # Apply mask to vectors to create class maps. @@ -70,6 +70,47 @@ class vectors weighed by the relative contribution of each instance. return class_maps +@attr.s(auto_attribs=True) +class ClassVectorGenerator: + """Transformer to generate class probability vectors from track indices.""" + + @property + def input_keys(self) -> List[Text]: + """Return the keys that incoming elements are expected to have.""" + return ["track_inds", "n_tracks"] + + @property + def output_keys(self) -> List[Text]: + """Return the keys that outgoing elements will have.""" + return self.input_keys + ["class_vectors"] + + def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset: + """Create a dataset that contains the generated class identity vectors. + + Args: + input_ds: A dataset with elements that contain the keys`"track_inds"` and + `"n_tracks"`. + + Returns: + A `tf.data.Dataset` with the same keys as the input, as well as a `"class"` + key containing the generated class vectors. + """ + + def generate_class_vectors(example): + """Local processing function for dataset mapping.""" + example["class_vectors"] = tf.cast( + make_class_vectors(example["track_inds"], example["n_tracks"]), + tf.float32, + ) + return example + + # Map transformation. + output_ds = input_ds.map( + generate_class_vectors, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + return output_ds + + @attr.s(auto_attribs=True) class ClassMapGenerator: """Transformer to generate class maps from track indices. diff --git a/sleap/nn/data/inference.py b/sleap/nn/data/inference.py index 2f0e24ef6..772ac3f8b 100644 --- a/sleap/nn/data/inference.py +++ b/sleap/nn/data/inference.py @@ -172,7 +172,8 @@ def find_peaks(example): centroid = example["centroid"] / example["scale"] all_peaks = example[self.all_peaks_in_key] # (n_instances, n_nodes, 2) dists = tf.reduce_min( - tf.norm(all_peaks - tf.reshape(centroid, [1, 1, 2]), axis=-1), axis=1, + tf.norm(all_peaks - tf.reshape(centroid, [1, 1, 2]), axis=-1), + axis=1, ) # (n_instances,) instance_ind = tf.argmin(dists) center_instance = tf.gather(all_peaks, instance_ind) diff --git a/sleap/nn/data/instance_cropping.py b/sleap/nn/data/instance_cropping.py index fff2e7ea2..1cfb1542b 100644 --- a/sleap/nn/data/instance_cropping.py +++ b/sleap/nn/data/instance_cropping.py @@ -279,7 +279,7 @@ def from_config( @property def input_keys(self) -> List[Text]: """Return the keys that incoming elements are expected to have.""" - return [self.image_key, self.instances_key, self.centroids_key] + return [self.image_key, self.instances_key, self.centroids_key, "track_inds"] @property def output_keys(self) -> List[Text]: @@ -289,6 +289,7 @@ def output_keys(self) -> List[Text]: "bbox", "center_instance", "center_instance_ind", + "track_ind", "all_instances", "centroid", "full_image_height", @@ -310,6 +311,7 @@ def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset: (n_instances, n_nodes, 2). "centroids": The computed centroid for each instance in a tf.float32 tensor of shape (n_instances, 2). + "track_inds": The track indices of the indices if available. Any additional keys present will be replicated in each output. Returns: @@ -331,6 +333,7 @@ def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset: "center_instance_ind": Scalar tf.int32 index of the centered instance relative to all the instances in the frame. This can be used to index into additional keys that may contain data from all instances. + "track_ind": Index of the track the instance belongs to if available. "all_instances": The points of all instances in the frame in image coordinates in the "instance_image". This will be a tf.float32 tensor of shape (n_instances, n_nodes, 2). This is useful for multi- @@ -359,6 +362,8 @@ def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset: keys_to_expand = [ key for key in test_example.keys() if key not in self.input_keys ] + if "class_vectors" in keys_to_expand: + keys_to_expand.remove("class_vectors") img_channels = test_example[self.image_key].shape[-1] if self.keep_full_image: keys_to_expand.append(self.image_key) @@ -408,6 +413,7 @@ def crop_instances(frame_data): "bbox": bboxes, "center_instance": center_instances, "center_instance_ind": tf.range(n_instances, dtype=tf.int32), + "track_ind": frame_data["track_inds"], "all_instances": all_instances, "centroid": frame_data[self.centroids_key], "full_image_height": tf.repeat( @@ -417,6 +423,8 @@ def crop_instances(frame_data): tf.shape(frame_data[self.image_key])[1], n_instances ), } + if "class_vectors" in frame_data: + instances_data["class_vectors"] = frame_data["class_vectors"] if self.mock_centroid_confidence: instances_data["centroid_confidence"] = tf.ones( [n_instances], dtype=tf.float32 diff --git a/sleap/nn/data/offset_regression.py b/sleap/nn/data/offset_regression.py index 437745812..c9e56b8c2 100644 --- a/sleap/nn/data/offset_regression.py +++ b/sleap/nn/data/offset_regression.py @@ -3,7 +3,9 @@ import tensorflow as tf -def make_offsets(points: tf.Tensor, xv: tf.Tensor, yv: tf.Tensor, stride: int = 1) -> tf.Tensor: +def make_offsets( + points: tf.Tensor, xv: tf.Tensor, yv: tf.Tensor, stride: int = 1 +) -> tf.Tensor: """Make point offset maps on a grid. Args: diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index db6aae5e5..daf9d8f23 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -20,9 +20,10 @@ AugmentationConfig, ImgaugAugmenter, RandomCropper, + RandomFlipper, ) from sleap.nn.data.normalization import Normalizer -from sleap.nn.data.resizing import Resizer, PointsRescaler +from sleap.nn.data.resizing import Resizer, PointsRescaler, SizeMatcher from sleap.nn.data.instance_centroids import InstanceCentroidFinder from sleap.nn.data.instance_cropping import InstanceCropper, PredictedInstanceCropper from sleap.nn.data.confidence_maps import ( @@ -31,7 +32,7 @@ SingleInstanceConfidenceMapGenerator, ) from sleap.nn.data.edge_maps import PartAffinityFieldsGenerator -from sleap.nn.data.identity import ClassMapGenerator +from sleap.nn.data.identity import ClassVectorGenerator, ClassMapGenerator from sleap.nn.data.dataset_ops import ( Shuffler, Batcher, @@ -59,6 +60,7 @@ CentroidConfmapsHead, CenteredInstanceConfmapsHead, ClassMapsHead, + ClassVectorsHead, SingleInstanceConfmapsHead, OffsetRefinementHead, ) @@ -70,12 +72,14 @@ RandomCropper, Normalizer, Resizer, + SizeMatcher, InstanceCentroidFinder, InstanceCropper, MultiConfidenceMapGenerator, InstanceConfidenceMapGenerator, PartAffinityFieldsGenerator, SingleInstanceConfidenceMapGenerator, + ClassVectorGenerator, ClassMapGenerator, Shuffler, Batcher, @@ -96,6 +100,7 @@ KeyDeviceMover, PointsRescaler, LambdaMap, + RandomFlipper, ) Provider = TypeVar("Provider", *PROVIDERS) Transformer = TypeVar("Transformer", *TRANSFORMERS) @@ -126,7 +131,7 @@ def from_blocks( """Create a pipeline from a sequence of providers and transformers. Args: - sequence: List or tuple of providers and transformer instances. + blocks: List or tuple of providers and transformer instances. Returns: An instantiated pipeline with all blocks chained. @@ -341,7 +346,6 @@ class SingleInstanceConfmapsPipeline: optimization_config: OptimizationConfig single_instance_confmap_head: SingleInstanceConfmapsHead offsets_head: Optional[OffsetRefinementHead] = None - def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """Create base pipeline with input data only. @@ -355,6 +359,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( @@ -386,6 +395,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: if self.optimization_config.online_shuffling: pipeline += Shuffler(self.optimization_config.shuffle_buffer_size) + if self.optimization_config.augmentation_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.labels.skeletons[0], + horizontal=self.optimization_config.augmentation_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config( self.optimization_config.augmentation_config ) @@ -395,13 +409,19 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: crop_width=self.optimization_config.augmentation_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += SingleInstanceConfidenceMapGenerator( sigma=self.single_instance_confmap_head.sigma, output_stride=self.single_instance_confmap_head.output_stride, with_offsets=self.offsets_head is not None, - offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0 + offsets_threshold=self.offsets_head.sigma_threshold + if self.offsets_head is not None + else 1.0, ) if len(data_provider) >= self.optimization_config.batch_size: @@ -485,6 +505,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( @@ -521,7 +546,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: pipeline += Shuffler( shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size ) - + if self.optimization_config.augmentation_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.labels.skeletons[0], + horizontal=self.optimization_config.augmentation_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config( self.optimization_config.augmentation_config ) @@ -531,8 +560,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: crop_width=self.optimization_config.augmentation_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, @@ -542,7 +575,9 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: output_stride=self.centroid_confmap_head.output_stride, centroids=True, with_offsets=self.offsets_head is not None, - offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0 + offsets_threshold=self.offsets_head.sigma_threshold + if self.offsets_head is not None + else 1.0, ) if len(data_provider) >= self.optimization_config.batch_size: @@ -637,6 +672,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, @@ -669,13 +709,21 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: pipeline += Shuffler( shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size ) - + if self.optimization_config.augmentation_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.labels.skeletons[0], + horizontal=self.optimization_config.augmentation_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config( self.optimization_config.augmentation_config ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, @@ -686,7 +734,9 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: output_stride=self.instance_confmap_head.output_stride, all_instances=False, with_offsets=self.offsets_head is not None, - offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0 + offsets_threshold=self.offsets_head.sigma_threshold + if self.offsets_head is not None + else 1.0, ) if len(data_provider) >= self.optimization_config.batch_size: @@ -766,6 +816,11 @@ def make_base_pipeline(self, data_provider: Provider) -> Pipeline: """ pipeline = Pipeline(providers=data_provider) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) if self.optimization_config.augmentation_config.random_crop: pipeline += RandomCropper( @@ -800,6 +855,11 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: ) aug_config = self.optimization_config.augmentation_config + if aug_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.labels.skeletons[0], + horizontal=aug_config.flip_horizontal, + ) pipeline += ImgaugAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( @@ -807,14 +867,20 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: crop_width=aug_config.random_crop_width, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + if self.data_config.preprocessing.resize_and_pad_to_target: + pipeline += SizeMatcher.from_config( + config=self.data_config.preprocessing, + provider=data_provider, + ) pipeline += Resizer.from_config(self.data_config.preprocessing) - pipeline += MultiConfidenceMapGenerator( sigma=self.confmaps_head.sigma, output_stride=self.confmaps_head.output_stride, centroids=False, with_offsets=self.offsets_head is not None, - offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0 + offsets_threshold=self.offsets_head.sigma_threshold + if self.offsets_head is not None + else 1.0, ) pipeline += PartAffinityFieldsGenerator( sigma=self.pafs_head.sigma, @@ -870,7 +936,7 @@ def make_viz_pipeline( model_output_keys=[ "predicted_confidence_maps", "predicted_part_affinity_fields", - ] + ], ) pipeline += LocalPeakFinder( confmaps_stride=self.confmaps_head.output_stride, @@ -963,7 +1029,9 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: output_stride=self.confmaps_head.output_stride, centroids=False, with_offsets=self.offsets_head is not None, - offsets_threshold=self.offsets_head.sigma_threshold if self.offsets_head is not None else 1.0 + offsets_threshold=self.offsets_head.sigma_threshold + if self.offsets_head is not None + else 1.0, ) pipeline += ClassMapGenerator( sigma=self.class_maps_head.sigma, @@ -1018,7 +1086,7 @@ def make_viz_pipeline( model_output_keys=[ "predicted_confidence_maps", "predicted_class_maps", - ] + ], ) pipeline += LocalPeakFinder( confmaps_stride=self.confmaps_head.output_stride, @@ -1030,3 +1098,139 @@ def make_viz_pipeline( peak_channel_inds_key="predicted_peak_channel_inds", ) return pipeline + + +@attr.s(auto_attribs=True) +class TopDownMultiClassPipeline: + """Pipeline builder for confidence maps and class maps models. + + Attributes: + data_config: Data-related configuration. + optimization_config: Optimization-related configuration. + confmaps_head: Instantiated head describing the output confidence maps tensor. + class_vectors_head: Instantiated head describing the output class vectors + tensor. + offsets_head: Optional head describing the offset refinement maps. + """ + + data_config: DataConfig + optimization_config: OptimizationConfig + instance_confmap_head: CenteredInstanceConfmapsHead + class_vectors_head: ClassVectorsHead + offsets_head: Optional[OffsetRefinementHead] = None + + def make_base_pipeline(self, data_provider: Provider) -> Pipeline: + """Create base pipeline with input data only. + + Args: + data_provider: A `Provider` that generates data examples, typically a + `LabelsReader` instance. + + Returns: + A `Pipeline` instance configured to produce input examples. + """ + pipeline = Pipeline(providers=data_provider) + pipeline += Normalizer.from_config(self.data_config.preprocessing) + pipeline += Resizer.from_config(self.data_config.preprocessing) + pipeline += InstanceCentroidFinder.from_config( + self.data_config.instance_cropping, + skeletons=self.data_config.labels.skeletons, + ) + pipeline += InstanceCropper.from_config(self.data_config.instance_cropping) + return pipeline + + def make_training_pipeline(self, data_provider: Provider) -> Pipeline: + """Create full training pipeline. + + Args: + data_provider: A `Provider` that generates data examples, typically a + `LabelsReader` instance. + + Returns: + A `Pipeline` instance configured to produce all data keys required for + training. + + Notes: + This does not remap keys to model outputs. Use `KeyMapper` to pull out keys + with the appropriate format for the instantiated `tf.keras.Model`. + """ + pipeline = Pipeline(providers=data_provider) + + if self.optimization_config.preload_data: + pipeline += Preloader() + + if self.optimization_config.online_shuffling: + pipeline += Shuffler( + shuffle=True, buffer_size=self.optimization_config.shuffle_buffer_size + ) + + pipeline += ImgaugAugmenter.from_config( + self.optimization_config.augmentation_config + ) + pipeline += Normalizer.from_config(self.data_config.preprocessing) + pipeline += Resizer.from_config(self.data_config.preprocessing) + + pipeline += ClassVectorGenerator() + pipeline += InstanceCentroidFinder.from_config( + self.data_config.instance_cropping, + skeletons=self.data_config.labels.skeletons, + ) + pipeline += InstanceCropper.from_config(self.data_config.instance_cropping) + pipeline += InstanceConfidenceMapGenerator( + sigma=self.instance_confmap_head.sigma, + output_stride=self.instance_confmap_head.output_stride, + all_instances=False, + with_offsets=self.offsets_head is not None, + offsets_threshold=self.offsets_head.sigma_threshold + if self.offsets_head is not None + else 1.0, + ) + + if len(data_provider) >= self.optimization_config.batch_size: + # Batching before repeating is preferred since it preserves epoch boundaries + # such that no sample is repeated within the epoch. But this breaks if there + # are fewer samples than the batch size. + pipeline += Batcher( + batch_size=self.optimization_config.batch_size, + drop_remainder=True, + unrag=True, + ) + pipeline += Repeater() + + else: + pipeline += Repeater() + pipeline += Batcher( + batch_size=self.optimization_config.batch_size, + drop_remainder=True, + unrag=True, + ) + + if self.optimization_config.prefetch: + pipeline += Prefetcher() + + return pipeline + + def make_viz_pipeline( + self, data_provider: Provider, keras_model: tf.keras.Model + ) -> Pipeline: + """Create visualization pipeline. + + Args: + data_provider: A `Provider` that generates data examples, typically a + `LabelsReader` instance. + keras_model: A `tf.keras.Model` that can be used for inference. + + Returns: + A `Pipeline` instance configured to fetch data and run inference to generate + predictions useful for visualization during training. + """ + pipeline = Pipeline(data_provider) + pipeline += Normalizer.from_config(self.data_config.preprocessing) + pipeline += InstanceCentroidFinder.from_config( + self.data_config.instance_cropping, + skeletons=self.data_config.labels.skeletons, + ) + pipeline += InstanceCropper.from_config(self.data_config.instance_cropping) + pipeline += Repeater() + pipeline += Prefetcher() + return pipeline diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 613560759..d5e9e9c3a 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -3,7 +3,7 @@ import numpy as np import tensorflow as tf import attr -from typing import Text, Optional, List, Sequence, Union +from typing import Text, Optional, List, Sequence, Union, Tuple import sleap @@ -23,10 +23,13 @@ class LabelsReader: the entire labels dataset will be read. These indices will be applicable to the labeled frames in `labels` attribute, which may have changed in ordering or filtered. + user_instances_only: If `True`, load only user labeled instances. If `False`, + all instances will be loaded. """ labels: sleap.Labels example_indices: Optional[Union[Sequence[int], np.ndarray]] = None + user_instances_only: bool = False @classmethod def from_user_instances(cls, labels: sleap.Labels) -> "LabelsReader": @@ -36,17 +39,40 @@ def from_user_instances(cls, labels: sleap.Labels) -> "LabelsReader": labels: A `sleap.Labels` instance containing user instances. Returns: - A `LabelsReader` instance that can create a dataset for pipelining. Note - that the examples may change in ordering relative to the input `labels`, so - be sure to use the `labels` attribute in the returned instance. + A `LabelsReader` instance that can create a dataset for pipelining. """ - user_labels = sleap.Labels( - [ - sleap.LabeledFrame(lf.video, lf.frame_idx, lf.training_instances) - for lf in labels.user_labeled_frames - ] - ) - return cls(labels=user_labels) + obj = cls.from_user_labeled_frames(labels) + obj.user_instances_only = True + return obj + + @classmethod + def from_user_labeled_frames(cls, labels: sleap.Labels) -> "LabelsReader": + """Create a `LabelsReader` using the user labeled frames in a `Labels` set. + + Args: + labels: A `sleap.Labels` instance containing user labeled frames. + + Returns: + A `LabelsReader` instance that can create a dataset for pipelining. + + Note that this constructor will load ALL instances in frames that have user + instances. To load only user labeled indices, use + `LabelsReader.from_user_instances`. + """ + return cls(labels=labels, example_indices=labels.user_labeled_frame_inds) + + @classmethod + def from_unlabeled_suggestions(cls, labels: sleap.Labels) -> "LabelsReader": + """Create a `LabelsReader` using the unlabeled suggestions in a `Labels` set. + + Args: + labels: A `sleap.Labels` instance containing unlabeled suggestions. + + Returns: + A `LabelsReader` instance that can create a dataset for pipelining. + """ + inds = labels.get_unlabeled_suggestion_inds() + return cls(labels=labels, example_indices=inds) @classmethod def from_filename( @@ -100,6 +126,21 @@ def tracks(self) -> List[sleap.Track]: """Return the list of tracks that `track_inds` in examples match up with.""" return self.labels.tracks + @property + def max_height_and_width(self) -> Tuple[int, int]: + """Return `(height, width)` that is the maximum of all videos.""" + return max(video.shape[1] for video in self.videos), max( + video.shape[2] for video in self.videos + ) + + @property + def is_from_multi_size_videos(self) -> bool: + """Return `True` if labels contain videos with different sizes.""" + return ( + len(set(v.shape[1] for v in self.videos)) > 1 + or len(set(v.shape[2] for v in self.videos)) > 1 + ) + def make_dataset( self, ds_index: Optional[tf.data.Dataset] = None ) -> tf.data.Dataset: @@ -135,28 +176,40 @@ def make_dataset( specifies the index of the instance track identity. If not specified, in the labels, this is set to -1. """ - # Grab an image to test for the dtype. - test_lf = self.labels[0] - test_image = tf.convert_to_tensor(test_lf.image) - image_dtype = test_image.dtype + # Grab the first image to capture dtype and number of color channels. + first_image = tf.convert_to_tensor(self.labels[0].image) + image_dtype = first_image.dtype + image_num_channels = first_image.shape[-1] def py_fetch_lf(ind): """Local function that will not be autographed.""" lf = self.labels[int(ind.numpy())] + video_ind = np.array(self.videos.index(lf.video)).astype("int32") frame_ind = np.array(lf.frame_idx).astype("int64") + raw_image = lf.image raw_image_size = np.array(raw_image.shape).astype("int32") - instances = np.stack( - [inst.points_array.astype("float32") for inst in lf.instances], axis=0 - ) + + if self.user_instances_only: + insts = lf.user_instances + else: + insts = lf.instances + insts = [inst for inst in insts if len(inst) > 0] + n_instances = len(insts) + n_nodes = len(insts[0].skeleton) if n_instances > 0 else 0 + + instances = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32") + for i, instance in enumerate(insts): + instances[i] = instance.numpy() + skeleton_inds = np.array( - [self.labels.skeletons.index(inst.skeleton) for inst in lf.instances] + [self.labels.skeletons.index(inst.skeleton) for inst in insts] ).astype("int32") track_inds = np.array( [ self.tracks.index(inst.track) if inst.track is not None else -1 - for inst in lf.instances + for inst in insts ] ).astype("int32") n_tracks = np.array(len(self.tracks)).astype("int32") @@ -197,7 +250,14 @@ def fetch_lf(ind): tf.int32, ], ) - image = tf.ensure_shape(image, test_image.shape) + + # Ensure shape with constant or variable height/width, based on whether or + # not the videos have mixed sizes. + if self.is_from_multi_size_videos: + image = tf.ensure_shape(image, (None, None, image_num_channels)) + else: + image = tf.ensure_shape(image, first_image.shape) + instances = tf.ensure_shape(instances, tf.TensorShape([None, None, 2])) skeleton_inds = tf.ensure_shape(skeleton_inds, tf.TensorShape([None])) track_inds = tf.ensure_shape(track_inds, tf.TensorShape([None])) diff --git a/sleap/nn/data/resizing.py b/sleap/nn/data/resizing.py index 3e1c4c87d..56a2c0315 100644 --- a/sleap/nn/data/resizing.py +++ b/sleap/nn/data/resizing.py @@ -58,7 +58,9 @@ def pad_to_stride(image: tf.Tensor, max_stride: int) -> tf.Tensor: paddings = tf.cast([[0, pad_bottom], [0, pad_right], [0, 0]], tf.int32) else: # tf.rank(image) == 4: - paddings = tf.cast([[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]], tf.int32) + paddings = tf.cast( + [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]], tf.int32 + ) image = tf.pad(image, paddings, mode="CONSTANT", constant_values=0) return image @@ -140,7 +142,7 @@ def from_config( pad_to_stride: Optional[int] = None, keep_full_image: bool = False, full_image_key: Text = "full_image", - points_key: Optional[Text] = "instances" + points_key: Optional[Text] = "instances", ) -> "Resizer": """Build an instance of this class from its configuration options. @@ -249,6 +251,182 @@ def resize(example): return ds_output +@attr.s(auto_attribs=True) +class SizeMatcher: + """Data transformer that ensures output images have uniform shape by resizing/padding smaller images. + + Attributes: + image_key: String name of the key containing the images to resize. + scale_key: String name of the key containing the scale of the images. + points_key: String name of the key containing points to adjust for the resizing + operation. + keep_full_image: If True, keeps the (original size) full image in the examples. + This is useful for multi-scale inference. + full_image_key: String name of the key containing the full images. + max_image_height: int The target height to which all smaller images will be resized/padded to. + max_image_width: int The target width to which all smaller images will be resized/padded to. + """ + + image_key: Text = "image" + scale_key: Text = "scale" + points_key: Optional[Text] = "instances" + keep_full_image: bool = False + full_image_key: Text = "full_image" + max_image_height: int = None + max_image_width: int = None + + @classmethod + def from_config( + cls, + config: PreprocessingConfig, + provider: Optional = None, + update_config: bool = True, + image_key: Text = "image", + scale_key: Text = "scale", + keep_full_image: bool = False, + full_image_key: Text = "full_image", + points_key: Optional[Text] = "instances", + ) -> "SizeMatcher": + """Build an instance of this class from configuration. + + Args: + config: An `PreprocessingConfig` instance with the desired parameters. If + `config.resize_and_pad_to_target` is True and 'target_height' / 'target_width' are not set, provider + needs to be set that implements 'max_height_and_width'. + provider: Data provider. + update_config: If True, the input model configuration will be updated with + values inferred from other fields. + image_key: String name of the key containing the images to resize. + scale_key: String name of the key containing the scale of the images. + pad_to_stride: An integer specifying the `pad_to_stride` if + `config.pad_to_stride` is not an explicit integer (e.g., set to None). + keep_full_image: If True, keeps the (original size) full image in the + examples. This is useful for multi-scale inference. + full_image_key: String name of the key containing the full images. + points_key: String name of the key containing points to adjust for the + resizing operation. + Returns: + An instance of this class. + + Raises: + ValueError: If `provider` is not set or does not implement `max_height_and_width`. + """ + if config.resize_and_pad_to_target: + if config.target_height is not None and config.target_width is not None: + max_height = config.target_height + max_width = config.target_width + else: + try: + max_height, max_width = provider.max_height_and_width + except: + raise ValueError( + "target_height / target_width could not be determined" + ) + if update_config: + config.target_height = max_height + config.target_width = max_width + else: + max_height, max_width = None, None + + return cls( + image_key=image_key, + points_key=points_key, + scale_key=scale_key, + keep_full_image=keep_full_image, + full_image_key=full_image_key, + max_image_height=max_height, + max_image_width=max_width, + ) + + @property + def input_keys(self) -> List[Text]: + """Return the keys that incoming elements are expected to have.""" + input_keys = [self.image_key, self.scale_key] + if self.points_key is not None: + input_keys.append(self.points_key) + return input_keys + + @property + def output_keys(self) -> List[Text]: + """Return the keys that outgoing elements will have.""" + output_keys = self.input_keys + if self.keep_full_image: + output_keys.append(self.full_image_key) + return output_keys + + def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset: + """Transform a dataset with potentially different size images into one with equal sized images. + + Args: + ds_input: A dataset with the image specified in the `image_key` attribute, + points specified in the `points_key` attribute, and the "scale" key for + tracking scaling transformations. + + Returns: + A `tf.data.Dataset` with elements containing the same images and points of equal size. + + If the `keep_full_image` attribute is True, a key specified by + `full_image_key` will be added with the to the example containing the image + before any processing. + """ + + # mapping function: match to max height width by resizing and padding bottom/right accordingly + def resize_and_pad(example): + image = example[self.image_key] + if self.keep_full_image: + example[self.full_image_key] = image + + current_shape = tf.shape(image) + + # Only apply this transform if image shape differs from target + if ( + current_shape[-3] != self.max_image_height + or current_shape[-2] != self.max_image_width + ): + # Calculate target height and width for resizing the image (no padding yet) + hratio = self.max_image_height / tf.cast(current_shape[-3], tf.float32) + wratio = self.max_image_width / tf.cast(current_shape[-2], tf.float32) + if hratio > wratio: + # The bottleneck is width, scale to fit width first then pad to height + target_height = tf.cast( + tf.cast(current_shape[-3], tf.float32) * wratio, tf.int32 + ) + target_width = self.max_image_width + example[self.scale_key] = example[self.scale_key] * wratio + else: + # The bottleneck is height, scale to fit height first then pad to width + target_height = self.max_image_height + target_width = tf.cast( + tf.cast(current_shape[-2], tf.float32) * hratio, tf.int32 + ) + example[self.scale_key] = example[self.scale_key] * hratio + # Resize the image to fill one of the dimensions by preserving aspect ratio + image = tf.image.resize_with_pad( + image, target_height=target_height, target_width=target_width + ) + # Pad the image on bottom/right with zeroes to match specified dimensions + image = tf.image.pad_to_bounding_box( + image, + offset_height=0, + offset_width=0, + target_height=self.max_image_height, + target_width=self.max_image_width, + ) + example[self.image_key] = tf.cast(image, example[self.image_key].dtype) + # Scale the instance points accordingly + if self.points_key: + example[self.points_key] = ( + example[self.points_key] * example[self.scale_key] + ) + + return example + + ds_output = ds_input.map( + resize_and_pad, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + return ds_output + + @attr.s(auto_attribs=True) class PointsRescaler: """Transformer to apply or invert scaling operations on points.""" diff --git a/sleap/nn/data/training.py b/sleap/nn/data/training.py index 1003107f4..cb6177375 100644 --- a/sleap/nn/data/training.py +++ b/sleap/nn/data/training.py @@ -7,6 +7,58 @@ from sleap.nn.data.utils import expand_to_rank, ensure_list import attr from typing import List, Text, Optional, Any, Union, Dict, Tuple, Sequence +from sklearn.model_selection import train_test_split + + +def split_labels_train_val( + labels: sleap.Labels, validation_fraction: float +) -> Tuple[sleap.Labels, List[int], sleap.Labels, List[int]]: + """Make a train/validation split from a labels dataset. + + Args: + labels: A `sleap.Labels` dataset with labeled frames. + validation_fraction: Fraction of frames to use for validation. + + Returns: + A tuple of `(labels_train, idx_train, labels_val, idx_val)`. + + `labels_train` and `labels_val` are `sleap.Label` objects containing the + selected frames for each split. Their `videos`, `tracks` and `provenance` + attributes are identical to `labels` even if the split does not contain + instances with a particular video or track. + + `idx_train` and `idx_val` are list indices of the labeled frames within the + input labels that were assigned to each split, i.e.: + + `labels[idx_train] == labels_train[:]` + + If there is only one labeled frame in `labels`, both of the labels will contain + the same frame. + + If `validation_fraction` would result in fewer than one label for either split, + it will be rounded to ensure there is at least one label in each. + """ + if len(labels) == 1: + return labels, [0], labels, [0] + + # Split indices. + n_val = round(len(labels) * validation_fraction) + n_val = max(min(n_val, len(labels) - 1), 1) + + idx_train, idx_val = train_test_split(list(range(len(labels))), test_size=n_val) + + # Create labels and keep original metadata. + labels_train = sleap.Labels(labels[idx_train]) + labels_train.videos = labels.videos + labels_train.tracks = labels.tracks + labels_train.provenance = labels.provenance + + labels_val = sleap.Labels(labels[idx_val]) + labels_val.videos = labels.videos + labels_val.tracks = labels.tracks + labels_val.provenance = labels.provenance + + return labels_train, idx_train, labels_val, idx_val def split_labels( @@ -181,6 +233,7 @@ def transform_dataset(self, ds_input: tf.data.Dataset) -> tf.data.Dataset: A dataset that generates examples with the tensors in `input_keys` mapped to keys in `output_keys` according to the structure in `key_maps`. """ + def map_keys(example): """Local processing function for dataset mapping.""" output_keys = [] diff --git a/sleap/nn/data/utils.py b/sleap/nn/data/utils.py index 08ff59542..16166f957 100644 --- a/sleap/nn/data/utils.py +++ b/sleap/nn/data/utils.py @@ -96,16 +96,29 @@ def describe_tensors( Returns: String description if `return_description` is `True`, otherwise `None`. """ + if isinstance(example, (tuple, list)): + return describe_tensors( + {f"x[{i}]": v for i, v in enumerate(example)}, + return_description=return_description, + ) + desc = [] key_length = max(len(k) for k in example.keys()) for key, val in example.items(): - dtype = str(val.dtype) if isinstance(val.dtype, np.dtype) else repr(val.dtype) - desc.append( - f"{key.rjust(key_length)}: type={type(val).__name__}, " - f"shape={val.shape}, " - f"dtype={dtype}, " - f"device={val.device if hasattr(val, 'device') else 'N/A'}" - ) + key_desc = f"{key.rjust(key_length)}: " + if isinstance(val, (tuple, list, dict)): + key_desc += describe_tensors(val, return_description=True) + else: + dtype = ( + str(val.dtype) if isinstance(val.dtype, np.dtype) else repr(val.dtype) + ) + key_desc += ( + f"type={type(val).__name__}, " + f"shape={val.shape}, " + f"dtype={dtype}, " + f"device={val.device if hasattr(val, 'device') else 'N/A'}" + ) + desc.append(key_desc) desc = "\n".join(desc) if return_description: diff --git a/sleap/nn/evals.py b/sleap/nn/evals.py index eb10edefb..a3296108c 100644 --- a/sleap/nn/evals.py +++ b/sleap/nn/evals.py @@ -32,14 +32,15 @@ CentroidsHeadConfig, CenteredInstanceConfmapsHeadConfig, MultiInstanceConfig, - MultiClassConfig, + MultiClassBottomUpConfig, + MultiClassTopDownConfig, SingleInstanceConfmapsHeadConfig, ) from sleap.nn.model import Model from sleap.nn.data.pipelines import LabelsReader from sleap.nn.inference import ( - TopdownPredictor, - BottomupPredictor, + TopDownPredictor, + BottomUpPredictor, BottomUpMultiClassPredictor, SingleInstancePredictor, ) @@ -673,31 +674,39 @@ def evaluate_model( # Setup predictor for evaluation. head_config = cfg.model.heads.which_oneof() if isinstance(head_config, CentroidsHeadConfig): - predictor = TopdownPredictor( + predictor = TopDownPredictor( centroid_config=cfg, centroid_model=model, confmap_config=None, confmap_model=None, ) elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig): - predictor = TopdownPredictor( + predictor = TopDownPredictor( centroid_config=None, centroid_model=None, confmap_config=cfg, confmap_model=model, ) elif isinstance(head_config, MultiInstanceConfig): - predictor = sleap.nn.inference.BottomupPredictor( + predictor = sleap.nn.inference.BottomUpPredictor( bottomup_config=cfg, bottomup_model=model ) - elif isinstance(head_config, MultiClassConfig): - predictor = sleap.nn.inference.BottomUpMultiClassPredictor( - config=cfg, model=model - ) elif isinstance(head_config, SingleInstanceConfmapsHeadConfig): predictor = sleap.nn.inference.SingleInstancePredictor( confmap_config=cfg, confmap_model=model ) + elif isinstance(head_config, MultiClassBottomUpConfig): + predictor = sleap.nn.inference.BottomUpMultiClassPredictor( + config=cfg, + model=model, + ) + elif isinstance(head_config, MultiClassTopDownConfig): + predictor = sleap.nn.inference.TopDownMultiClassPredictor( + centroid_config=None, + centroid_model=None, + confmap_config=cfg, + confmap_model=model, + ) else: raise ValueError("Unrecognized model type:", head_config) @@ -720,7 +729,9 @@ def evaluate_model( logger.info("Saved predictions: %s", labels_pr_path) if metrics is not None: - metrics_path = os.path.join(cfg.outputs.run_path, f"metrics.{split_name}.npz") + metrics_path = os.path.join( + cfg.outputs.run_path, f"metrics.{split_name}.npz" + ) np.savez_compressed(metrics_path, **{"metrics": metrics}) logger.info("Saved metrics: %s", metrics_path) diff --git a/sleap/nn/heads.py b/sleap/nn/heads.py index 05804498c..5c463cf70 100644 --- a/sleap/nn/heads.py +++ b/sleap/nn/heads.py @@ -1,7 +1,9 @@ """Model head definitions for defining model output types.""" +import tensorflow as tf import attr from typing import Optional, Text, List, Sequence, Tuple, Union +from abc import ABC, abstractmethod from sleap.nn.config import ( CentroidsHeadConfig, @@ -10,12 +12,67 @@ MultiInstanceConfmapsHeadConfig, PartAffinityFieldsHeadConfig, ClassMapsHeadConfig, + ClassVectorsHeadConfig, ) @attr.s(auto_attribs=True) -class SingleInstanceConfmapsHead: - """Head for specifying single instance confidence maps.""" +class Head(ABC): + """Base class for model output heads.""" + + output_stride: int = 1 + loss_weight: float = 1.0 + + @property + @abstractmethod + def channels(self) -> int: + """Return the number of channels in the tensor output by this head.""" + pass + + @property + def activation(self) -> str: + """Return the activation function of the head output layer.""" + return "linear" + + @property + def loss_function(self) -> str: + """Return the name of the loss function to use for this head.""" + return "mse" + + def make_head(self, x_in: tf.Tensor, name: Optional[Text] = None) -> tf.Tensor: + """Make head output tensor from input feature tensor. + + Args: + x_in: An input `tf.Tensor`. + name: If provided, specifies the name of the output layer. If not (the + default), uses the name of the head as the layer name. + + Returns: + A `tf.Tensor` with the correct shape for the head. + """ + if name is None: + name = f"{type(self).__name__}" + return tf.keras.layers.Conv2D( + filters=self.channels, + kernel_size=1, + strides=1, + padding="same", + activation=self.activation, + name=name, + )(x_in) + + +@attr.s(auto_attribs=True) +class SingleInstanceConfmapsHead(Head): + """Head for specifying single instance confidence maps. + + Attributes: + part_names: List of strings specifying the part names associated with channels. + sigma: Spread of the confidence maps. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ part_names: List[Text] sigma: float = 5.0 @@ -52,12 +109,22 @@ def from_config( part_names=part_names, sigma=config.sigma, output_stride=config.output_stride, + loss_weight=config.loss_weight, ) @attr.s(auto_attribs=True) -class CentroidConfmapsHead: - """Head for specifying instance centroid confidence maps.""" +class CentroidConfmapsHead(Head): + """Head for specifying instance centroid confidence maps. + + Attributes: + anchor_part: Name of the part to use as an anchor node. If not specified, the + bounding box centroid will be used. + sigma: Spread of the confidence maps. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ anchor_part: Optional[Text] = None sigma: float = 5.0 @@ -83,12 +150,23 @@ def from_config(cls, config: CentroidsHeadConfig) -> "CentroidConfmapsHead": anchor_part=config.anchor_part, sigma=config.sigma, output_stride=config.output_stride, + loss_weight=config.loss_weight, ) @attr.s(auto_attribs=True) -class CenteredInstanceConfmapsHead: - """Head for specifying centered instance confidence maps.""" +class CenteredInstanceConfmapsHead(Head): + """Head for specifying centered instance confidence maps. + + Attributes: + part_names: List of strings specifying the part names associated with channels. + anchor_part: Name of the part to use as an anchor node. If not specified, the + bounding box centroid will be used. + sigma: Spread of the confidence maps. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ part_names: List[Text] anchor_part: Optional[Text] = None @@ -127,12 +205,21 @@ def from_config( anchor_part=config.anchor_part, sigma=config.sigma, output_stride=config.output_stride, + loss_weight=config.loss_weight, ) @attr.s(auto_attribs=True) -class MultiInstanceConfmapsHead: - """Head for specifying multi-instance confidence maps.""" +class MultiInstanceConfmapsHead(Head): + """Head for specifying multi-instance confidence maps. + + Attributes: + part_names: List of strings specifying the part names associated with channels. + sigma: Spread of the confidence maps. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ part_names: List[Text] sigma: float = 5.0 @@ -174,8 +261,16 @@ def from_config( @attr.s(auto_attribs=True) -class PartAffinityFieldsHead: - """Head for specifying multi-instance part affinity fields.""" +class PartAffinityFieldsHead(Head): + """Head for specifying multi-instance part affinity fields. + + Attributes: + edges: List of tuples of `(source, destination)` node names. + sigma: Spread of the part affinity fields. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ edges: Sequence[Tuple[Text, Text]] sigma: float = 15.0 @@ -216,8 +311,16 @@ def from_config( @attr.s(auto_attribs=True) -class ClassMapsHead: - """Head for specifying class identity maps.""" +class ClassMapsHead(Head): + """Head for specifying class identity maps. + + Attributes: + classes: List of string names of the classes. + sigma: Spread of the class maps around each node. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ classes: List[Text] sigma: float = 5.0 @@ -229,6 +332,11 @@ def channels(self) -> int: """Return the number of channels in the tensor output by this head.""" return len(self.classes) + @property + def activation(self) -> str: + """Return the activation function of the head output layer.""" + return "sigmoid" + @classmethod def from_config( cls, @@ -256,6 +364,102 @@ def from_config( ) +@attr.s(auto_attribs=True) +class ClassVectorsHead(Head): + """Head for specifying classification heads. + + Attributes: + classes: List of string names of the classes. + num_fc_layers: Number of fully connected layers after flattening input features. + num_fc_units: Number of units (dimensions) in fully connected layers prior to + classification output. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ + + classes: List[Text] + num_fc_layers: int = 1 + num_fc_units: int = 64 + global_pool: bool = True + output_stride: int = 1 + loss_weight: float = 1.0 + + @property + def channels(self) -> int: + """Return the number of channels in the tensor output by this head.""" + return len(self.classes) + + @property + def activation(self) -> str: + """Return the activation function of the head output layer.""" + return "softmax" + + @property + def loss_function(self) -> str: + """Return the name of the loss function to use for this head.""" + return "categorical_crossentropy" + + @classmethod + def from_config( + cls, + config: ClassVectorsHeadConfig, + classes: Optional[List[Text]] = None, + ) -> "ClassVectorsHead": + """Create this head from a set of configurations. + + Attributes: + config: A `ClassVectorsHeadConfig` instance specifying the head parameters. + classes: List of string names of the classes that this head will predict. + This must be set if the `classes` attribute of the configuration is not + set. + + Returns: + The instantiated head with the specified configuration options. + """ + if config.classes is not None: + classes = config.classes + return cls( + classes=classes, + num_fc_layers=config.num_fc_layers, + num_fc_units=config.num_fc_units, + global_pool=config.global_pool, + output_stride=config.output_stride, + loss_weight=config.loss_weight, + ) + + def make_head(self, x_in: tf.Tensor, name: Optional[Text] = None) -> tf.Tensor: + """Make head output tensor from input feature tensor. + + Args: + x_in: An input `tf.Tensor`. + name: If provided, specifies the name of the output layer. If not (the + default), uses the name of the head as the layer name. + + Returns: + A `tf.Tensor` with the correct shape for the head. + """ + if name is None: + name = f"{type(self).__name__}" + x = x_in + if self.global_pool: + x = tf.keras.layers.GlobalMaxPool2D(name="pre_classification_global_pool")( + x + ) + x = tf.keras.layers.Flatten(name="pre_classification_flatten")(x) + for i in range(self.num_fc_layers): + x = tf.keras.layers.Dense( + self.num_fc_units, name=f"pre_classification{i}_fc" + )(x) + x = tf.keras.layers.Activation("relu", name=f"pre_classification{i}_relu")( + x + ) + x = tf.keras.layers.Dense(self.channels, activation=self.activation, name=name)( + x + ) + return x + + ConfmapConfig = Union[ CentroidsHeadConfig, SingleInstanceConfmapsHeadConfig, @@ -265,12 +469,21 @@ def from_config( @attr.s(auto_attribs=True) -class OffsetRefinementHead: - """Head for specifying offset refinement maps.""" +class OffsetRefinementHead(Head): + """Head for specifying offset refinement maps. + + Attributes: + part_names: List of strings specifying the part names associated with channels. + sigma_threshold: Threshold of confidence map values to use for defining the + boundary of the offset maps. + output_stride: Stride of the output head tensor. The input tensor is expected to + be at the same stride. + loss_weight: Weight of the loss term for this head during optimization. + """ part_names: List[Text] - output_stride: int = 1 sigma_threshold: float = 0.2 + output_stride: int = 1 loss_weight: float = 1.0 @property @@ -308,7 +521,7 @@ def from_config( part_names = [config.anchor_part] return cls( part_names=part_names, - output_stride=config.output_stride, sigma_threshold=sigma_threshold, + output_stride=config.output_stride, loss_weight=loss_weight, ) diff --git a/sleap/nn/identity.py b/sleap/nn/identity.py index 7d7b36248..61ebcef21 100644 --- a/sleap/nn/identity.py +++ b/sleap/nn/identity.py @@ -46,7 +46,7 @@ def group_class_peaks( algorithm. Peaks that are assigned to classes that are not the highest probability for each class are removed from the matches. - See also: classify_peaks + See also: classify_peaks_from_maps, classify_peaks_from_vectors """ peak_sample_inds = tf.cast(peak_sample_inds, tf.int32) peak_channel_inds = tf.cast(peak_channel_inds, tf.int32) @@ -94,7 +94,7 @@ def group_class_peaks( return peak_inds, class_inds -def classify_peaks( +def classify_peaks_from_maps( class_maps: tf.Tensor, peak_points: tf.Tensor, peak_vals: tf.Tensor, @@ -177,3 +177,74 @@ def classify_peaks( ) return points, point_vals, class_probs + + +def classify_peaks_from_vectors( + peak_points: tf.Tensor, + peak_vals: tf.Tensor, + peak_class_probs: tf.Tensor, + crop_sample_inds: tf.Tensor, + n_samples: int, +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + """Group peaks by classification probabilities. + + This is used in top-down classification models. + + Args: + peak_points: + peak_vals: + peak_class_probs: + crop_sample_inds: + n_samples: Number of samples in the batch. + + Returns: + A tuple of `(points, point_vals, class_probs)`. + + `points`: Class-grouped peaks as a `tf.Tensor` of dtype `tf.float32` and shape + `(n_samples, n_classes, n_channels, 2)`. Missing points will be denoted by + NaNs. + + `point_vals`: The confidence map values for each point as a `tf.Tensor` of dtype + `tf.float32` and shape `(n_samples, n_classes, n_channels)`. + + `class_probs`: Classification probabilities for matched points as a `tf.Tensor` + of dtype `tf.float32` and shape `(n_samples, n_classes, n_channels)`. + """ + crop_sample_inds = tf.cast(crop_sample_inds, tf.int32) + n_samples = tf.cast(n_samples, tf.int32) + n_channels = tf.shape(peak_points)[1] + n_instances = tf.shape(peak_class_probs)[1] + + peak_inds, class_inds = group_class_peaks( + peak_class_probs, + crop_sample_inds, + tf.zeros_like(crop_sample_inds), + n_samples, + 1, + ) + + # Assign the results to fixed size tensors. + subs = tf.stack( + [ + tf.gather(crop_sample_inds, peak_inds), + class_inds, + ], + axis=1, + ) + points = tf.tensor_scatter_nd_update( + tf.fill([n_samples, n_instances, n_channels, 2], np.nan), + subs, + tf.gather(peak_points, peak_inds), + ) + point_vals = tf.tensor_scatter_nd_update( + tf.fill([n_samples, n_instances, n_channels], np.nan), + subs, + tf.gather(peak_vals, peak_inds), + ) + class_probs = tf.tensor_scatter_nd_update( + tf.fill([n_samples, n_instances], np.nan), + subs, + tf.gather_nd(peak_class_probs, tf.stack([peak_inds, class_inds], axis=1)), + ) + + return points, point_vals, class_probs diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 0c3f0ad26..d0b006361 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -22,23 +22,32 @@ """ import attr +import argparse import logging import warnings import os -import time +import tempfile +import platform +import shutil +import atexit +import rich.progress +from collections import deque +import json +from time import time +from datetime import datetime +from pathlib import Path + from abc import ABC, abstractmethod -from typing import Text, Optional, List, Dict, Union, Iterator +from typing import Text, Optional, List, Dict, Union, Iterator, Tuple import tensorflow as tf import numpy as np import sleap -from sleap import util from sleap.nn.config import TrainingJobConfig from sleap.nn.model import Model -from sleap.nn.tracking import Tracker, run_tracker +from sleap.nn.tracking import Tracker from sleap.nn.paf_grouping import PAFScorer -from sleap.nn.data.grouping import group_examples_iter from sleap.nn.data.pipelines import ( Provider, Pipeline, @@ -47,71 +56,165 @@ Normalizer, Resizer, Prefetcher, - LambdaFilter, KerasModelPredictor, - LocalPeakFinder, - PredictedInstanceCropper, - InstanceCentroidFinder, - InstanceCropper, - GlobalPeakFinder, - MockGlobalPeakFinder, - KeyFilter, - KeyRenamer, - KeyDeviceMover, - PredictedCenterInstanceNormalizer, - PointsRescaler, ) logger = logging.getLogger(__name__) -def safely_generate(ds: tf.data.Dataset, progress: bool = True): - """Yields examples from dataset, catching and logging exceptions.""" - # Unsafe generating: - # for example in ds: - # yield example - - ds_iter = iter(ds) - - i = 0 - wall_t0 = time.time() - done = False - while not done: - try: - next_val = next(ds_iter) - yield next_val - except StopIteration: - done = True - except Exception as e: - logger.info(f"ERROR in sample index {i}") - logger.info(e) - logger.info("") - finally: - if not done: - i += 1 - - # Show the current progress (frames, time, fps) - if progress: - if (i and i % 1000 == 0) or done: - elapsed_time = time.time() - wall_t0 - logger.info( - f"Finished {i} examples in {elapsed_time:.2f} seconds " - "(inference + postprocessing)" - ) - if elapsed_time: - logger.info(f"examples/s = {i/elapsed_time}") +def get_keras_model_path(path: Text) -> str: + """Utility method for finding the path to a saved Keras model. + Args: + path: Path to a model run folder or job file. -def get_keras_model_path(path: Text) -> Text: + Returns: + Path to `best_model.h5` in the run folder. + """ + # TODO: Move this to TrainingJobConfig or Model? if path.endswith(".json"): path = os.path.dirname(path) return os.path.join(path, "best_model.h5") +class RateColumn(rich.progress.ProgressColumn): + """Renders the progress rate.""" + + def render(self, task: "Task") -> rich.progress.Text: + """Show progress rate.""" + speed = task.speed + if speed is None: + return rich.progress.Text("?", style="progress.data.speed") + return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") + + @attr.s(auto_attribs=True) class Predictor(ABC): """Base interface class for predictors.""" + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="rich", + kw_only=True, + ) + report_rate: float = attr.ib(default=2.0, kw_only=True) + model_paths: List[str] = attr.ib(factory=list, kw_only=True) + + @property + def report_period(self) -> float: + """Time between progress reports in seconds.""" + return 1.0 / self.report_rate + + @classmethod + def from_model_paths( + cls, + model_paths: Union[str, List[str]], + peak_threshold: float = 0.2, + integral_refinement: bool = True, + integral_patch_size: int = 5, + batch_size: int = 4, + ) -> "Predictor": + """Create the appropriate `Predictor` subclass from a list of model paths. + + Args: + model_paths: A single or list of trained model paths. + peak_threshold: Minimum confidence map value to consider a peak as valid. + integral_refinement: If `True`, peaks will be refined with integral + regression. If `False`, `"local"`, peaks will be refined with quarter + pixel local gradient offset. This has no effect if the model has an + offset regression head. + integral_patch_size: Size of patches to crop around each rough peak for + integral refinement as an integer scalar. + batch_size: The default batch size to use when loading data for inference. + Higher values increase inference speed at the cost of higher memory + usage. + + Returns: + A subclass of `Predictor`. + + See also: `SingleInstancePredictor`, `TopDownPredictor`, `BottomUpPredictor` + """ + # Read configs and find model types. + if isinstance(model_paths, str): + model_paths = [model_paths] + model_configs = [sleap.load_config(model_path) for model_path in model_paths] + model_paths = [cfg.filename for cfg in model_configs] + model_types = [ + cfg.model.heads.which_oneof_attrib_name() for cfg in model_configs + ] + + if "single_instance" in model_types: + predictor = SingleInstancePredictor.from_trained_models( + model_path=model_paths[model_types.index("single_instance")], + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + batch_size=batch_size, + ) + + elif ( + "centroid" in model_types + or "centered_instance" in model_types + or "multi_class_topdown" in model_types + ): + centroid_model_path = None + if "centroid" in model_types: + centroid_model_path = model_paths[model_types.index("centroid")] + + confmap_model_path = None + if "centered_instance" in model_types: + confmap_model_path = model_paths[model_types.index("centered_instance")] + + td_multiclass_model_path = None + if "multi_class_topdown" in model_types: + td_multiclass_model_path = model_paths[ + model_types.index("multi_class_topdown") + ] + + if td_multiclass_model_path is not None: + predictor = TopDownMultiClassPredictor.from_trained_models( + centroid_model_path=centroid_model_path, + confmap_model_path=td_multiclass_model_path, + batch_size=batch_size, + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + ) + else: + predictor = TopDownPredictor.from_trained_models( + centroid_model_path=centroid_model_path, + confmap_model_path=confmap_model_path, + batch_size=batch_size, + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + ) + + elif "multi_instance" in model_types: + predictor = BottomUpPredictor.from_trained_models( + model_path=model_paths[model_types.index("multi_instance")], + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + batch_size=batch_size, + ) + + elif "multi_class_bottomup" in model_types: + predictor = BottomUpMultiClassPredictor.from_trained_models( + model_path=model_paths[model_types.index("multi_class_bottomup")], + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, + batch_size=batch_size, + ) + + else: + raise ValueError( + "Could not create predictor from model paths:" + "\n".join(model_paths) + ) + predictor.model_paths = model_paths + return predictor + @classmethod @abstractmethod def from_trained_models(cls, *args, **kwargs): @@ -119,14 +222,11 @@ def from_trained_models(cls, *args, **kwargs): def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: """Make a data loading pipeline. - Args: data_provider: If not `None`, the pipeline will be created with an instance of a `sleap.pipelines.Provider`. - Returns: The created `sleap.pipelines.Pipeline` with batching and prefetching. - Notes: This method also updates the class attribute for the pipeline and will be called automatically when predicting on data from a new source. @@ -174,15 +274,12 @@ def _predict_generator( # Update the data provider source. self.pipeline.providers = [data_provider] - # Loop over data batches. - for ex in self.pipeline.make_dataset(): + def process_batch(ex): # Run inference on current batch. - # preds = self.inference_model.predict(ex) preds = self.inference_model.predict_on_batch(ex) + + # Add model outputs to the input data example. ex.update(preds) - # ex["instance_peaks"] = preds["instance_peaks"] - # ex["instance_peak_vals"] = preds["instance_peak_vals"] - # ex["instance_scores"] = preds["instance_scores"] # Convert to numpy arrays if not already. if isinstance(ex["video_ind"], tf.Tensor): @@ -190,7 +287,82 @@ def _predict_generator( if isinstance(ex["frame_ind"], tf.Tensor): ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - yield ex + return ex + + # Loop over data batches with optional progress reporting. + if self.verbosity == "rich": + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Predicting...", total=len(data_provider)) + last_report = time() + for ex in self.pipeline.make_dataset(): + ex = process_batch(ex) + progress.update(task, advance=len(ex["frame_ind"])) + + # Handle refreshing manually to support notebooks. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + progress.refresh() + + # Return results. + yield ex + + elif self.verbosity == "json": + n_processed = 0 + n_total = len(data_provider) + n_recent = deque(maxlen=30) + elapsed_recent = deque(maxlen=30) + last_report = time() + t0_all = time() + t0_batch = time() + for ex in self.pipeline.make_dataset(): + # Process batch of examples. + ex = process_batch(ex) + + # Track timing and progress. + elapsed_batch = time() - t0_batch + t0_batch = time() + n_batch = len(ex["frame_ind"]) + n_processed += n_batch + elapsed_all = time() - t0_all + + # Compute recent rate. + n_recent.append(n_batch) + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + # Report. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + # Return results. + yield ex + else: + for ex in self.pipeline.make_dataset(): + yield process_batch(ex) def predict( self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True @@ -230,59 +402,7 @@ def predict( return list(generator) -@attr.s(auto_attribs=True) -class MockPredictor(Predictor): - labels: sleap.Labels - - @classmethod - def from_trained_models(cls, labels_path: Text): - labels = sleap.Labels.load_file(labels_path) - return cls(labels=labels) - - def make_pipeline(self): - pass - - def predict(self, data_provider: Provider): - - prediction_video = None - - # Try to match specified video by its full path - prediction_video_path = os.path.abspath(data_provider.video.filename) - for video in self.labels.videos: - if os.path.abspath(video.filename) == prediction_video_path: - prediction_video = video - break - - if prediction_video is None: - # Try to match on filename (without path) - prediction_video_path = os.path.basename(data_provider.video.filename) - for video in self.labels.videos: - if os.path.basename(video.filename) == prediction_video_path: - prediction_video = video - break - - if prediction_video is None: - # Default to first video in labels file - prediction_video = self.labels.videos[0] - - # Get specified frames from labels file (or use None for all frames) - frame_idx_list = ( - list(data_provider.example_indices) - if data_provider.example_indices - else None - ) - - frames = self.labels.find(video=prediction_video, frame_idx=frame_idx_list) - - # Run tracker as specified - if self.tracker: - frames = run_tracker(tracker=self.tracker, frames=frames) - self.tracker.final_pass(frames) - - # Return frames (there are no "raw" predictions we could return) - return frames - - +# TODO: Rewrite this class. @attr.s(auto_attribs=True) class VisualPredictor(Predictor): """Predictor class for generating the visual output of model.""" @@ -352,6 +472,42 @@ def make_pipeline(self): self.pipeline = pipeline + def safely_generate(self, ds: tf.data.Dataset, progress: bool = True): + """Yields examples from dataset, catching and logging exceptions.""" + # Unsafe generating: + # for example in ds: + # yield example + + ds_iter = iter(ds) + + i = 0 + wall_t0 = time() + done = False + while not done: + try: + next_val = next(ds_iter) + yield next_val + except StopIteration: + done = True + except Exception as e: + logger.info(f"ERROR in sample index {i}") + logger.info(e) + logger.info("") + finally: + if not done: + i += 1 + + # Show the current progress (frames, time, fps) + if progress: + if (i and i % 1000 == 0) or done: + elapsed_time = time() - wall_t0 + logger.info( + f"Finished {i} examples in {elapsed_time:.2f} seconds " + "(inference + postprocessing)" + ) + if elapsed_time: + logger.info(f"examples/s = {i/elapsed_time}") + def predict_generator(self, data_provider: Provider): if self.pipeline is None: # Pass in data provider when mocking one of the models. @@ -360,7 +516,7 @@ def predict_generator(self, data_provider: Provider): self.pipeline.providers = [data_provider] # Yield each example from dataset, catching and logging exceptions - return safely_generate(self.pipeline.make_dataset()) + return self.safely_generate(self.pipeline.make_dataset()) def predict(self, data_provider: Provider): generator = self.predict_generator(data_provider) @@ -1014,7 +1170,6 @@ def from_trained_models( obj._initialize_inference_model() return obj - def _make_labeled_frames_from_generator( self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider ) -> List[sleap.LabeledFrame]: @@ -1580,7 +1735,7 @@ def call( @attr.s(auto_attribs=True) -class TopdownPredictor(Predictor): +class TopDownPredictor(Predictor): """Top-down multi-instance predictor. This high-level class handles initialization, preprocessing and tracking using a @@ -1686,7 +1841,7 @@ def from_trained_models( peak_threshold: float = 0.2, integral_refinement: bool = True, integral_patch_size: int = 5, - ) -> "TopdownPredictor": + ) -> "TopDownPredictor": """Create predictor from saved models. Args: @@ -1709,7 +1864,7 @@ def from_trained_models( integral refinement as an integer scalar. Returns: - An instance of `TopdownPredictor` with the loaded models. + An instance of `TopDownPredictor` with the loaded models. One of the two models can be left as `None` to perform inference with ground truth data. This will only work with `LabelsReader` as the provider. @@ -2176,7 +2331,7 @@ def call(self, example): @attr.s(auto_attribs=True) -class BottomupPredictor(Predictor): +class BottomUpPredictor(Predictor): """Bottom-up multi-instance predictor. This high-level class handles initialization, preprocessing and tracking using a @@ -2253,7 +2408,7 @@ def from_trained_models( peak_threshold: float = 0.2, integral_refinement: bool = True, integral_patch_size: int = 5, - ) -> "BottomupPredictor": + ) -> "BottomUpPredictor": """Create predictor from a saved model. Args: @@ -2273,7 +2428,7 @@ def from_trained_models( integral refinement as an integer scalar. Returns: - An instance of `BottomupPredictor` with the loaded model. + An instance of `BottomUpPredictor` with the loaded model. """ # Load bottomup model. bottomup_config = TrainingJobConfig.load_json(model_path) @@ -2584,7 +2739,7 @@ def call(self, data): predicted_instances, predicted_peak_scores, predicted_instance_scores, - ) = sleap.nn.identity.classify_peaks( + ) = sleap.nn.identity.classify_peaks_from_maps( class_maps, peaks, peak_vals, @@ -2746,7 +2901,7 @@ def from_trained_models( integral refinement as an integer scalar. Returns: - An instance of `BottomupPredictor` with the loaded model. + An instance of `BottomUpPredictor` with the loaded model. """ # Load bottomup model. config = TrainingJobConfig.load_json(model_path) @@ -2791,8 +2946,11 @@ def _make_labeled_frames_from_generator( if tracks is None: if hasattr(data_provider, "tracks"): tracks = data_provider.tracks - elif self.config.model.heads.multi_class.class_maps.classes is not None: - names = self.config.model.heads.multi_class.class_maps.classes + elif ( + self.config.model.heads.multi_class_bottomup.class_maps.classes + is not None + ): + names = self.config.model.heads.multi_class_bottomup.class_maps.classes tracks = [sleap.Track(name=n, spawned_on=0) for n in names] # Loop over batches. @@ -2840,307 +2998,630 @@ def _make_labeled_frames_from_generator( return predicted_frames -CLI_PREDICTORS = { - "topdown": TopdownPredictor, - "bottomup": BottomupPredictor, - "bottomup_multiclass": BottomUpMultiClassPredictor, - "single": SingleInstancePredictor, -} - - -def make_cli_parser(): - import argparse - from sleap.util import frame_list - - parser = argparse.ArgumentParser() - - # Add args for entire pipeline - parser.add_argument( - "video_path", type=str, nargs="?", default="", help="Path to video file" - ) - parser.add_argument( - "-m", - "--model", - dest="models", - action="append", - help="Path to trained model directory (with training_config.json). " - "Multiple models can be specified, each preceded by --model.", - ) - - parser.add_argument( - "--frames", - type=frame_list, - default="", - help="List of frames to predict. Either comma separated list (e.g. 1,2,3) or " - "a range separated by hyphen (e.g. 1-3, for 1,2,3). (default is entire video)", - ) - parser.add_argument( - "--only-labeled-frames", - action="store_true", - default=False, - help="Only run inference on labeled frames (when running on labels dataset file).", - ) - parser.add_argument( - "--only-suggested-frames", - action="store_true", - default=False, - help="Only run inference on suggested frames (when running on labels dataset file).", - ) - parser.add_argument( - "-o", - "--output", - type=str, - default=None, - help="The output filename to use for the predicted data.", - ) - parser.add_argument( - "--labels", - type=str, - default=None, - help="Path to labels dataset file (for inference on multiple videos or for re-tracking pre-existing predictions).", - ) - - # TODO: better video parameters +class TopDownMultiClassFindPeaks(InferenceLayer): + """Keras layer that predicts and classifies peaks from images using a trained model. - parser.add_argument( - "--video.dataset", type=str, default="", help="The dataset for HDF5 videos." - ) + This layer encapsulates all of the inference operations required for generating + predictions from a centered instance confidence map and multi-class model. This + includes preprocessing, model forward pass, peak finding, coordinate adjustment, and + classification. - parser.add_argument( - "--video.input_format", - type=str, - default="", - help="The input_format for HDF5 videos.", - ) + Attributes: + keras_model: A `tf.keras.Model` that accepts rank-4 images as input and predicts + rank-4 confidence maps as output. This should be a model that is trained on + centered instance confidence maps and classification. + input_scale: Float indicating if the images should be resized before being + passed to the model. + output_stride: Output stride of the model, denoting the scale of the output + confidence maps relative to the images (after input scaling). This is used + for adjusting the peak coordinates to the image grid. This will be inferred + from the model shapes if not provided. + peak_threshold: Minimum confidence map value to consider a global peak as valid. + refinement: If `None`, returns the grid-aligned peaks with no refinement. If + `"integral"`, peaks will be refined with integral regression. If `"local"`, + peaks will be refined with quarter pixel local gradient offset. This has no + effect if the model has an offset regression head. + integral_patch_size: Size of patches to crop around each rough peak for integral + refinement as an integer scalar. + return_confmaps: If `True`, the confidence maps will be returned together with + the predicted peaks. This will result in slower inference times since the + data must be copied off of the GPU, but is useful for visualizing the raw + output of the model. + return_class_vectors: If `True`, the classification probabilities will be + returned together with the predicted peaks. This will not line up with the + grouped instances, for which the associtated class probabilities will always + be returned in `"instance_scores"`. + confmaps_ind: Index of the output tensor of the model corresponding to + confidence maps. If `None` (the default), this will be detected + automatically by searching for the first tensor that contains + `"CenteredInstanceConfmapsHead"` in its name. + offsets_ind: Index of the output tensor of the model corresponding to + offset regression maps. If `None` (the default), this will be detected + automatically by searching for the first tensor that contains + `"OffsetRefinementHead"` in its name. If the head is not present, the method + specified in the `refinement` attribute will be used. + class_vectors_ind: Index of the output tensor of the model corresponding to the + classification vectors. If `None` (the default), this will be detected + automatically by searching for the first tensor that contains + `"ClassVectorsHead"` in its name. + """ - device_group = parser.add_mutually_exclusive_group(required=False) - device_group.add_argument( - "--cpu", - action="store_true", - help="Run inference only on CPU. If not specified, will use available GPU.", - ) - device_group.add_argument( - "--first-gpu", - action="store_true", - help="Run inference on the first GPU, if available.", - ) - device_group.add_argument( - "--last-gpu", - action="store_true", - help="Run inference on the last GPU, if available.", - ) - device_group.add_argument( - "--gpu", type=int, default=0, help="Run inference on the i-th GPU specified." - ) + def __init__( + self, + keras_model: tf.keras.Model, + input_scale: float = 1.0, + output_stride: Optional[int] = None, + peak_threshold: float = 0.2, + refinement: Optional[str] = "local", + integral_patch_size: int = 5, + return_confmaps: bool = False, + return_class_vectors: bool = False, + confmaps_ind: Optional[int] = None, + offsets_ind: Optional[int] = None, + class_vectors_ind: Optional[int] = None, + **kwargs, + ): + super().__init__( + keras_model=keras_model, input_scale=input_scale, pad_to_stride=1, **kwargs + ) + self.peak_threshold = peak_threshold + self.refinement = refinement + self.integral_patch_size = integral_patch_size + self.return_confmaps = return_confmaps + self.return_class_vectors = return_class_vectors + self.confmaps_ind = confmaps_ind + self.class_vectors_ind = class_vectors_ind + self.offsets_ind = offsets_ind - # Add args for each predictor class - for predictor_name, predictor_class in CLI_PREDICTORS.items(): - if "peak_threshold" in attr.fields_dict(predictor_class): - # get the default value to show in help string, although we'll - # use None as default so that unspecified vals won't be passed to - # builder. - default_val = attr.fields_dict(predictor_class)["peak_threshold"].default - - parser.add_argument( - f"--{predictor_name}.peak_threshold", - type=float, - default=None, - help=f"Threshold to use when finding peaks in {predictor_class.__name__} (default: {default_val}).", + if self.confmaps_ind is None: + self.confmaps_ind = find_head( + self.keras_model, "CenteredInstanceConfmapsHead" + ) + if self.confmaps_ind is None: + raise ValueError( + "Index of the confidence maps output tensor must be specified if not " + "named 'CenteredInstanceConfmapsHead'." ) - if "batch_size" in attr.fields_dict(predictor_class): - default_val = attr.fields_dict(predictor_class)["batch_size"].default - parser.add_argument( - f"--{predictor_name}.batch_size", - type=int, - default=4, - help=f"Batch size to use for model inference in {predictor_class.__name__} (default: {default_val}).", + if self.class_vectors_ind is None: + self.class_vectors_ind = find_head(self.keras_model, "ClassVectorsHead") + if self.class_vectors_ind is None: + raise ValueError( + "Index of the classifications output tensor must be specified if not " + "named 'ClassVectorsHead'." ) - # Add args for tracking - Tracker.add_cli_parser_args(parser, arg_scope="tracking") + if self.offsets_ind is None: + self.offsets_ind = find_head(self.keras_model, "OffsetRefinementHead") - parser.add_argument( - "--test-pipeline", - default=False, - action="store_true", - help="Test pipeline construction without running anything.", - ) + if output_stride is None: + # Attempt to automatically infer the output stride. + output_stride = get_model_output_stride( + self.keras_model, 0, self.confmaps_ind + ) + self.output_stride = output_stride - return parser + def call( + self, inputs: Union[Dict[str, tf.Tensor], tf.Tensor] + ) -> Dict[str, tf.Tensor]: + """Predict confidence maps and infer peak coordinates. + This layer can be chained with a `CentroidCrop` layer to create a top-down + inference function from full images. -def make_video_readers_from_cli(args) -> List[VideoReader]: - if args.video_path: - # TODO: better support for video params - video_kwargs = dict( - dataset=vars(args).get("video.dataset"), - input_format=vars(args).get("video.input_format"), - ) + Args: + inputs: Instance-centered images as a `tf.Tensor` of shape + `(samples, height, width, channels)` or `tf.RaggedTensor` of shape + `(samples, ?, height, width, channels)` where images are grouped by + sample and may contain a variable number of crops, or a dictionary with + keys: + `"crops"`: Cropped images in either format above. + `"crop_offsets"`: (Optional) Coordinates of the top-left of the crops as + `(x, y)` offsets of shape `(samples, ?, 2)` for adjusting the + predicted peak coordinates. No adjustment is performed if not + provided. + `"centroids"`: (Optional) If provided, will be passed through to the + output. + `"centroid_vals"`: (Optional) If provided, will be passed through to the + output. - video_reader = VideoReader.from_filepath( - filename=args.video_path, example_indices=args.frames, **video_kwargs - ) + Returns: + A dictionary of outputs with keys: - return [video_reader] + `"instance_peaks"`: The predicted peaks for each instance in the batch as a + `tf.Tensor` of shape `(samples, n_classes, nodes, 2)`. Instances will + be ordered by class and will be filled with `NaN` where not found. + `"instance_peak_vals"`: The value of the confidence maps at the predicted + peaks for each instance in the batch as a `tf.Tensor` of shape + `(samples, n_classes, nodes)`. - if args.labels: - # TODO: Replace with LabelsReader. - labels = sleap.Labels.load_file(args.labels) + If provided (e.g., from an input `CentroidCrop` layer), the centroids that + generated the crops will also be included in the keys `"centroids"` and + `"centroid_vals"`. - readers = [] + If the `return_confmaps` attribute is set to `True`, the output will also + contain a key named `"instance_confmaps"` containing a `tf.RaggedTensor` of + shape `(samples, ?, output_height, output_width, nodes)` containing the + confidence maps predicted by the model. - if args.only_labeled_frames: - user_labeled_frames = labels.user_labeled_frames + If the `return_class_vectors` attribe is set to `True`, the output will also + contain a key named `"class_vectors"` containing the full classification + probabilities for all crops. + """ + if isinstance(inputs, dict): + crops = inputs["crops"] else: - user_labeled_frames = [] + # Tensor input provided. We'll infer the extra fields in the expected input + # dictionary. + crops = inputs + inputs = {} - for video in labels.videos: - if args.only_labeled_frames: - frame_indices = [ - lf.frame_idx for lf in user_labeled_frames if lf.video == video - ] - readers.append(VideoReader(video=video, example_indices=frame_indices)) - elif args.only_suggested_frames: - readers.append( - VideoReader( - video=video, example_indices=labels.get_video_suggestions(video) - ) - ) + if isinstance(crops, tf.RaggedTensor): + crops = inputs["crops"] # (samples, ?, height, width, channels) + + # Flatten crops into (n_peaks, height, width, channels) + crop_sample_inds = crops.value_rowids() # (n_peaks,) + samples = crops.nrows() + crops = crops.merge_dims(0, 1) + + else: + if "crop_sample_inds" in inputs: + # Crops provided as a regular tensor, use the metadata are in the input. + samples = inputs["samples"] + crop_sample_inds = inputs["crop_sample_inds"] else: - readers.append(VideoReader(video=video)) + # Assuming crops is (samples, height, width, channels). + samples = tf.shape(crops)[0] + crop_sample_inds = tf.range(samples, dtype=tf.int32) - return readers + # Preprocess inputs (scaling, padding, colorspace, int to float). + crops = self.preprocess(crops) - raise ValueError("You must specify either video_path or labels dataset path.") + # Network forward pass. + out = self.keras_model(crops) + # Sort outputs. + cms = out[self.confmaps_ind] + peak_class_probs = out[self.class_vectors_ind] + offsets = None + if self.offsets_ind is not None: + offsets = out[self.offsets_ind] -def make_predictor_from_paths(paths, **kwargs) -> Predictor: - """Build predictor object from a list of model paths.""" - return make_predictor_from_models(find_heads_for_model_paths(paths), **kwargs) + # Find peaks. + if self.offsets_ind is None: + # Use deterministic refinement. + peak_points, peak_vals = sleap.nn.peak_finding.find_global_peaks( + cms, + threshold=self.peak_threshold, + refinement=self.refinement, + integral_patch_size=self.integral_patch_size, + ) + else: + # Use learned offsets. + ( + peak_points, + peak_vals, + ) = sleap.nn.peak_finding.find_global_peaks_with_offsets( + cms, offsets, threshold=self.peak_threshold + ) + # Adjust for stride and scale. + peak_points = peak_points * self.output_stride + if self.input_scale != 1.0: + # Note: We add 0.5 here to offset TensorFlow's weird image resizing. This + # may not always(?) be the most correct approach. + # See: https://github.com/tensorflow/tensorflow/issues/6720 + peak_points = (peak_points / self.input_scale) + 0.5 -def find_heads_for_model_paths(paths) -> Dict[str, str]: - """Given list of models paths, returns dict with path keyed by head name.""" - trained_model_paths = dict() + # Adjust for crop offsets if provided. + if "crop_offsets" in inputs: + # Flatten (samples, ?, 2) -> (n_peaks, 2). + crop_offsets = inputs["crop_offsets"].merge_dims(0, 1) + peak_points = peak_points + tf.expand_dims(crop_offsets, axis=1) - if paths is None: - return trained_model_paths + # Group peaks from classification probabilities. + points, point_vals, class_probs = sleap.nn.identity.classify_peaks_from_vectors( + peak_points, peak_vals, peak_class_probs, crop_sample_inds, samples + ) - for model_path in paths: - # Load the model config - cfg = TrainingJobConfig.load_json(model_path) + # Build outputs. + outputs = { + "instance_peaks": points, + "instance_peak_vals": point_vals, + "instance_scores": class_probs, + } + if "centroids" in inputs: + outputs["centroids"] = inputs["centroids"] + if "centroids" in inputs: + outputs["centroid_vals"] = inputs["centroid_vals"] + if self.return_confmaps: + cms = tf.RaggedTensor.from_value_rowids( + cms, crop_sample_inds, nrows=samples + ) + outputs["instance_confmaps"] = cms + if self.return_class_vectors: + outputs["class_vectors"] = peak_class_probs + return outputs - # Get the head from the model (i.e., what the model will predict) - key = cfg.model.heads.which_oneof_attrib_name() - # If path is to config file json, then get the path to parent dir - if model_path.endswith(".json"): - model_path = os.path.dirname(model_path) +class TopDownMultiClassInferenceModel(InferenceModel): + """Top-down instance prediction model. - trained_model_paths[key] = model_path + This model encapsulates the top-down approach where instances are first detected by + local peak detection of an anchor point and then cropped. These instance-centered + crops are then passed to an instance peak detector which is trained to detect all + remaining body parts for the instance that is centered within the crop. - return trained_model_paths + Attributes: + centroid_crop: A centroid cropping layer. This can be either `CentroidCrop` or + `CentroidCropGroundTruth`. This layer takes the full image as input and + outputs a set of centroids and cropped boxes. + instance_peaks: A instance peak detection and classification layer, an instance + of `TopDownMultiClassFindPeaks`. This layer takes as input the output of the + centroid cropper and outputs the detected peaks and classes for the + instances within each crop. + """ + def __init__( + self, + centroid_crop: Union[CentroidCrop, CentroidCropGroundTruth], + instance_peaks: TopDownMultiClassFindPeaks, + ): + super().__init__() + self.centroid_crop = centroid_crop + self.instance_peaks = instance_peaks -def make_predictor_from_models( - trained_model_paths: Dict[str, str], - labels_path: Optional[str] = None, - policy_args: Optional[dict] = None, - **kwargs, -) -> Predictor: - """Given dict of paths keyed by head name, returns appropriate predictor.""" + def call( + self, example: Union[Dict[str, tf.Tensor], tf.Tensor] + ) -> Dict[str, tf.Tensor]: + """Predict instances for one batch of images. - def get_relevant_args(key): - if policy_args is not None and key in policy_args: - return policy_args[key] - return dict() + Args: + example: This may be either a single batch of images as a 4-D tensor of + shape `(batch_size, height, width, channels)`, or a dictionary + containing the image batch in the `"images"` key. If using a ground + truth model for either centroid cropping or instance peaks, the full + example from a `Pipeline` is required for providing the metadata. - if "multi_instance" in trained_model_paths: - predictor = BottomupPredictor.from_trained_models( - trained_model_paths["multi_instance"], - **get_relevant_args("bottomup"), - **kwargs, - ) - elif "single_instance" in trained_model_paths: - predictor = SingleInstancePredictor.from_trained_models( - trained_model_paths["single_instance"], - **get_relevant_args("single"), - **kwargs, + Returns: + The predicted instances as a dictionary of tensors with keys: + + `"centroids": (batch_size, n_instances, 2)`: Instance centroids. + `"centroid_vals": (batch_size, n_instances)`: Instance centroid confidence + values. + `"instance_peaks": (batch_size, n_instances, n_nodes, 2)`: Instance skeleton + points. + `"instance_peak_vals": (batch_size, n_instances, n_nodes)`: Confidence + values for the instance skeleton points. + """ + if isinstance(example, tf.Tensor): + example = dict(image=example) + + crop_output = self.centroid_crop(example) + peaks_output = self.instance_peaks(crop_output) + return peaks_output + + +@attr.s(auto_attribs=True) +class TopDownMultiClassPredictor(Predictor): + """Top-down multi-instance predictor with classification. + + This high-level class handles initialization, preprocessing and tracking using a + trained top-down multi-instance classification SLEAP model. + + This should be initialized using the `from_trained_models()` constructor or the + high-level API (`sleap.load_model`). + + Attributes: + centroid_config: The `sleap.nn.config.TrainingJobConfig` containing the metadata + for the trained centroid model. If `None`, ground truth centroids will be + used if available from the data source. + centroid_model: A `sleap.nn.model.Model` instance created from the trained + centroid model. If `None`, ground truth centroids will be used if available + from the data source. + confmap_config: The `sleap.nn.config.TrainingJobConfig` containing the metadata + for the trained centered instance model. If `None`, ground truth instances + will be used if available from the data source. + confmap_model: A `sleap.nn.model.Model` instance created from the trained + centered-instance model. If `None`, ground truth instances will be used if + available from the data source. + inference_model: A `TopDownMultiClassInferenceModel` that wraps a trained + `tf.keras.Model` to implement preprocessing, centroid detection, cropping, + peak finding and classification. + pipeline: A `sleap.nn.data.Pipeline` that loads the data and batches input data. + This will be updated dynamically if new data sources are used. + tracker: A `sleap.nn.tracking.Tracker` that will be called to associate + detections over time. Predicted instances will not be assigned to tracks if + if this is `None`. + batch_size: The default batch size to use when loading data for inference. + Higher values increase inference speed at the cost of higher memory usage. + peak_threshold: Minimum confidence map value to consider a local peak as valid. + integral_refinement: If `True`, peaks will be refined with integral regression. + If `False`, `"local"`, peaks will be refined with quarter pixel local + gradient offset. This has no effect if the model has an offset regression + head. + integral_patch_size: Size of patches to crop around each rough peak for integral + refinement as an integer scalar. + tracks: If provided, instances will be created using these track instances. If + not, instances will be assigned tracks from the provider if possible. + """ + + centroid_config: Optional[TrainingJobConfig] = attr.ib(default=None) + centroid_model: Optional[Model] = attr.ib(default=None) + confmap_config: Optional[TrainingJobConfig] = attr.ib(default=None) + confmap_model: Optional[Model] = attr.ib(default=None) + inference_model: Optional[TopDownMultiClassInferenceModel] = attr.ib(default=None) + pipeline: Optional[Pipeline] = attr.ib(default=None, init=False) + tracker: Optional[Tracker] = attr.ib(default=None, init=False) + batch_size: int = 4 + peak_threshold: float = 0.2 + integral_refinement: bool = True + integral_patch_size: int = 5 + tracks: Optional[List[sleap.Track]] = None + + def _initialize_inference_model(self): + """Initialize the inference model from the trained models and configuration.""" + use_gt_centroid = self.centroid_config is None + use_gt_confmap = self.confmap_config is None # TODO + + if use_gt_centroid: + centroid_crop_layer = CentroidCropGroundTruth( + crop_size=self.confmap_config.data.instance_cropping.crop_size + ) + else: + # if use_gt_confmap: + # crop_size = 1 + # else: + crop_size = self.confmap_config.data.instance_cropping.crop_size + centroid_crop_layer = CentroidCrop( + keras_model=self.centroid_model.keras_model, + crop_size=crop_size, + input_scale=self.centroid_config.data.preprocessing.input_scaling, + pad_to_stride=self.centroid_config.data.preprocessing.pad_to_stride, + output_stride=self.centroid_config.model.heads.centroid.output_stride, + peak_threshold=self.peak_threshold, + refinement="integral" if self.integral_refinement else "local", + integral_patch_size=self.integral_patch_size, + return_confmaps=False, + ) + + # if use_gt_confmap: + # instance_peaks_layer = FindInstancePeaksGroundTruth() + # else: + cfg = self.confmap_config + instance_peaks_layer = TopDownMultiClassFindPeaks( + keras_model=self.confmap_model.keras_model, + input_scale=cfg.data.preprocessing.input_scaling, + peak_threshold=self.peak_threshold, + output_stride=cfg.model.heads.multi_class_topdown.confmaps.output_stride, + refinement="integral" if self.integral_refinement else "local", + integral_patch_size=self.integral_patch_size, + return_confmaps=False, ) - elif "multi_class" in trained_model_paths: - predictor = BottomUpMultiClassPredictor.from_trained_models( - trained_model_paths["multi_class"], - **get_relevant_args("bottomup_multiclass"), - **kwargs, + + self.inference_model = TopDownMultiClassInferenceModel( + centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer ) - elif ( - "centroid" in trained_model_paths and "centered_instance" in trained_model_paths - ): - predictor = TopdownPredictor.from_trained_models( - centroid_model_path=trained_model_paths["centroid"], - confmap_model_path=trained_model_paths["centered_instance"], - **get_relevant_args("topdown"), - **kwargs, + + @classmethod + def from_trained_models( + cls, + centroid_model_path: Optional[Text] = None, + confmap_model_path: Optional[Text] = None, + batch_size: int = 4, + peak_threshold: float = 0.2, + integral_refinement: bool = True, + integral_patch_size: int = 5, + ) -> "TopDownMultiClassPredictor": + """Create predictor from saved models. + + Args: + centroid_model_path: Path to a centroid model folder or training job JSON + file inside a model folder. This folder should contain + `training_config.json` and `best_model.h5` files for a trained model. + confmap_model_path: Path to a centered instance model folder or training job + JSON file inside a model folder. This folder should contain + `training_config.json` and `best_model.h5` files for a trained model. + batch_size: The default batch size to use when loading data for inference. + Higher values increase inference speed at the cost of higher memory + usage. + peak_threshold: Minimum confidence map value to consider a local peak as + valid. + integral_refinement: If `True`, peaks will be refined with integral + regression. If `False`, `"local"`, peaks will be refined with quarter + pixel local gradient offset. This has no effect if the model has an + offset regression head. + integral_patch_size: Size of patches to crop around each rough peak for + integral refinement as an integer scalar. + + Returns: + An instance of `TopDownPredictor` with the loaded models. + + One of the two models can be left as `None` to perform inference with ground + truth data. This will only work with `LabelsReader` as the provider. + """ + if centroid_model_path is None and confmap_model_path is None: + raise ValueError( + "Either the centroid or topdown confidence map model must be provided." + ) + + if centroid_model_path is not None: + # Load centroid model. + centroid_config = TrainingJobConfig.load_json(centroid_model_path) + centroid_keras_model_path = get_keras_model_path(centroid_model_path) + centroid_model = Model.from_config(centroid_config.model) + centroid_model.keras_model = tf.keras.models.load_model( + centroid_keras_model_path, compile=False + ) + else: + centroid_config = None + centroid_model = None + + if confmap_model_path is not None: + # Load confmap model. + confmap_config = TrainingJobConfig.load_json(confmap_model_path) + confmap_keras_model_path = get_keras_model_path(confmap_model_path) + confmap_model = Model.from_config(confmap_config.model) + confmap_model.keras_model = tf.keras.models.load_model( + confmap_keras_model_path, compile=False + ) + else: + confmap_config = None + confmap_model = None + + obj = cls( + centroid_config=centroid_config, + centroid_model=centroid_model, + confmap_config=confmap_config, + confmap_model=confmap_model, + batch_size=batch_size, + peak_threshold=peak_threshold, + integral_refinement=integral_refinement, + integral_patch_size=integral_patch_size, ) - elif len(trained_model_paths) == 0 and labels_path: - predictor = MockPredictor.from_trained_models(labels_path=labels_path) - else: - raise ValueError( - f"Unable to run inference with {list(trained_model_paths.keys())} heads." + obj._initialize_inference_model() + return obj + + def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: + """Make a data loading pipeline. + + Args: + data_provider: If not `None`, the pipeline will be created with an instance + of a `sleap.pipelines.Provider`. + + Returns: + The created `sleap.pipelines.Pipeline` with batching and prefetching. + + Notes: + This method also updates the class attribute for the pipeline and will be + called automatically when predicting on data from a new source. + """ + pipeline = Pipeline() + if data_provider is not None: + pipeline.providers = [data_provider] + + if self.centroid_model is None: + anchor_part = self.confmap_config.data.instance_cropping.center_on_part + pipeline += sleap.nn.data.pipelines.InstanceCentroidFinder( + center_on_anchor_part=anchor_part is not None, + anchor_part_names=anchor_part, + skeletons=self.confmap_config.data.labels.skeletons, + ) + + pipeline += sleap.nn.data.pipelines.Batcher( + batch_size=self.batch_size, drop_remainder=False, unrag=False ) - return predictor + pipeline += Prefetcher() + self.pipeline = pipeline -def make_tracker_from_cli(policy_args): - if "tracking" in policy_args: - tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) - return tracker + return pipeline - return None + def _make_labeled_frames_from_generator( + self, generator: Iterator[Dict[str, np.ndarray]], data_provider: Provider + ) -> List[sleap.LabeledFrame]: + """Create labeled frames from a generator that yields inference results. + This method converts pure arrays into SLEAP-specific data structures and runs + them through the tracker if it is specified. -def save_predictions_from_cli(args, predicted_frames, prediction_metadata=None): - from sleap import Labels + Args: + generator: A generator that returns dictionaries with inference results. + This should return dictionaries with keys `"image"`, `"video_ind"`, + `"frame_ind"`, `"instance_peaks"`, `"instance_peak_vals"`, and + `"centroid_vals"`. This can be created using the `_predict_generator()` + method. + data_provider: The `sleap.pipelines.Provider` that the predictions are being + created from. This is used to retrieve the `sleap.Video` instance + associated with each inference result. - if args.output: - output_path = args.output - elif args.video_path: - out_dir = os.path.dirname(args.video_path) - out_name = os.path.basename(args.video_path) + ".predictions.slp" - output_path = os.path.join(out_dir, out_name) - elif args.labels: - out_dir = os.path.dirname(args.labels) - out_name = os.path.basename(args.labels) + ".predictions.slp" - output_path = os.path.join(out_dir, out_name) - else: - # We shouldn't ever get here but if we do, just save in working dir. - output_path = "predictions.slp" + Returns: + A list of `sleap.LabeledFrame`s with `sleap.PredictedInstance`s created from + arrays returned from the inference result generator. + """ + if self.confmap_config is not None: + skeleton = self.confmap_config.data.labels.skeletons[0] + else: + skeleton = self.centroid_config.data.labels.skeletons[0] + + tracks = self.tracks + if tracks is None: + if hasattr(data_provider, "tracks"): + tracks = data_provider.tracks + elif ( + self.confmap_config.model.heads.multi_class_topdown.class_vectors.classes + is not None + ): + names = ( + self.confmap_config.model.heads.multi_class_topdown.class_vectors.classes + ) + tracks = [sleap.Track(name=n, spawned_on=0) for n in names] + + # Loop over batches. + predicted_frames = [] + for ex in generator: + + # Loop over frames. + for image, video_ind, frame_ind, points, confidences, scores in zip( + ex["image"], + ex["video_ind"], + ex["frame_ind"], + ex["instance_peaks"], + ex["instance_peak_vals"], + ex["instance_scores"], + ): + + # Loop over instances. + predicted_instances = [] + for i, (pts, confs, score) in enumerate( + zip(points, confidences, scores) + ): + if np.isnan(pts).all(): + continue + track = None + if tracks is not None and len(tracks) >= (i - 1): + track = tracks[i] + predicted_instances.append( + sleap.instance.PredictedInstance.from_arrays( + points=pts, + point_confidences=confs, + instance_score=np.nanmean(score), + skeleton=skeleton, + track=track, + ) + ) - labels = Labels(labeled_frames=predicted_frames, provenance=prediction_metadata) + predicted_frames.append( + sleap.LabeledFrame( + video=data_provider.videos[video_ind], + frame_idx=frame_ind, + instances=predicted_instances, + ) + ) - print(f"Saving: {output_path}") - Labels.save_file(labels, output_path) + return predicted_frames def load_model( model_path: Union[str, List[str]], batch_size: int = 4, - peak_threshold: float = 0.7, + peak_threshold: float = 0.2, refinement: str = "integral", tracker: Optional[str] = None, tracker_window: int = 5, tracker_max_instances: Optional[int] = None, + disable_gpu_preallocation: bool = True, + progress_reporting: str = "rich", ) -> Predictor: """Load a trained SLEAP model. - Args: model_path: Path to model or list of path to models that were trained by SLEAP. These should be the directories that contain `training_job.json` and `best_model.h5`. batch_size: Number of frames to predict at a time. Larger values result in faster inference speeds, but require more memory. + peak_threshold: Minimum confidence map value to consider a peak as valid. refinement: If `"integral"`, peak locations will be refined with integral regression. If `"local"`, peaks will be refined with quarter pixel local gradient offset. This has no effect if the model has an offset regression @@ -3152,32 +3633,59 @@ def load_model( `tracker` is `None`. tracker_max_instances: If not `None`, discard instances beyond this count when tracking. No effect when `tracker` is `None`. - + disable_gpu_preallocation: If `True` (the default), initialize the GPU and + disable preallocation of memory. This is necessary to prevent freezing on + some systems with low GPU memory and has negligible impact on performance. + If `False`, no GPU initialization is performed. No effect if running in + CPU-only mode. + progress_reporting: Mode of inference progress reporting. If `"rich"` (the + default), an updating progress bar is displayed in the console or notebook. + If `"json"`, a JSON-serialized message is printed out which can be captured + for programmatic progress monitoring. If `"none"`, nothing is displayed + during inference -- this is recommended when running on clusters or headless + machines where the output is captured to a log file. Returns: An instance of a `Predictor` based on which model type was detected. - If this is a top-down model, paths to the centroids model as well as the - centered instance model must be provided. A `TopdownPredictor` instance will be + centered instance model must be provided. A `TopDownPredictor` instance will be returned. - - If this is a bottom-up model, a `BottomupPredictor` will be returned. - + If this is a bottom-up model, a `BottomUpPredictor` will be returned. If this is a single-instance model, a `SingleInstancePredictor` will be returned. - If a `tracker` is specified, the predictor will also run identity tracking over time. - - See also: TopdownPredictor, BottomupPredictor, SingleInstancePredictor + See also: TopDownPredictor, BottomUpPredictor, SingleInstancePredictor """ if isinstance(model_path, str): - model_path = [model_path] - predictor = make_predictor_from_paths( - model_path, - batch_size=batch_size, - integral_refinement=refinement == "integral", + model_paths = [model_path] + else: + model_paths = model_path + + # Uncompress ZIP packaged models. + tmp_dirs = [] + for i, model_path in enumerate(model_paths): + if model_path.endswith(".zip"): + # Create temp dir on demand. + tmp_dir = tempfile.TemporaryDirectory() + tmp_dirs.append(tmp_dir) + + # Remove the temp dir when program exits in case something goes wrong. + atexit.register(shutil.rmtree, tmp_dir.name, ignore_errors=True) + + # Extract and replace in the list. + shutil.unpack_archive(model_path, extract_dir=tmp_dir.name) + model_paths[i] = tmp_dir.name + + if disable_gpu_preallocation: + sleap.disable_preallocation() + + predictor = Predictor.from_model_paths( + model_paths, peak_threshold=peak_threshold, + integral_refinement=refinement == "integral", + batch_size=batch_size, ) + predictor.verbosity = progress_reporting if tracker is not None: predictor.tracker = Tracker.make_tracker_by_name( tracker=tracker, @@ -3185,15 +3693,322 @@ def load_model( post_connect_single_breaks=True, clean_instance_count=tracker_max_instances, ) + + # Remove temp dirs. + for tmp_dir in tmp_dirs: + tmp_dir.cleanup() + + return predictor + + +def _make_cli_parser() -> argparse.ArgumentParser: + """Create argument parser for CLI. + + Returns: + The `argparse.ArgumentParser` that defines the CLI options. + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "data_path", + type=str, + nargs="?", + default="", + help=( + "Path to data to predict on. This can be a labels (.slp) file or any " + "supported video format." + ), + ) + parser.add_argument( + "-m", + "--model", + dest="models", + action="append", + help=( + "Path to trained model directory (with training_config.json). " + "Multiple models can be specified, each preceded by --model." + ), + ) + parser.add_argument( + "--frames", + type=sleap.util.frame_list, + default="", + help=( + "List of frames to predict when running on a video. Can be specified as a " + "comma separated list (e.g. 1,2,3) or a range separated by hyphen (e.g., " + "1-3, for 1,2,3). If not provided, defaults to predicting on the entire " + "video." + ), + ) + parser.add_argument( + "--only-labeled-frames", + action="store_true", + default=False, + help=( + "Only run inference on user labeled frames when running on labels dataset. " + "This is useful for generating predictions to compare against ground truth." + ), + ) + parser.add_argument( + "--only-suggested-frames", + action="store_true", + default=False, + help=( + "Only run inference on unlabeled suggested frames when running on labels " + "dataset. This is useful for generating predictions for initialization " + "during labeling." + ), + ) + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help=( + "The output filename to use for the predicted data. If not provided, " + "defaults to '[data_path].predictions.slp'." + ), + ) + parser.add_argument( + "--no-empty-frames", + action="store_true", + default=False, + help=( + "Clear any empty frames that did not have any detected instances before " + "saving to output." + ), + ) + parser.add_argument( + "--verbosity", + type=str, + choices=["none", "rich", "json"], + default="rich", + help=( + "Verbosity of inference progress reporting. 'none' does not output " + "anything during inference, 'rich' displays an updating progress bar, " + "and 'json' outputs the progress as a JSON encoded response to the " + "console." + ), + ) + parser.add_argument( + "--video.dataset", type=str, default=None, help="The dataset for HDF5 videos." + ) + parser.add_argument( + "--video.input_format", + type=str, + default="channels_last", + help="The input_format for HDF5 videos.", + ) + device_group = parser.add_mutually_exclusive_group(required=False) + device_group.add_argument( + "--cpu", + action="store_true", + help="Run inference only on CPU. If not specified, will use available GPU.", + ) + device_group.add_argument( + "--first-gpu", + action="store_true", + help="Run inference on the first GPU, if available.", + ) + device_group.add_argument( + "--last-gpu", + action="store_true", + help="Run inference on the last GPU, if available.", + ) + device_group.add_argument( + "--gpu", type=int, default=0, help="Run inference on the i-th GPU specified." + ) + + parser.add_argument( + "--peak_threshold", + type=float, + default=0.2, + help="Minimum confidence map value to consider a peak as valid.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4, + help=( + "Number of frames to predict at a time. Larger values result in faster " + "inference speeds, but require more memory." + ), + ) + + # Deprecated legacy args. These will still be parsed for backward compatibility but + # are hidden from the CLI help. + parser.add_argument( + "--labels", + type=str, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--single.peak_threshold", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--topdown.peak_threshold", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--bottomup.peak_threshold", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--single.batch_size", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--topdown.batch_size", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--bottomup.batch_size", + type=float, + default=argparse.SUPPRESS, + help=argparse.SUPPRESS, + ) + + # Add tracker args. + Tracker.add_cli_parser_args(parser, arg_scope="tracking") + + return parser + + +def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]: + """Make data provider from parsed CLI args. + + Args: + args: Parsed CLI namespace. + + Returns: + A tuple of `(provider, data_path)` with the data `Provider` and path to the data + that was specified in the args. + """ + # Figure out which input path to use. + labels_path = getattr(args, "labels", None) + if labels_path is not None: + data_path = labels_path + else: + data_path = args.data_path + + if data_path is None: + raise ValueError("You must specify a path to a video or a labels dataset.") + + if data_path.endswith(".slp"): + labels = sleap.Labels.load_file(data_path) + + if args.only_labeled_frames: + provider = LabelsReader.from_user_labeled_frames(labels) + elif args.only_suggested_frames: + provider = LabelsReader.from_unlabeled_suggestions(labels) + else: + provider = LabelsReader(labels) + + else: + # TODO: Clean this up. + video_kwargs = dict( + dataset=vars(args).get("video.dataset"), + input_format=vars(args).get("video.input_format"), + ) + provider = VideoReader.from_filepath( + filename=data_path, example_indices=args.frames, **video_kwargs + ) + + return provider, data_path + + +def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: + """Make predictor from parsed CLI args. + + Args: + args: Parsed CLI namespace. + + Returns: + The `Predictor` created from loaded models. + """ + peak_threshold = None + for deprecated_arg in [ + "single.peak_threshold", + "topdown.peak_threshold", + "bottomup.peak_threshold", + ]: + val = getattr(args, deprecated_arg, None) + if val is not None: + peak_threshold = val + if peak_threshold is None: + peak_threshold = args.peak_threshold + + batch_size = None + for deprecated_arg in [ + "single.batch_size", + "topdown.batch_size", + "bottomup.batch_size", + ]: + val = getattr(args, deprecated_arg, None) + if val is not None: + batch_size = val + if batch_size is None: + batch_size = args.batch_size + + predictor = Predictor.from_model_paths( + args.models, + peak_threshold=peak_threshold, + integral_refinement=True, + batch_size=batch_size, + ) + predictor.verbosity = args.verbosity return predictor +def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]: + """Make tracker from parsed CLI arguments. + + Args: + args: Parsed CLI namespace. + + Returns: + An instance of `Tracker` or `None` if tracking method was not specified. + """ + policy_args = sleap.util.make_scoped_dictionary(vars(args), exclude_nones=True) + if "tracking" in policy_args: + tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) + return tracker + return None + + def main(): - """CLI for running inference.""" - parser = make_cli_parser() + """Entrypoint for `sleap-track` CLI for running inference.""" + t0 = time() + start_timestamp = str(datetime.now()) + print("Started inference at:", start_timestamp) + + # Setup CLI. + parser = _make_cli_parser() + + # Parse inputs. args, _ = parser.parse_known_args() - print(args) + args_msg = ["Args:"] + for name, val in vars(args).items(): + if name == "frames" and val is not None: + args_msg.append(f" frames: {min(val)}-{max(val)} ({len(val)})") + else: + args_msg.append(f" {name}: {val}") + print("\n".join(args_msg)) + + # Setup devices. if args.cpu or not sleap.nn.system.is_gpu_system(): sleap.nn.system.use_cpu_only() else: @@ -3203,69 +4018,57 @@ def main(): sleap.nn.system.use_last_gpu() else: sleap.nn.system.use_gpu(args.gpu) - sleap.nn.system.disable_preallocation() + sleap.disable_preallocation() + + print("Versions:") + sleap.versions() + print() print("System:") sleap.nn.system.summary() + print() - video_readers = make_video_readers_from_cli(args) + # Setup data loader. + provider, data_path = _make_provider_from_cli(args) - # Find the specified models - model_paths_by_head = find_heads_for_model_paths(args.models) + # Setup models. + predictor = _make_predictor_from_cli(args) - # Make a scoped dictionary with args specified from cli - policy_args = util.make_scoped_dictionary(vars(args), exclude_nones=True) - - # Create appropriate predictor given these models - predictor = make_predictor_from_models( - model_paths_by_head, labels_path=args.labels, policy_args=policy_args - ) - - # Make the tracker - tracker = make_tracker_from_cli(policy_args) + # Setup tracker. + tracker = _make_tracker_from_cli(args) predictor.tracker = tracker - if args.test_pipeline: - print() - - print(policy_args) - print() - - print(predictor) - print() - - predictor.make_pipeline() - print("===pipeline transformers===") - print() - for transformer in predictor.pipeline.transformers: - print(transformer.__class__.__name__) - print(f"\t-> {transformer.input_keys}") - print(f"\t {transformer.output_keys} ->") - print() - - print("--test-pipeline arg set so stopping here.") - return - # Run inference! - t0 = time.time() - predicted_frames = [] - - for video_reader in video_readers: - video_predicted_frames = predictor.predict(video_reader).labeled_frames - predicted_frames.extend(video_predicted_frames) - - # Create dictionary of metadata we want to save with predictions - prediction_metadata = dict() - for head, path in model_paths_by_head.items(): - prediction_metadata[f"model.{head}.path"] = os.path.abspath(path) - for scope in policy_args.keys(): - for key, val in policy_args[scope].items(): - prediction_metadata[f"{scope}.{key}"] = val - prediction_metadata["video.path"] = args.video_path - prediction_metadata["sleap.version"] = sleap.__version__ - - save_predictions_from_cli(args, predicted_frames, prediction_metadata) - print(f"Total Time: {time.time() - t0}") + labels_pr = predictor.predict(provider) + + if args.no_empty_frames: + # Clear empty frames if specified. + labels_pr.remove_empty_frames() + + finish_timestamp = str(datetime.now()) + total_elapsed = time() - t0 + print("Finished inference at:", finish_timestamp) + print(f"Total runtime: {total_elapsed} secs") + print(f"Predicted frames: {len(labels_pr)}/{len(provider)}") + + output_path = args.output + if output_path is None: + output_path = data_path + ".predictions.slp" + + # Add provenance metadata to predictions. + labels_pr.provenance["sleap_version"] = sleap.__version__ + labels_pr.provenance["platform"] = platform.platform() + labels_pr.provenance["data_path"] = data_path + labels_pr.provenance["model_paths"] = predictor.model_paths + labels_pr.provenance["output_path"] = output_path + labels_pr.provenance["predictor"] = type(predictor).__name__ + labels_pr.provenance["total_elapsed"] = total_elapsed + labels_pr.provenance["start_timestamp"] = start_timestamp + labels_pr.provenance["finish_timestamp"] = finish_timestamp + + # Save results. + labels_pr.save(output_path) + print("Saved output:", output_path) if __name__ == "__main__": diff --git a/sleap/nn/model.py b/sleap/nn/model.py index 8b7a13d3d..b71774e3b 100644 --- a/sleap/nn/model.py +++ b/sleap/nn/model.py @@ -22,12 +22,14 @@ IntermediateFeature, ) from sleap.nn.heads import ( + Head, CentroidConfmapsHead, SingleInstanceConfmapsHead, CenteredInstanceConfmapsHead, MultiInstanceConfmapsHead, PartAffinityFieldsHead, ClassMapsHead, + ClassVectorsHead, OffsetRefinementHead, ) from sleap.nn.config import ( @@ -40,8 +42,8 @@ CentroidsHeadConfig, CenteredInstanceConfmapsHeadConfig, MultiInstanceConfig, - ClassMapsHeadConfig, - MultiClassConfig, + MultiClassBottomUpConfig, + MultiClassTopDownConfig, BackboneConfig, HeadsConfig, ModelConfig, @@ -70,17 +72,6 @@ PretrainedEncoderConfig: UnetPretrainedEncoder, } -HEADS = [ - CentroidConfmapsHead, - SingleInstanceConfmapsHead, - CenteredInstanceConfmapsHead, - MultiInstanceConfmapsHead, - PartAffinityFieldsHead, - ClassMapsHead, - OffsetRefinementHead, -] -Head = TypeVar("Head", *HEADS) - @attr.s(auto_attribs=True) class Model: @@ -118,7 +109,11 @@ def from_config( """ # Figure out which backbone class to use. backbone_config = config.backbone.which_oneof() - backbone_cls = BACKBONE_CONFIG_TO_CLS[type(backbone_config)] + backbone_cls = BACKBONE_CONFIG_TO_CLS.get(type(backbone_config), None) + if backbone_cls is None: + raise ValueError( + "Backbone architecture (config.model.backbone) was not specified." + ) # Figure out which head class to use. head_config = config.heads.which_oneof() @@ -148,9 +143,7 @@ def from_config( heads = [CentroidConfmapsHead.from_config(head_config)] output_stride = heads[0].output_stride if head_config.offset_refinement: - heads.append( - OffsetRefinementHead.from_config(head_config) - ) + heads.append(OffsetRefinementHead.from_config(head_config)) elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig): part_names = head_config.part_names @@ -212,7 +205,7 @@ def from_config( ) ) - elif isinstance(head_config, MultiClassConfig): + elif isinstance(head_config, MultiClassBottomUpConfig): part_names = head_config.confmaps.part_names if part_names is None: if skeleton is None: @@ -250,6 +243,50 @@ def from_config( ) ) + elif isinstance(head_config, MultiClassTopDownConfig): + part_names = head_config.confmaps.part_names + if part_names is None: + if skeleton is None: + raise ValueError( + "Skeleton must be provided when the head configuration is " + "incomplete." + ) + part_names = skeleton.node_names + if update_config: + head_config.confmaps.part_names = part_names + + classes = head_config.class_vectors.classes + if classes is None: + if tracks is None: + raise ValueError( + "Classes must be provided when the head configuration is " + "incomplete." + ) + classes = [t.name for t in tracks] + if update_config: + head_config.class_vectors.classes = classes + + heads = [ + CenteredInstanceConfmapsHead.from_config( + head_config.confmaps, part_names=part_names + ), + ClassVectorsHead.from_config( + head_config.class_vectors, classes=classes + ), + ] + output_stride = min(heads[0].output_stride, heads[1].output_stride) + output_stride = heads[0].output_stride + if head_config.confmaps.offset_refinement: + heads.append( + OffsetRefinementHead.from_config( + head_config.confmaps, part_names=part_names + ) + ) + else: + raise ValueError( + "Head configuration (config.model.heads) was not specified." + ) + backbone_config.output_stride = output_stride return cls(backbone=backbone_cls.from_config(backbone_config), heads=heads) @@ -284,25 +321,12 @@ def make_model(self, input_shape: Tuple[int, int, int]) -> tf.keras.Model: # Build output layers for each head. x_outs = [] for output in self.heads: - if isinstance(output, ClassMapsHead): - activation = "sigmoid" - else: - activation = "linear" x_head = [] if output.output_stride == self.backbone.output_stride: # The main output has the same stride as the head, so build output layer # from that tensor. for i, x in enumerate(x_main): - x_head.append( - tf.keras.layers.Conv2D( - filters=output.channels, - kernel_size=1, - strides=1, - padding="same", - activation=activation, - name=f"{type(output).__name__}_{i}", - )(x) - ) + x_head.append(output.make_head(x)) else: # Look for an intermediate activation that has the correct stride. @@ -311,16 +335,7 @@ def make_model(self, input_shape: Tuple[int, int, int]) -> tf.keras.Model: assert all([feat.stride == feats[0].stride for feat in feats]) if feats[0].stride == output.output_stride: for i, feat in enumerate(feats): - x_head.append( - tf.keras.layers.Conv2D( - filters=output.channels, - kernel_size=1, - strides=1, - padding="same", - activation=activation, - name=f"{type(output).__name__}_{i}", - )(feat.tensor) - ) + x_head.append(output.make_head(feat.tensor)) break if len(x_head) == 0: diff --git a/sleap/nn/monitor.py b/sleap/nn/monitor.py index 40fae2a0e..0e60469df 100644 --- a/sleap/nn/monitor.py +++ b/sleap/nn/monitor.py @@ -28,6 +28,8 @@ def __init__( self.show_controller = show_controller self.stop_button = None + self.cancel_button = None + self.canceled = False self.redraw_batch_interval = 40 self.batches_to_show = -1 # -1 to show all @@ -162,9 +164,12 @@ def reset(self, what=""): control_layout.addStretch(1) - self.stop_button = QtWidgets.QPushButton("Stop Training") + self.stop_button = QtWidgets.QPushButton("Stop Early") self.stop_button.clicked.connect(self.stop) control_layout.addWidget(self.stop_button) + self.cancel_button = QtWidgets.QPushButton("Cancel Training") + self.cancel_button.clicked.connect(self.cancel) + control_layout.addWidget(self.cancel_button) widget = QtWidgets.QWidget() widget.setLayout(control_layout) @@ -242,12 +247,19 @@ def setup_zmq(self, zmq_context: Optional[zmq.Context]): self.timer.timeout.connect(self.check_messages) self.timer.start(20) + def cancel(self): + """Set the cancel flag.""" + self.canceled = True + if self.cancel_button is not None: + self.cancel_button.setText("Canceling...") + self.cancel_button.setEnabled(False) + def stop(self): """Action to stop training.""" if self.zmq_ctrl is not None: # send command to stop training - logger.info("Sending command to stop training") + logger.info("Sending command to stop training.") self.zmq_ctrl.send_string(jsonpickle.encode(dict(command="stop"))) # Disable the button diff --git a/sleap/nn/paf_grouping.py b/sleap/nn/paf_grouping.py index 57aec2179..3ee32c75a 100644 --- a/sleap/nn/paf_grouping.py +++ b/sleap/nn/paf_grouping.py @@ -1532,7 +1532,7 @@ def predict( `tf.float32`. Notes: - This is a high level API for grouping peaks into instances using PAFs. + This is a high level API for grouping peaks into instances using PAFs. See the `PAFScorer` class documentation for more details on the algorithm. diff --git a/sleap/nn/peak_finding.py b/sleap/nn/peak_finding.py index 6034cdfcd..c4d971788 100644 --- a/sleap/nn/peak_finding.py +++ b/sleap/nn/peak_finding.py @@ -688,11 +688,11 @@ def find_local_peaks_with_offsets( ], axis=1, ) - + # Expand last axes of offsets. shape = tf.shape(offsets) offsets = tf.reshape(offsets, [shape[0], shape[1], shape[2], -1, 2]) - + # Extract offsets at the peak locations. peak_offsets = tf.gather_nd(offsets, subs) diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index 91f11c63d..e307df1ff 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -436,7 +436,10 @@ def from_candidate_instances( # candidates. track_instances = candidate_instances_by_track[candidate_track] track_matching_similarities = [ - similarity_function(untracked_instance, candidate_instance,) + similarity_function( + untracked_instance, + candidate_instance, + ) for candidate_instance in track_instances ] diff --git a/sleap/nn/tracker/kalman.py b/sleap/nn/tracker/kalman.py index 60a68be5a..2b0343927 100644 --- a/sleap/nn/tracker/kalman.py +++ b/sleap/nn/tracker/kalman.py @@ -350,7 +350,7 @@ def get_instance_points_weight( def get_too_close_checking_function( self, instances: List[InstanceType], dist_thresh: float ) -> Callable: - """" + """ " Returns a function which determines if two instances are too close. Args: @@ -583,7 +583,9 @@ def matches_from_match_tuples( def remove_second_bests_from_cost_matrix( - cost_matrix: np.ndarray, thresh: float, invalid_val: float = np.nan, + cost_matrix: np.ndarray, + thresh: float, + invalid_val: float = np.nan, ) -> np.ndarray: """ Removes unclear matches from cost matrix. diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index cbde7e718..7ad93ee8a 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -212,7 +212,11 @@ def flow_shift_instances( None, winSize=(window_size, window_size), maxLevel=max_levels, - criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 30, 0.01,), + criteria=( + cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, + 30, + 0.01, + ), ) shifted_pts /= scale @@ -264,13 +268,21 @@ def get_candidates( return candidate_instances -tracker_policies = dict(simple=SimpleCandidateMaker, flow=FlowCandidateMaker,) +tracker_policies = dict( + simple=SimpleCandidateMaker, + flow=FlowCandidateMaker, +) similarity_policies = dict( - instance=instance_similarity, centroid=centroid_distance, iou=instance_iou, + instance=instance_similarity, + centroid=centroid_distance, + iou=instance_iou, ) -match_policies = dict(hungarian=hungarian_matching, greedy=greedy_matching,) +match_policies = dict( + hungarian=hungarian_matching, + greedy=greedy_matching, +) @attr.s(auto_attribs=True) @@ -420,7 +432,9 @@ def track( # Build a pool of matchable candidate instances. candidate_instances = self.candidate_maker.get_candidates( - track_matching_queue=self.track_matching_queue, t=t, img=img, + track_matching_queue=self.track_matching_queue, + t=t, + img=img, ) # Determine matches for untracked instances in current frame. @@ -462,7 +476,9 @@ def update_matched_instance_tracks(matches: List[Match]) -> List[InstanceType]: # Assign to track and save. inst_list.append( attr.evolve( - match.instance, track=match.track, tracking_score=match.score, + match.instance, + track=match.track, + tracking_score=match.score, ) ) return inst_list @@ -489,13 +505,13 @@ def spawn_for_untracked_instances( def final_pass(self, frames: List[LabeledFrame]): """Called after tracking has run on all frames to do any post-processing.""" if self.cleaner: - # print( - # "DEPRECATION WARNING: " - # "--clean_instance_count is deprecated (but still applied to " - # "clean results *after* tracking). Use --target_instance_count " - # "and --pre_cull_to_target instead to cull instances *before* " - # "tracking." - # ) + # print( + # "DEPRECATION WARNING: " + # "--clean_instance_count is deprecated (but still applied to " + # "clean results *after* tracking). Use --target_instance_count " + # "and --pre_cull_to_target instead to cull instances *before* " + # "tracking." + # ) self.cleaner.run(frames) elif self.target_instance_count and self.post_connect_single_breaks: connect_single_track_breaks(frames, self.target_instance_count) @@ -647,16 +663,12 @@ def get_by_name_factory_options(cls): option = dict(name="clean_instance_count", default=0) option["type"] = int - option[ - "help" - ] = "Target number of instances to clean *after* tracking." + option["help"] = "Target number of instances to clean *after* tracking." options.append(option) option = dict(name="clean_iou_threshold", default=0) option["type"] = float - option[ - "help" - ] = "IOU to use when culling instances *after* tracking." + option["help"] = "IOU to use when culling instances *after* tracking." options.append(option) option = dict(name="similarity", default="instance") @@ -707,9 +719,7 @@ def int_list_func(s): option = dict(name="kf_node_indices", default="") option["type"] = int_list_func - option[ - "help" - ] = "For Kalman filter: Indices of nodes to track." + option["help"] = "For Kalman filter: Indices of nodes to track." options.append(option) option = dict(name="kf_init_frame_count", default="0") @@ -735,7 +745,9 @@ def add_cli_parser_args(cls, parser, arg_scope: str = ""): arg_name = arg["name"] parser.add_argument( - f"--{arg_name}", type=arg["type"], help=help_string, + f"--{arg_name}", + type=arg["type"], + help=help_string, ) diff --git a/sleap/nn/training.py b/sleap/nn/training.py index a04d5e494..4078f4779 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -5,6 +5,7 @@ from datetime import datetime from time import time import logging +import shutil import tensorflow as tf import numpy as np @@ -27,7 +28,8 @@ CentroidsHeadConfig, CenteredInstanceConfmapsHeadConfig, MultiInstanceConfig, - MultiClassConfig, + MultiClassBottomUpConfig, + MultiClassTopDownConfig, ) # Model @@ -43,9 +45,10 @@ TopdownConfmapsPipeline, BottomUpPipeline, BottomUpMultiClassPipeline, + TopDownMultiClassPipeline, KeyMapper, ) -from sleap.nn.data.training import split_labels +from sleap.nn.data.training import split_labels_train_val # Optimization from sleap.nn.config import OptimizationConfig @@ -118,6 +121,11 @@ def from_config( if test is None: test = labels_config.test_labels + if video_search_paths is None: + video_search_paths = [] + if labels_config.search_path_hints is not None: + video_search_paths.extend(labels_config.search_path_hints) + # Update the config fields with arguments (if not a full sleap.Labels instance). if update_config: if isinstance(training, Text): @@ -128,14 +136,16 @@ def from_config( labels_config.validation_fraction = validation if isinstance(test, Text): labels_config.test_labels = test + labels_config.search_path_hints = video_search_paths # Build class. - # TODO: use labels_config.search_path_hints for loading return cls.from_labels( training=training, validation=validation, test=test, video_search_paths=video_search_paths, + labels_config=labels_config, + update_config=update_config, ) @classmethod @@ -145,22 +155,72 @@ def from_labels( validation: Union[Text, sleap.Labels, float], test: Optional[Union[Text, sleap.Labels]] = None, video_search_paths: Optional[List[Text]] = None, + labels_config: Optional[LabelsConfig] = None, + update_config: bool = False, ) -> "DataReaders": """Create data readers from sleap.Labels datasets as data providers.""" - if isinstance(training, str): - print("video search paths: ", video_search_paths) + logger.info(f"Loading training labels from: {training}") training = sleap.Labels.load_file(training, video_search=video_search_paths) - print(training.videos) + + if labels_config is not None and labels_config.split_by_inds: + # First try to split by indices if specified in config. + if ( + labels_config.validation_inds is not None + and len(labels_config.validation_inds) > 0 + ): + logger.info( + "Creating validation split from explicit indices " + f"(n = {len(labels_config.validation_inds)})." + ) + validation = training[labels_config.validation_inds] + + if labels_config.test_inds is not None and len(labels_config.test_inds) > 0: + logger.info( + "Creating test split from explicit indices " + f"(n = {len(labels_config.test_inds)})." + ) + test = training[labels_config.test_inds] + + if ( + labels_config.training_inds is not None + and len(labels_config.training_inds) > 0 + ): + logger.info( + "Creating training split from explicit indices " + f"(n = {len(labels_config.training_inds)})." + ) + training = training[labels_config.training_inds] if isinstance(validation, str): + # If validation is still a path, load it. + logger.info(f"Loading validation labels from: {validation}") validation = sleap.Labels.load_file( validation, video_search=video_search_paths ) elif isinstance(validation, float): - training, validation = split_labels(training, [-1, validation]) + logger.info( + "Creating training and validation splits from " + f"validation fraction: {validation}" + ) + # If validation is still a float, create the split from training. + ( + training, + training_inds, + validation, + validation_inds, + ) = split_labels_train_val(training.with_user_labels_only(), validation) + logger.info( + f" Splits: Training = {len(training_inds)} /" + f" Validation = {len(validation_inds)}" + ) + if update_config and labels_config is not None: + labels_config.training_inds = training_inds + labels_config.validation_inds = validation_inds if isinstance(test, str): + # If test is still a path, load it. + logger.info(f"Loading test labels from: {test}") test = sleap.Labels.load_file(test, video_search=video_search_paths) test_reader = None @@ -209,7 +269,9 @@ def setup_optimizer(config: OptimizationConfig) -> tf.keras.optimizers.Optimizer return optimizer -def setup_losses(config: OptimizationConfig) -> Callable[[tf.Tensor], tf.Tensor]: +def setup_losses( + config: OptimizationConfig, +) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]: """Set up model loss function from config.""" losses = [tf.keras.losses.MeanSquaredError()] @@ -440,14 +502,13 @@ def setup_visualization( ) callbacks.append( MatplotlibSaver( - save_folder=os.path.join(run_path, "viz"), plot_fn=viz_fn, prefix=name + save_folder=os.path.join(run_path, "viz"), + plot_fn=viz_fn, + prefix=name, ) ) - if ( - config.tensorboard.write_logs - and config.tensorboard.visualizations - ): + if config.tensorboard.write_logs and config.tensorboard.visualizations: try: matplotlib.use("Qt5Agg") except ImportError: @@ -549,7 +610,7 @@ def from_config( video_search_paths: Optional[List[Text]] = None, ) -> "Trainer": """Initialize the trainer from a training job configuration. - + Args: config: A `TrainingJobConfig` instance. training_labels: Training labels to use instead of the ones in the config, @@ -563,6 +624,9 @@ def from_config( # Copy input config before we make any changes. initial_config = copy.deepcopy(config) + # Store SLEAP version on the training process. + config.sleap_version = sleap.__version__ + # Create data readers and store loaded skeleton. data_readers = DataReaders.from_config( config.data.labels, @@ -590,8 +654,11 @@ def from_config( trainer_cls = TopdownConfmapsModelTrainer elif isinstance(head_config, MultiInstanceConfig): trainer_cls = BottomUpModelTrainer - elif isinstance(head_config, MultiClassConfig): + elif isinstance(head_config, MultiClassBottomUpConfig): trainer_cls = BottomUpMultiClassModelTrainer + elif isinstance(head_config, MultiClassTopDownConfig): + trainer_cls = TopDownMultiClassModelTrainer + pass elif isinstance(head_config, SingleInstanceConfmapsHeadConfig): trainer_cls = SingleInstanceModelTrainer else: @@ -656,7 +723,10 @@ def _setup_model(self): logger.info(f" Parameters: {self.model.keras_model.count_params():3,d}") logger.info(" Heads: ") for i, head in enumerate(self.model.heads): - logger.info(f" heads[{i}] = {head}") + logger.info(f" [{i}] = {head}") + logger.info(" Outputs: ") + for i, output in enumerate(self.model.keras_model.outputs): + logger.info(f" [{i}] = {output}") @property def keras_model(self) -> tf.keras.Model: @@ -688,14 +758,18 @@ def _setup_pipelines(self): ) + key_mapper ) - logger.info(f"Training set: n = {len(self.data_readers.training_labels)}") + logger.info( + f"Training set: n = {len(self.data_readers.training_labels_reader)}" + ) self.validation_pipeline = ( self.pipeline_builder.make_training_pipeline( self.data_readers.validation_labels_reader ) + key_mapper ) - logger.info(f"Validation set: n = {len(self.data_readers.validation_labels)}") + logger.info( + f"Validation set: n = {len(self.data_readers.validation_labels_reader)}" + ) def _setup_optimization(self): """Set up optimizer, loss functions and compile the model.""" @@ -704,7 +778,10 @@ def _setup_optimization(self): # TODO: Implement general part loss reporting. part_names = None - if isinstance(self.pipeline_builder, TopdownConfmapsPipeline) and self.pipeline_builder.offsets_head is None: + if ( + isinstance(self.pipeline_builder, TopdownConfmapsPipeline) + and self.pipeline_builder.offsets_head is None + ): part_names = [ sanitize_scope_name(name) for name in self.model.heads[0].part_names ] @@ -729,9 +806,14 @@ def _setup_optimization(self): def _setup_outputs(self): """Set up output-related functionality.""" if self.config.outputs.save_outputs: - # Build path to run folder. + # Build path to run folder. Timestamp will be added automatically. + # Example: 210204_041707.centroid.n=300 + model_type = self.config.model.heads.which_oneof_attrib_name() + n = len(self.data_readers.training_labels_reader) + len( + self.data_readers.validation_labels_reader + ) self.run_path = setup_new_run_folder( - self.config.outputs, base_run_name=type(self.model.backbone).__name__ + self.config.outputs, base_run_name=f"{model_type}.n={n}" ) # Setup output callbacks. @@ -829,30 +911,60 @@ def train(self): ) logger.info(f"Finished training loop. [{(time() - t0) / 60:.1f} min]") - # Save predictions and evaluations. + # Run post-training actions. if self.config.outputs.save_outputs: + if ( + self.config.outputs.save_visualizations + and self.config.outputs.delete_viz_images + ): + self.cleanup() + + self.evaluate() + + if self.config.outputs.zip_outputs: + self.package() + + def evaluate(self): + """Compute evaluation metrics on data splits and save them.""" + logger.info("Saving evaluation metrics to model folder...") + sleap.nn.evals.evaluate_model( + cfg=self.config, + labels_reader=self.data_readers.training_labels_reader, + model=self.model, + save=True, + split_name="train", + ) + sleap.nn.evals.evaluate_model( + cfg=self.config, + labels_reader=self.data_readers.validation_labels_reader, + model=self.model, + save=True, + split_name="val", + ) + if self.data_readers.test_labels_reader is not None: sleap.nn.evals.evaluate_model( cfg=self.config, - labels_reader=self.data_readers.training_labels_reader, - model=self.model, - save=True, - split_name="train", - ) - sleap.nn.evals.evaluate_model( - cfg=self.config, - labels_reader=self.data_readers.validation_labels_reader, + labels_reader=self.data_readers.test_labels_reader, model=self.model, save=True, - split_name="val", + split_name="test", ) - if self.data_readers.test_labels_reader is not None: - sleap.nn.evals.evaluate_model( - cfg=self.config, - labels_reader=self.data_readers.test_labels_reader, - model=self.model, - save=True, - split_name="test", - ) + + def cleanup(self): + """Delete visualization images subdirectory.""" + viz_path = os.path.join(self.run_path, "viz") + if os.path.exists(viz_path): + logger.info(f"Deleting visualization directory: {viz_path}") + shutil.rmtree(viz_path) + + def package(self): + """Package model folder into a zip file for portability.""" + if self.config.outputs.delete_viz_images: + self.cleanup() + logger.info(f"Packaging results to: {self.run_path}.zip") + shutil.make_archive( + base_name=self.run_path, root_dir=self.run_path, format="zip" + ) @attr.s(auto_attribs=True) @@ -899,7 +1011,7 @@ def _setup_pipeline_builder(self): data_config=self.config.data, optimization_config=self.config.optimization, single_instance_confmap_head=self.model.heads[0], - offsets_head=self.model.heads[1] if self.has_offsets else None + offsets_head=self.model.heads[1] if self.has_offsets else None, ) @property @@ -1003,7 +1115,7 @@ def _setup_pipeline_builder(self): data_config=self.config.data, optimization_config=self.config.optimization, centroid_confmap_head=self.model.heads[0], - offsets_head=self.model.heads[1] if self.has_offsets else None + offsets_head=self.model.heads[1] if self.has_offsets else None, ) @property @@ -1119,7 +1231,7 @@ def _setup_pipeline_builder(self): data_config=self.config.data, optimization_config=self.config.optimization, instance_confmap_head=self.model.heads[0], - offsets_head=self.model.heads[1] if self.has_offsets else None + offsets_head=self.model.heads[1] if self.has_offsets else None, ) @property @@ -1235,7 +1347,7 @@ def _setup_pipeline_builder(self): optimization_config=self.config.optimization, confmaps_head=self.model.heads[0], pafs_head=self.model.heads[1], - offsets_head=self.model.heads[2] if self.has_offsets else None + offsets_head=self.model.heads[2] if self.has_offsets else None, ) @property @@ -1344,7 +1456,7 @@ class BottomUpMultiClassModelTrainer(Trainer): @property def has_offsets(self) -> bool: """Whether model is configured to output refinement offsets.""" - return self.config.model.heads.multi_class.confmaps.offset_refinement + return self.config.model.heads.multi_class_bottomup.confmaps.offset_refinement def _update_config(self): """Update the configuration with inferred values.""" @@ -1376,7 +1488,7 @@ def _setup_pipeline_builder(self): optimization_config=self.config.optimization, confmaps_head=self.model.heads[0], class_maps_head=self.model.heads[1], - offsets_head=self.model.heads[2] if self.has_offsets else None + offsets_head=self.model.heads[2] if self.has_offsets else None, ) @property @@ -1466,12 +1578,151 @@ def visualize_class_maps_example(example): setup_visualization( self.config.outputs, run_path=self.run_path, - viz_fn=lambda: visualize_class_maps_example(next(validation_viz_ds_iter)), + viz_fn=lambda: visualize_class_maps_example( + next(validation_viz_ds_iter) + ), name=f"validation_class_maps", ) ) +@attr.s(auto_attribs=True) +class TopDownMultiClassModelTrainer(Trainer): + """Trainer for models that output multi-instance confidence maps and class maps.""" + + pipeline_builder: TopDownMultiClassPipeline = attr.ib(init=False) + + @property + def has_offsets(self) -> bool: + """Whether model is configured to output refinement offsets.""" + return self.config.model.heads.multi_class_topdown.confmaps.offset_refinement + + def _update_config(self): + """Update the configuration with inferred values.""" + if self.config.data.preprocessing.pad_to_stride is None: + self.config.data.preprocessing.pad_to_stride = self.model.maximum_stride + + if self.config.optimization.batches_per_epoch is None: + n_training_examples = len(self.data_readers.training_labels) + n_training_batches = ( + n_training_examples // self.config.optimization.batch_size + ) + self.config.optimization.batches_per_epoch = max( + self.config.optimization.min_batches_per_epoch, n_training_batches + ) + + if self.config.optimization.val_batches_per_epoch is None: + n_validation_examples = len(self.data_readers.validation_labels) + n_validation_batches = ( + n_validation_examples // self.config.optimization.batch_size + ) + self.config.optimization.val_batches_per_epoch = max( + self.config.optimization.min_val_batches_per_epoch, n_validation_batches + ) + + def _setup_pipeline_builder(self): + """Initialize pipeline builder.""" + self.pipeline_builder = TopDownMultiClassPipeline( + data_config=self.config.data, + optimization_config=self.config.optimization, + instance_confmap_head=self.model.heads[0], + class_vectors_head=self.model.heads[1], + offsets_head=self.model.heads[2] if self.has_offsets else None, + ) + + @property + def input_keys(self) -> List[Text]: + """Return example keys to be mapped to model inputs.""" + return ["instance_image"] + + @property + def output_keys(self) -> List[Text]: + """Return example keys to be mapped to model outputs.""" + output_keys = ["instance_confidence_maps", "class_vectors"] + if self.has_offsets: + output_keys.append("offsets") + return output_keys + + def _setup_optimization(self): + """Set up optimizer, loss functions and compile the model.""" + optimizer = setup_optimizer(self.config.optimization) + # loss_fn = setup_losses(self.config.optimization) + + # part_names = None + # metrics = setup_metrics(self.config.optimization, part_names=None) + metrics = {"ClassVectorsHead": "accuracy"} + + self.optimization_callbacks = setup_optimization_callbacks( + self.config.optimization + ) + + self.keras_model.compile( + optimizer=optimizer, + loss={ + output_name: head.loss_function + for output_name, head in zip( + self.keras_model.output_names, self.model.heads + ) + }, + metrics=metrics, + loss_weights={ + output_name: head.loss_weight + for output_name, head in zip( + self.keras_model.output_names, self.model.heads + ) + }, + ) + + def _setup_visualization(self): + """Set up visualization pipelines and callbacks.""" + # Create visualization/inference pipelines. + self.training_viz_pipeline = self.pipeline_builder.make_viz_pipeline( + self.data_readers.training_labels_reader, self.keras_model + ) + self.validation_viz_pipeline = self.pipeline_builder.make_viz_pipeline( + self.data_readers.validation_labels_reader, self.keras_model + ) + + # Create static iterators. + training_viz_ds_iter = iter(self.training_viz_pipeline.make_dataset()) + validation_viz_ds_iter = iter(self.validation_viz_pipeline.make_dataset()) + + def visualize_confmaps_example(example): + img = example["image"].numpy() + cms = example["predicted_confidence_maps"].numpy() + pts_gt = example["instances"].numpy() + pts_pr = example["predicted_peaks"].numpy() + + scale = 1.0 + if img.shape[0] < 512: + scale = 2.0 + if img.shape[0] < 256: + scale = 4.0 + fig = plot_img(img, dpi=72 * scale, scale=scale) + plot_confmaps(cms, output_scale=cms.shape[0] / img.shape[0]) + plt.xlim(plt.xlim()) + plt.ylim(plt.ylim()) + plot_peaks(pts_gt, pts_pr, paired=False) + return fig + + self.visualization_callbacks.extend( + setup_visualization( + self.config.outputs, + run_path=self.run_path, + viz_fn=lambda: visualize_confmaps_example(next(training_viz_ds_iter)), + name=f"train", + ) + ) + self.visualization_callbacks.extend( + setup_visualization( + self.config.outputs, + run_path=self.run_path, + viz_fn=lambda: visualize_confmaps_example(next(validation_viz_ds_iter)), + name=f"validation", + ) + ) + + def main(): """Create CLI for training and run.""" import argparse @@ -1480,35 +1731,63 @@ def main(): parser.add_argument( "training_job_path", help="Path to training job profile JSON file." ) - parser.add_argument("labels_path", help="Path to labels file to use for training.") + parser.add_argument( + "labels_path", + nargs="?", + default="", + help=( + "Path to labels file to use for training. If specified, overrides the path " + "specified in the training job config." + ), + ) parser.add_argument( "--video-paths", type=str, default="", - help="List of paths for finding videos in case paths inside labels file need fixing.", + help=( + "List of paths for finding videos in case paths inside labels file are " + "not accessible." + ), ) parser.add_argument( "--val_labels", "--val", - help="Path to labels file to use for validation (overrides training job path if set).", + help=( + "Path to labels file to use for validation. If specified, overrides the " + "path specified in the training job config." + ), ) parser.add_argument( "--test_labels", "--test", - help="Path to labels file to use for test (overrides training job path if set).", + help=( + "Path to labels file to use for test. If specified, overrides the path " + "specified in the training job config." + ), ) parser.add_argument( "--tensorboard", action="store_true", - help="Enables TensorBoard logging to the run path.", + help=( + "Enable TensorBoard logging to the run path if not already specified in " + "the training job config." + ), ) parser.add_argument( "--save_viz", action="store_true", - help="Enables saving of prediction visualizations to the run folder.", + help=( + "Enable saving of prediction visualizations to the run folder if not " + "already specified in the training job config." + ), ) parser.add_argument( - "--zmq", action="store_true", help="Enables ZMQ logging (for GUI)." + "--zmq", + action="store_true", + help=( + "Enable ZMQ logging (for GUI) if not already specified in the training " + "job config." + ), ) parser.add_argument( "--run_name", @@ -1535,16 +1814,21 @@ def main(): # Override config settings for CLI-based training. job_config.outputs.save_outputs = True - job_config.outputs.tensorboard.write_logs = args.tensorboard - job_config.outputs.zmq.publish_updates = args.zmq - job_config.outputs.zmq.subscribe_to_controller = args.zmq + job_config.outputs.tensorboard.write_logs |= args.tensorboard + job_config.outputs.zmq.publish_updates |= args.zmq + job_config.outputs.zmq.subscribe_to_controller |= args.zmq if args.run_name != "": job_config.outputs.run_name = args.run_name if args.prefix != "": job_config.outputs.run_name_prefix = args.prefix if args.suffix != "": job_config.outputs.run_name_suffix = args.suffix - job_config.outputs.save_visualizations = args.save_viz + job_config.outputs.save_visualizations |= args.save_viz + if args.labels_path == "": + args.labels_path = None + + logger.info("Versions:") + sleap.versions() logger.info(f"Training labels file: {args.labels_path}") logger.info(f"Training profile: {job_filename}") diff --git a/sleap/prefs.py b/sleap/prefs.py index d39c86a16..5269ccd5d 100644 --- a/sleap/prefs.py +++ b/sleap/prefs.py @@ -12,18 +12,18 @@ class Preferences(object): _prefs = None _defaults = { - "medium step size": 4, + "medium step size": 10, "large step size": 100, "color predicted": False, + "propagate track labels": True, "palette": "standard", "bold lines": False, "trail length": 0, "trail width": 4.0, "trail node count": 1, - "hide videos dock": False, - "hide skeleton dock": False, - "hide instances dock": False, - "hide labeling suggestions dock": False, + "marker size": 4, + "edge style": "Line", + "window state": b"", } _filename = "preferences.yaml" @@ -48,6 +48,11 @@ def save(self): """Save preferences to file.""" util.save_config_yaml(self._filename, self._prefs) + def reset_to_default(self): + """Reset preferences to default.""" + util.save_config_yaml(self._filename, self._defaults) + self.load() + def _validate_key(self, key): if key not in self._defaults: raise KeyError(f"No preference matching '{key}'") diff --git a/sleap/skeleton.py b/sleap/skeleton.py index 28e1a0cc7..1b9a1d740 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -14,6 +14,7 @@ import h5py import copy +import operator from enum import Enum from itertools import count from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text @@ -22,12 +23,14 @@ from networkx.readwrite import json_graph from scipy.io import loadmat + NodeRef = Union[str, "Node"] H5FileRef = Union[str, h5py.File] class EdgeType(Enum): - """ + """Type of edge in the skeleton graph. + The skeleton graph can store different types of edges to represent different things. All edges must specify one or more of the following types: @@ -44,9 +47,13 @@ class EdgeType(Enum): @attr.s(auto_attribs=True, slots=True, eq=False, order=False) class Node: - """ - The class :class:`Node` represents a potential skeleton node. - (But note that nodes can exist without being part of a skeleton.) + """This class represents node in the skeleton graph, i.e., a body part. + + Note: Nodes can exist without being part of a skeleton. + + Attributes: + name: String name of the node. + weight: Weight of the node (not currently used). """ name: str @@ -92,8 +99,7 @@ class Skeleton: _skeleton_idx = count(0) def __init__(self, name: str = None): - """ - Initialize an empty skeleton object. + """Initialize an empty skeleton object. Skeleton objects, once created, can be modified by adding nodes and edges. @@ -101,7 +107,6 @@ def __init__(self, name: str = None): Args: name: A name for this skeleton. """ - # If no skeleton was create, try to create a unique name for this Skeleton. if name is None or not isinstance(name, str) or not name: name = "Skeleton-" + str(next(self._skeleton_idx)) @@ -114,23 +119,35 @@ def __init__(self, name: str = None): def __repr__(self) -> str: """Return full description of the skeleton.""" - return (f"Skeleton(name='{self.name}', " - f"nodes={self.node_names}, edges={self.edge_names})") + return ( + f"Skeleton(name='{self.name}', " + f"nodes={self.node_names}, " + f"edges={self.edge_names}, " + f"symmetries={self.symmetry_names}" + ")" + ) def __str__(self) -> str: """Return short readable description of the skeleton.""" - return f"Skeleton(nodes={len(self.nodes)}, edges={len(self.edges)})" + nodes = ", ".join(self.node_names) + edges = ", ".join([f"{s}->{d}" for (s, d) in self.edge_names]) + symm = ", ".join([f"{s}<->{d}" for (s, d) in self.symmetry_names]) + return ( + "Skeleton(" + f"nodes=[{nodes}], " + f"edges=[{edges}], " + f"symmetries=[{symm}]" + ")" + ) def matches(self, other: "Skeleton") -> bool: - """ - Compare this `Skeleton` to another, ignoring skeleton name and - the identities of the `Node` objects in each graph. + """Compare this `Skeleton` to another, ignoring name and node identities. Args: other: The other skeleton. Returns: - True if match, False otherwise. + `True` if the skeleton graphs are isomorphic and node names. """ def dict_match(dict1, dict2): @@ -144,17 +161,16 @@ def dict_match(dict1, dict2): if not is_isomorphic: return False - # Now check that the nodes have the same labels and order. They can have - # different weights I guess?! + # Now check that the nodes have the same labels and order. for node1, node2 in zip(self._graph.nodes, other._graph.nodes): if node1.name != node2.name: return False - # Check if the two graphs are equal return True @property def is_arborescence(self) -> bool: + """Return whether this skeleton graph forms an arborescence.""" return nx.algorithms.tree.recognition.is_arborescence(self._graph) @property @@ -171,7 +187,7 @@ def cycles(self) -> List[List[Node]]: @property def graph(self): - """Returns subgraph of BODY edges for skeleton.""" + """Return subgraph of BODY edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") @@ -184,7 +200,7 @@ def graph(self): @property def graph_symmetry(self): - """Returns subgraph of symmetric edges for skeleton.""" + """Return subgraph of symmetric edges for skeleton.""" edges = [ (src, dst, key) for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") @@ -194,8 +210,7 @@ def graph_symmetry(self): @staticmethod def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: - """ - Find all unique nodes from a list of skeletons. + """Find all unique nodes from a list of skeletons. Args: skeletons: The list of skeletons. @@ -207,8 +222,7 @@ def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: @staticmethod def make_cattr(idx_to_node: Dict[int, Node] = None) -> cattr.Converter: - """ - Make cattr.Convert() for `Skeleton`. + """Make cattr.Convert() for `Skeleton`. Make a cattr.Converter() that registers structure/unstructure hooks for Skeleton objects to handle serialization of skeletons. @@ -246,7 +260,8 @@ def name(self) -> str: @name.setter def name(self, name: str): - """ + """Set skeleton name (no-op). + A skeleton object cannot change its name. This property is immutable because it is used to hash skeletons. @@ -271,8 +286,7 @@ def name(self, name: str): @classmethod def rename_skeleton(cls, skeleton: "Skeleton", name: str) -> "Skeleton": - """ - Make copy of skeleton with new name. + """Make copy of skeleton with new name. This property is immutable because it is used to hash skeletons. If you want to rename a Skeleton you must use this class method. @@ -358,7 +372,6 @@ def edge_inds(self) -> List[Tuple[int, int]]: A list of (src_node_ind, dst_node_ind), where indices are subscripts into the Skeleton.nodes list. """ - return [ (self.nodes.index(src_node), self.nodes.index(dst_node)) for src_node, dst_node in self.edges @@ -391,9 +404,16 @@ def symmetries(self) -> List[Tuple[Node, Node]]: if edge_type == EdgeType.SYMMETRY ] # Get rid of duplicates - symmetries = list(set([tuple(set(e)) for e in symmetries])) + symmetries = list( + set([tuple(sorted(e, key=operator.attrgetter("name"))) for e in symmetries]) + ) return symmetries + @property + def symmetry_names(self) -> List[Tuple[str, str]]: + """List of symmetry edges as tuples of node names.""" + return [(s.name, d.name) for (s, d) in self.symmetries] + @property def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: """Get a list of all symmetries with keys and attributes. @@ -411,9 +431,18 @@ def symmetries_full(self) -> List[Tuple[Node, Node, Any, Any]]: if attr["type"] == EdgeType.SYMMETRY ] + @property + def symmetric_inds(self) -> np.ndarray: + """Return the symmetric nodes as an array of indices.""" + return np.array( + [ + [self.nodes.index(node1), self.nodes.index(node2)] + for node1, node2 in self.symmetries + ] + ) + def node_to_index(self, node: NodeRef) -> int: - """ - Return the index of the node, accepts either `Node` or name. + """Return the index of the node, accepts either `Node` or name. Args: node: The name of the node or the Node object. @@ -430,8 +459,8 @@ def node_to_index(self, node: NodeRef) -> int: except ValueError: return node_list.index(self.find_node(node)) - def edge_to_index(self, source: NodeRef, destination: NodeRef): - """Returns the index of edge from source to destination.""" + def edge_to_index(self, source: NodeRef, destination: NodeRef) -> int: + """Return the index of edge from source to destination.""" source = self.find_node(source) destination = self.find_node(destination) edge = (source, destination) @@ -449,9 +478,6 @@ def add_node(self, name: str): Raises: ValueError: If name is not unique. - - Returns: - None """ if not isinstance(name, str): raise TypeError("Cannot add nodes to the skeleton that are not str") @@ -462,14 +488,10 @@ def add_node(self, name: str): self._graph.add_node(Node(name)) def add_nodes(self, name_list: List[str]): - """ - Add a list of nodes representing animal parts to the skeleton. + """Add a list of nodes representing animal parts to the skeleton. Args: name_list: List of strings representing the nodes. - - Returns: - None """ for node in name_list: self.add_node(node) @@ -617,8 +639,7 @@ def delete_edge(self, source: str, destination: str): self._graph.remove_edge(source_node, destination_node) def clear_edges(self): - """Deletes all edges in skeleton.""" - + """Delete all edges in skeleton.""" for src, dst in self.edges: self.delete_edge(src, dst) @@ -642,8 +663,9 @@ def add_symmetry(self, node1: str, node2: str): """ node1_node, node2_node = self.find_node(node1), self.find_node(node2) - # We will represent symmetric pairs in the skeleton via additional edges in the _graph - # These edges will have a special attribute signifying they are not part of the skeleton itself + # We will represent symmetric pairs in the skeleton via additional edges in the + # _graph. These edges will have a special attribute signifying they are not part + # of the skeleton itself if node1 == node2: raise ValueError("Cannot add symmetry to the same node.") @@ -662,8 +684,7 @@ def add_symmetry(self, node1: str, node2: str): self._graph.add_edge(node2_node, node1_node, type=EdgeType.SYMMETRY) def delete_symmetry(self, node1: NodeRef, node2: NodeRef): - """ - Deletes a previously established symmetry between two nodes. + """Delete a previously established symmetry between two nodes. Args: node1: One node (by `Node` object or name) in symmetric pair. @@ -694,8 +715,7 @@ def delete_symmetry(self, node1: NodeRef, node2: NodeRef): self._graph.remove_edges_from(edges) def get_symmetry(self, node: NodeRef) -> Optional[Node]: - """ - Returns the node symmetric with the specified node. + """Return the node symmetric with the specified node. Args: node: Node (by `Node` object or name) to query. @@ -722,8 +742,7 @@ def get_symmetry(self, node: NodeRef) -> Optional[Node]: raise ValueError(f"{node} has more than one symmetry.") def get_symmetry_name(self, node: NodeRef) -> Optional[str]: - """ - Returns the name of the node symmetric with the specified node. + """Return the name of the node symmetric with the specified node. Args: node: Node (by `Node` object or name) to query. @@ -735,8 +754,7 @@ def get_symmetry_name(self, node: NodeRef) -> Optional[str]: return None if symmetric_node is None else symmetric_node.name def __getitem__(self, node_name: str) -> dict: - """ - Retrieves the node data associated with skeleton node. + """Retrieve the node data associated with skeleton node. Args: node_name: The name from which to retrieve data. @@ -755,8 +773,7 @@ def __getitem__(self, node_name: str) -> dict: return self._graph.nodes.data()[node] def __contains__(self, node_name: str) -> bool: - """ - Checks if specified node exists in skeleton. + """Check if specified node exists in skeleton. Args: node_name: the node name to query @@ -766,6 +783,10 @@ def __contains__(self, node_name: str) -> bool: """ return self.has_node(node_name) + def __len__(self) -> int: + """Return the number of nodes in the skeleton.""" + return len(self.nodes) + def relabel_node(self, old_name: str, new_name: str): """ Relabel a single node to a new name. diff --git a/sleap/training_profiles/baseline.centroid.json b/sleap/training_profiles/baseline.centroid.json index 1393b8f7f..d46a3694f 100755 --- a/sleap/training_profiles/baseline.centroid.json +++ b/sleap/training_profiles/baseline.centroid.json @@ -72,7 +72,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_large_rf.bottomup.json b/sleap/training_profiles/baseline_large_rf.bottomup.json index f54063945..eef85ec56 100644 --- a/sleap/training_profiles/baseline_large_rf.bottomup.json +++ b/sleap/training_profiles/baseline_large_rf.bottomup.json @@ -81,7 +81,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_large_rf.single.json b/sleap/training_profiles/baseline_large_rf.single.json index 222d60558..f35f81be2 100644 --- a/sleap/training_profiles/baseline_large_rf.single.json +++ b/sleap/training_profiles/baseline_large_rf.single.json @@ -72,7 +72,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_large_rf.topdown.json b/sleap/training_profiles/baseline_large_rf.topdown.json index 9c75a09c7..1156b64a0 100644 --- a/sleap/training_profiles/baseline_large_rf.topdown.json +++ b/sleap/training_profiles/baseline_large_rf.topdown.json @@ -73,7 +73,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_medium_rf.bottomup.json b/sleap/training_profiles/baseline_medium_rf.bottomup.json index ae53a9481..c75e54e7e 100644 --- a/sleap/training_profiles/baseline_medium_rf.bottomup.json +++ b/sleap/training_profiles/baseline_medium_rf.bottomup.json @@ -81,7 +81,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_medium_rf.single.json b/sleap/training_profiles/baseline_medium_rf.single.json index b01000c0f..152fdbb9a 100644 --- a/sleap/training_profiles/baseline_medium_rf.single.json +++ b/sleap/training_profiles/baseline_medium_rf.single.json @@ -72,7 +72,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/baseline_medium_rf.topdown.json b/sleap/training_profiles/baseline_medium_rf.topdown.json index fb54f6cf3..0cdfe1cca 100755 --- a/sleap/training_profiles/baseline_medium_rf.topdown.json +++ b/sleap/training_profiles/baseline_medium_rf.topdown.json @@ -73,7 +73,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.bottomup.json b/sleap/training_profiles/pretrained.bottomup.json index 22e2abf71..3b0e20112 100644 --- a/sleap/training_profiles/pretrained.bottomup.json +++ b/sleap/training_profiles/pretrained.bottomup.json @@ -78,7 +78,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.centroid.json b/sleap/training_profiles/pretrained.centroid.json index 45a2d9ea5..a535688e6 100644 --- a/sleap/training_profiles/pretrained.centroid.json +++ b/sleap/training_profiles/pretrained.centroid.json @@ -69,7 +69,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.single.json b/sleap/training_profiles/pretrained.single.json index 572b84ccd..1dfb8453f 100644 --- a/sleap/training_profiles/pretrained.single.json +++ b/sleap/training_profiles/pretrained.single.json @@ -69,7 +69,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/training_profiles/pretrained.topdown.json b/sleap/training_profiles/pretrained.topdown.json index c49e06a63..25cf1c54b 100644 --- a/sleap/training_profiles/pretrained.topdown.json +++ b/sleap/training_profiles/pretrained.topdown.json @@ -70,7 +70,9 @@ "contrast_max_gamma": 2.0, "brightness": false, "brightness_min_val": 0.0, - "brightness_max_val": 10.0 + "brightness_max_val": 10.0, + "random_flip": false, + "flip_horizontal": true }, "online_shuffling": true, "shuffle_buffer_size": 128, diff --git a/sleap/util.py b/sleap/util.py index 3b70b6395..0cc47328f 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -310,12 +310,13 @@ def get_config_file( def get_config_yaml(shortname: str, get_defaults: bool = False) -> dict: config_path = get_config_file(shortname, get_defaults=get_defaults) with open(config_path, "r") as f: - return yaml.load(f, Loader=yaml.SafeLoader) + return yaml.load(f, Loader=yaml.Loader) def save_config_yaml(shortname: str, data: Any) -> dict: yaml_path = get_config_file(shortname, ignore_file_not_found=True) with open(yaml_path, "w") as f: + print(f"Saving config: {yaml_path}") yaml.dump(data, f) @@ -379,16 +380,3 @@ def find_files_by_suffix( ) return matching_files - - -def open_file(filename): - """ - Opens file (as if double-clicked by user). - - https://stackoverflow.com/questions/17317219/is-there-an-platform-independent-equivalent-of-os-startfile/17317468#17317468 - """ - if sys.platform == "win32": - os.startfile(filename) - else: - opener = "open" if sys.platform == "darwin" else "xdg-open" - subprocess.call([opener, filename]) diff --git a/sleap/version.py b/sleap/version.py index 108d51f53..96c547661 100644 --- a/sleap/version.py +++ b/sleap/version.py @@ -10,4 +10,23 @@ Must be a semver string, "aN" should be appended for alpha releases. """ -__version__ = "1.1.0a9" + + +__version__ = "1.1.0a10" + + +def versions(): + """Print versions of SLEAP and other libraries.""" + import tensorflow as tf + import numpy as np + import platform + + vers = {} + vers["SLEAP"] = __version__ + vers["TensorFlow"] = tf.__version__ + vers["Numpy"] = np.__version__ + vers["Python"] = platform.python_version() + vers["OS"] = platform.platform() + + msg = "\n".join([f"{k}: {v}" for k, v in vers.items()]) + print(msg) diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json index d2c7f53b6..7e52d1703 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json @@ -43,7 +43,7 @@ "centroid": null, "centered_instance": null, "multi_instance": null, - "multi_class": { + "multi_class_bottomup": { "confmaps": { "part_names": null, "sigma": 5.0, diff --git a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json index 3ff85fc92..bcb2f26d5 100644 --- a/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +++ b/tests/data/models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json @@ -100,7 +100,7 @@ "centroid": null, "centered_instance": null, "multi_instance": null, - "multi_class": { + "multi_class_bottomup": { "confmaps": { "part_names": [ "head", diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 8c146ff79..1a30b7a4e 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -16,10 +16,12 @@ def min_centered_instance_model_path(): def min_bottomup_model_path(): return "tests/data/models/minimal_instance.UNet.bottomup" + @pytest.fixture def min_single_instance_robot_model_path(): return "tests/data/models/minimal_robot.UNet.single_instance" + @pytest.fixture def min_bottomup_multiclass_model_path(): return "tests/data/models/min_tracks_2node.UNet.bottomup_multiclass" diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 2efae2055..6528e6677 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1,5 +1,10 @@ -from sleap.gui.commands import CommandContext, ImportDeepLabCutFolder +from sleap.gui.commands import ( + CommandContext, + ImportDeepLabCutFolder, + get_new_version_filename, +) from sleap.io.pathutils import fix_path_separator +from pathlib import PurePath def test_delete_user_dialog(centered_pair_predictions): @@ -19,10 +24,12 @@ def test_delete_user_dialog(centered_pair_predictions): def test_import_labels_from_dlc_folder(): - csv_files = ImportDeepLabCutFolder.find_dlc_files_in_folder('tests/data/dlc_multiple_datasets') + csv_files = ImportDeepLabCutFolder.find_dlc_files_in_folder( + "tests/data/dlc_multiple_datasets" + ) assert set([fix_path_separator(f) for f in csv_files]) == { - 'tests/data/dlc_multiple_datasets/video2/dlc_dataset_2.csv', - 'tests/data/dlc_multiple_datasets/video1/dlc_dataset_1.csv', + "tests/data/dlc_multiple_datasets/video2/dlc_dataset_2.csv", + "tests/data/dlc_multiple_datasets/video1/dlc_dataset_1.csv", } labels = ImportDeepLabCutFolder.import_labels_from_dlc_files(csv_files) @@ -33,10 +40,26 @@ def test_import_labels_from_dlc_folder(): assert len(labels.nodes) == 3 assert len(labels.tracks) == 0 - assert set([fix_path_separator(l.video.backend.filename) for l in labels.labeled_frames]) == { - 'tests/data/dlc_multiple_datasets/video2/img002.jpg', - 'tests/data/dlc_multiple_datasets/video1/img000.jpg', - 'tests/data/dlc_multiple_datasets/video1/img000.jpg', + assert set( + [fix_path_separator(l.video.backend.filename) for l in labels.labeled_frames] + ) == { + "tests/data/dlc_multiple_datasets/video2/img002.jpg", + "tests/data/dlc_multiple_datasets/video1/img000.jpg", + "tests/data/dlc_multiple_datasets/video1/img000.jpg", } - assert set([l.frame_idx for l in labels.labeled_frames]) == {0, 0, 1} \ No newline at end of file + assert set([l.frame_idx for l in labels.labeled_frames]) == {0, 0, 1} + + +def test_get_new_version_filename(): + assert get_new_version_filename("labels.slp") == "labels copy.slp" + assert get_new_version_filename("labels.v0.slp") == "labels.v1.slp" + assert get_new_version_filename("/a/b/labels.slp") == str( + PurePath("/a/b/labels copy.slp") + ) + assert get_new_version_filename("/a/b/labels.v0.slp") == str( + PurePath("/a/b/labels.v1.slp") + ) + assert get_new_version_filename("/a/b/labels.v01.slp") == str( + PurePath("/a/b/labels.v02.slp") + ) diff --git a/tests/gui/test_inference_gui.py b/tests/gui/test_inference_gui.py index 74039595a..0a58c63e7 100644 --- a/tests/gui/test_inference_gui.py +++ b/tests/gui/test_inference_gui.py @@ -22,7 +22,10 @@ def test_config_list_order(): # Check that all 'old' configs (if any) are last in the collected configs list for i in range(len(configs) - 1): # if current config is 'old', next must be 'old' as well - assert not configs[i].filename.startswith('old.') or configs[i + 1].filename.startswith('old.') + assert not configs[i].filename.startswith("old.") or configs[ + i + 1 + ].filename.startswith("old.") + def test_scoped_key_dict(): d = {"foo": 1, "bar": {"cat": {"dog": 2}, "elephant": 3}} @@ -42,7 +45,8 @@ def test_inference_cli_builder(): ) item_for_inference = runners.VideoItemForInference( - video=Video.from_filename("video.mp4"), frames=[1, 2, 3], + video=Video.from_filename("video.mp4"), + frames=[1, 2, 3], ) cli_args, output_path = inference_task.make_predict_cli_call(item_for_inference) @@ -60,16 +64,19 @@ def test_inference_cli_builder(): def test_inference_cli_output_path(): inference_task = runners.InferenceTask( - trained_job_paths=["model1", "model2"], inference_params=dict(), + trained_job_paths=["model1", "model2"], + inference_params=dict(), ) item_for_inference = runners.VideoItemForInference( - video=Video.from_filename("video.mp4"), frames=[1, 2, 3], + video=Video.from_filename("video.mp4"), + frames=[1, 2, 3], ) # Try with specified output path cli_args, output_path = inference_task.make_predict_cli_call( - item_for_inference, output_path="another_output_path.slp", + item_for_inference, + output_path="another_output_path.slp", ) assert output_path == "another_output_path.slp" diff --git a/tests/gui/test_video_player.py b/tests/gui/test_video_player.py index cf16f6934..28bcaae49 100644 --- a/tests/gui/test_video_player.py +++ b/tests/gui/test_video_player.py @@ -1,6 +1,11 @@ import numpy as np from sleap import Instance, Skeleton -from sleap.gui.widgets.video import QtVideoPlayer, GraphicsView, QtInstance, QtVideoPlayer +from sleap.gui.widgets.video import ( + QtVideoPlayer, + GraphicsView, + QtInstance, + QtVideoPlayer, +) import PySide2.QtCore as QtCore diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index aab1875d7..6ec1e5568 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -513,7 +513,9 @@ def test_merge_with_package(min_labels_robot, tmpdir): labels_pkg = sleap.load_file(pkg_path) assert isinstance(labels_pkg.video.backend, sleap.io.video.HDF5Video) assert labels_pkg.video.backend.has_embedded_images - assert isinstance(labels_pkg.video.backend._source_video.backend, sleap.io.video.MediaVideo) + assert isinstance( + labels_pkg.video.backend._source_video.backend, sleap.io.video.MediaVideo + ) assert len(labels_pkg.predicted_instances) == 0 # Add prediction. @@ -521,10 +523,12 @@ def test_merge_with_package(min_labels_robot, tmpdir): inst_pr = sleap.PredictedInstance.from_pointsarray( inst.numpy(), skeleton=labels_pkg.skeleton ) - labels_pkg.append(sleap.LabeledFrame( - video=labels_pkg.suggestions[0].video, - frame_idx=labels_pkg.suggestions[0].frame_idx, - instances=[inst_pr]) + labels_pkg.append( + sleap.LabeledFrame( + video=labels_pkg.suggestions[0].video, + frame_idx=labels_pkg.suggestions[0].frame_idx, + instances=[inst_pr], + ) ) # Save labels without image data. @@ -537,7 +541,9 @@ def test_merge_with_package(min_labels_robot, tmpdir): # Merge with base labels. base_video_path = labels.video.backend.filename - merged, extra_base, extra_new = sleap.Labels.complex_merge_between(labels, labels_pr) + merged, extra_base, extra_new = sleap.Labels.complex_merge_between( + labels, labels_pr + ) assert len(labels.videos) == 1 assert labels.video.backend.filename == base_video_path assert len(labels.predicted_instances) == 1 @@ -550,7 +556,9 @@ def test_merge_with_package(min_labels_robot, tmpdir): labels_pr = sleap.load_file(preds_path) assert len(labels_pkg.predicted_instances) == 0 base_video_path = labels_pkg.video.backend.filename - merged, extra_base, extra_new = sleap.Labels.complex_merge_between(labels_pkg, labels_pr) + merged, extra_base, extra_new = sleap.Labels.complex_merge_between( + labels_pkg, labels_pr + ) assert len(labels_pkg.videos) == 1 assert labels_pkg.video.backend.filename == base_video_path assert len(labels_pkg.predicted_instances) == 1 @@ -771,25 +779,41 @@ def test_save_frame_data_hdf5(min_labels_slp, tmpdir): fn = os.path.join(tmpdir, "test_user_only.slp") labels.save_frame_data_hdf5( - fn, format="png", user_labeled=True, all_labeled=False, suggested=False, + fn, + format="png", + user_labeled=True, + all_labeled=False, + suggested=False, ) assert Video.from_filename(fn, dataset="video0").embedded_frame_inds == [0] fn = os.path.join(tmpdir, "test_all_labeled.slp") labels.save_frame_data_hdf5( - fn, format="png", user_labeled=False, all_labeled=True, suggested=False, + fn, + format="png", + user_labeled=False, + all_labeled=True, + suggested=False, ) assert Video.from_filename(fn, dataset="video0").embedded_frame_inds == [0, 1] fn = os.path.join(tmpdir, "test_suggested.slp") labels.save_frame_data_hdf5( - fn, format="png", user_labeled=False, all_labeled=False, suggested=True, + fn, + format="png", + user_labeled=False, + all_labeled=False, + suggested=True, ) assert Video.from_filename(fn, dataset="video0").embedded_frame_inds == [2] fn = os.path.join(tmpdir, "test_all.slp") labels.save_frame_data_hdf5( - fn, format="png", user_labeled=False, all_labeled=True, suggested=True, + fn, + format="png", + user_labeled=False, + all_labeled=True, + suggested=True, ) assert Video.from_filename(fn, dataset="video0").embedded_frame_inds == [0, 1, 2] @@ -801,25 +825,37 @@ def test_save_labels_with_images(min_labels_slp, tmpdir): fn = os.path.join(tmpdir, "test_user_only.slp") labels.save( - fn, with_images=True, embed_all_labeled=False, embed_suggested=False, + fn, + with_images=True, + embed_all_labeled=False, + embed_suggested=False, ) assert Labels.load_file(fn).video.embedded_frame_inds == [0] fn = os.path.join(tmpdir, "test_all_labeled.slp") labels.save( - fn, with_images=True, embed_all_labeled=True, embed_suggested=False, + fn, + with_images=True, + embed_all_labeled=True, + embed_suggested=False, ) assert Labels.load_file(fn).video.embedded_frame_inds == [0, 1] fn = os.path.join(tmpdir, "test_suggested.slp") labels.save( - fn, with_images=True, embed_all_labeled=False, embed_suggested=True, + fn, + with_images=True, + embed_all_labeled=False, + embed_suggested=True, ) assert Labels.load_file(fn).video.embedded_frame_inds == [0, 2] fn = os.path.join(tmpdir, "test_all.slp") labels.save( - fn, with_images=True, embed_all_labeled=True, embed_suggested=True, + fn, + with_images=True, + embed_all_labeled=True, + embed_suggested=True, ) assert Labels.load_file(fn).video.embedded_frame_inds == [0, 1, 2] diff --git a/tests/io/test_formats.py b/tests/io/test_formats.py index 4793f54ba..137287516 100644 --- a/tests/io/test_formats.py +++ b/tests/io/test_formats.py @@ -119,7 +119,12 @@ def test_analysis_hdf5(tmpdir, centered_pair_predictions): write_analysis(centered_pair_predictions, output_path=filename, all_frames=True) - labels = read(filename, for_object="labels", as_format="analysis", video=video,) + labels = read( + filename, + for_object="labels", + as_format="analysis", + video=video, + ) assert len(labels) == len(centered_pair_predictions) assert len(labels.tracks) == len(centered_pair_predictions.tracks) diff --git a/tests/io/test_video.py b/tests/io/test_video.py index dde703661..767ae60ce 100644 --- a/tests/io/test_video.py +++ b/tests/io/test_video.py @@ -4,12 +4,13 @@ import numpy as np -from sleap.io.video import Video, HDF5Video, MediaVideo, DummyVideo +from sleap.io.video import Video, HDF5Video, MediaVideo, DummyVideo, load_video from tests.fixtures.videos import ( TEST_H5_FILE, TEST_SMALL_ROBOT_MP4_FILE, TEST_H5_DSET, TEST_H5_INPUT_FORMAT, + TEST_SMALL_CENTERED_PAIR_VID, ) # FIXME: @@ -117,7 +118,7 @@ def test_is_missing(): assert vid.is_missing vid = Video.from_numpy( Video.from_media(TEST_SMALL_ROBOT_MP4_FILE).get_frames((3, 7, 9)) - ) + ) assert not vid.is_missing @@ -399,3 +400,8 @@ def test_safe_frame_loading_all_invalid(): assert idxs == [] assert frames is None + +def test_load_video(): + video = load_video(TEST_SMALL_CENTERED_PAIR_VID) + assert video.shape == (1100, 384, 384, 1) + assert video[:3].shape == (3, 384, 384, 1) diff --git a/tests/io/test_visuals.py b/tests/io/test_visuals.py index cf8262f8a..c09f0747c 100644 --- a/tests/io/test_visuals.py +++ b/tests/io/test_visuals.py @@ -41,7 +41,8 @@ def test_serial_pipeline(centered_pair_predictions, tmpdir): # Make sure we can mark images marked_image_list = marker_thread._mark_images( - frame_indices=frames, frame_images=small_images, + frame_indices=frames, + frame_images=small_images, ) # There's a point at 201, 186 (i.e. 50.25, 46.5), so make sure it got marked diff --git a/tests/nn/architectures/test_common.py b/tests/nn/architectures/test_common.py index c764a84f3..a40d621ef 100644 --- a/tests/nn/architectures/test_common.py +++ b/tests/nn/architectures/test_common.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.architectures import common diff --git a/tests/nn/architectures/test_encoder_decoder.py b/tests/nn/architectures/test_encoder_decoder.py index 7373ae50c..3ce019371 100644 --- a/tests/nn/architectures/test_encoder_decoder.py +++ b/tests/nn/architectures/test_encoder_decoder.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.architectures import encoder_decoder @@ -122,9 +124,9 @@ def test_simple_conv_block_pool_before_convs(self): def test_simple_upsampling_block(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = False, - interp_method = "bilinear", - refine_convs = 0, + transposed_conv=False, + interp_method="bilinear", + refine_convs=0, ) x_in = tf.keras.Input((8, 8, 1)) x = block.make_block(x_in) @@ -139,14 +141,14 @@ def test_simple_upsampling_block(self): def test_simple_upsampling_block_trans_conv(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = True, - transposed_conv_filters = 8, - transposed_conv_kernel_size = 3, - transposed_conv_use_bias = True, - transposed_conv_batch_norm = True, - transposed_conv_batch_norm_before_activation = True, - transposed_conv_activation = "relu", - refine_convs = 0, + transposed_conv=True, + transposed_conv_filters=8, + transposed_conv_kernel_size=3, + transposed_conv_use_bias=True, + transposed_conv_batch_norm=True, + transposed_conv_batch_norm_before_activation=True, + transposed_conv_activation="relu", + refine_convs=0, ) x_in = tf.keras.Input((8, 8, 1)) x = block.make_block(x_in) @@ -163,14 +165,14 @@ def test_simple_upsampling_block_trans_conv(self): def test_simple_upsampling_block_trans_conv_bn_post(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = True, - transposed_conv_filters = 8, - transposed_conv_kernel_size = 3, - transposed_conv_use_bias = True, - transposed_conv_batch_norm = True, - transposed_conv_batch_norm_before_activation = False, - transposed_conv_activation = "relu", - refine_convs = 0, + transposed_conv=True, + transposed_conv_filters=8, + transposed_conv_kernel_size=3, + transposed_conv_use_bias=True, + transposed_conv_batch_norm=True, + transposed_conv_batch_norm_before_activation=False, + transposed_conv_activation="relu", + refine_convs=0, ) x_in = tf.keras.Input((8, 8, 1)) x = block.make_block(x_in) @@ -187,11 +189,11 @@ def test_simple_upsampling_block_trans_conv_bn_post(self): def test_simple_upsampling_block_ignore_skip_source(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = False, - interp_method = "bilinear", - skip_connection = False, - skip_add = False, - refine_convs = 0, + transposed_conv=False, + interp_method="bilinear", + skip_connection=False, + skip_add=False, + refine_convs=0, ) x_in = tf.keras.Input((8, 8, 1)) skip_src = tf.keras.Input((16, 16, 1)) @@ -207,11 +209,11 @@ def test_simple_upsampling_block_ignore_skip_source(self): def test_simple_upsampling_block_skip_add(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = False, - interp_method = "bilinear", - skip_connection = True, - skip_add = True, - refine_convs = 0, + transposed_conv=False, + interp_method="bilinear", + skip_connection=True, + skip_add=True, + refine_convs=0, ) x_in = tf.keras.Input((8, 8, 1)) skip_src = tf.ones((1, 16, 16, 1)) @@ -229,11 +231,11 @@ def test_simple_upsampling_block_skip_add(self): def test_simple_upsampling_block_skip_add_adjust_channels(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = False, - interp_method = "bilinear", - skip_connection = True, - skip_add = True, - refine_convs = 0, + transposed_conv=False, + interp_method="bilinear", + skip_connection=True, + skip_add=True, + refine_convs=0, ) x_in = tf.keras.Input((8, 8, 1)) skip_src = tf.keras.Input((16, 16, 4)) @@ -242,7 +244,7 @@ def test_simple_upsampling_block_skip_add_adjust_channels(self): self.assertEqual(len(model.layers), 5) self.assertEqual(len(model.trainable_weights), 2) - self.assertEqual(model.count_params(), 1+4) + self.assertEqual(model.count_params(), 1 + 4) self.assertAllEqual(model.output.shape, (None, 16, 16, 1)) self.assertIsInstance(model.layers[3], tf.keras.layers.UpSampling2D) self.assertIsInstance(model.layers[2], tf.keras.layers.Conv2D) @@ -251,11 +253,11 @@ def test_simple_upsampling_block_skip_add_adjust_channels(self): def test_simple_upsampling_block_skip_concat(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = False, - interp_method = "bilinear", - skip_connection = True, - skip_add = False, - refine_convs = 0, + transposed_conv=False, + interp_method="bilinear", + skip_connection=True, + skip_add=False, + refine_convs=0, ) x_in = tf.keras.Input((8, 8, 1)) skip_src = tf.keras.Input((16, 16, 4)) @@ -272,16 +274,16 @@ def test_simple_upsampling_block_skip_concat(self): def test_simple_upsampling_block_refine_convs(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = False, - interp_method = "bilinear", - skip_connection = True, - refine_convs = 2, - refine_convs_filters = 16, - refine_convs_use_bias = True, - refine_convs_kernel_size = 3, - refine_convs_batch_norm = True, - refine_convs_batch_norm_before_activation = True, - refine_convs_activation = "relu", + transposed_conv=False, + interp_method="bilinear", + skip_connection=True, + refine_convs=2, + refine_convs_filters=16, + refine_convs_use_bias=True, + refine_convs_kernel_size=3, + refine_convs_batch_norm=True, + refine_convs_batch_norm_before_activation=True, + refine_convs_activation="relu", ) x_in = tf.keras.Input((8, 8, 1)) x = block.make_block(x_in) @@ -299,16 +301,16 @@ def test_simple_upsampling_block_refine_convs(self): def test_simple_upsampling_block_refine_convs_bn_post(self): block = encoder_decoder.SimpleUpsamplingBlock( upsampling_stride=2, - transposed_conv = False, - interp_method = "bilinear", - skip_connection = True, - refine_convs = 2, - refine_convs_filters = 16, - refine_convs_use_bias = True, - refine_convs_kernel_size = 3, - refine_convs_batch_norm = True, - refine_convs_batch_norm_before_activation = False, - refine_convs_activation = "relu", + transposed_conv=False, + interp_method="bilinear", + skip_connection=True, + refine_convs=2, + refine_convs_filters=16, + refine_convs_use_bias=True, + refine_convs_kernel_size=3, + refine_convs_batch_norm=True, + refine_convs_batch_norm_before_activation=False, + refine_convs_activation="relu", ) x_in = tf.keras.Input((8, 8, 1)) x = block.make_block(x_in) @@ -322,4 +324,3 @@ def test_simple_upsampling_block_refine_convs_bn_post(self): self.assertIsInstance(model.layers[2], tf.keras.layers.Conv2D) self.assertIsInstance(model.layers[3], tf.keras.layers.Activation) self.assertIsInstance(model.layers[4], tf.keras.layers.BatchNormalization) - diff --git a/tests/nn/architectures/test_hourglass.py b/tests/nn/architectures/test_hourglass.py index 6763a6321..4efe79a1c 100644 --- a/tests/nn/architectures/test_hourglass.py +++ b/tests/nn/architectures/test_hourglass.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.architectures import hourglass from sleap.nn.config import HourglassConfig @@ -17,7 +19,7 @@ def test_hourglass_reference(self): filters=256, filter_increase=128, interp_method="nearest", - stacks=3 + stacks=3, ) x_in = tf.keras.layers.Input((256, 256, 1)) x, x_mid = arch.make_backbone(x_in) @@ -28,8 +30,8 @@ def test_hourglass_reference(self): with self.subTest("output shape"): self.assertAllEqual( - [out.shape for out in model.output], - [(None, 64, 64, 256)] * 3) + [out.shape for out in model.output], [(None, 64, 64, 256)] * 3 + ) with self.subTest("encoder stride"): self.assertEqual(arch.encoder_features_stride, 64) with self.subTest("decoder stride"): @@ -47,15 +49,17 @@ def test_hourglass_reference(self): def test_hourglass_reference_from_config(self): # Reference implementation from the original paper. - arch = hourglass.Hourglass.from_config(HourglassConfig( - stem_stride=4, - max_stride=64, - output_stride=4, - stem_filters=128, - filters=256, - filter_increase=128, - stacks=3, - )) + arch = hourglass.Hourglass.from_config( + HourglassConfig( + stem_stride=4, + max_stride=64, + output_stride=4, + stem_filters=128, + filters=256, + filter_increase=128, + stacks=3, + ) + ) x_in = tf.keras.layers.Input((256, 256, 1)) x, x_mid = arch.make_backbone(x_in) model = tf.keras.Model(x_in, x) @@ -65,8 +69,8 @@ def test_hourglass_reference_from_config(self): with self.subTest("output shape"): self.assertAllEqual( - [out.shape for out in model.output], - [(None, 64, 64, 256)] * 3) + [out.shape for out in model.output], [(None, 64, 64, 256)] * 3 + ) with self.subTest("encoder stride"): self.assertEqual(arch.encoder_features_stride, 64) with self.subTest("decoder stride"): diff --git a/tests/nn/architectures/test_leap.py b/tests/nn/architectures/test_leap.py index e2f265edc..edf07396b 100644 --- a/tests/nn/architectures/test_leap.py +++ b/tests/nn/architectures/test_leap.py @@ -1,10 +1,13 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.architectures import leap from sleap.nn.config import LEAPConfig + class LeapTests(tf.test.TestCase): def test_leap_cnn_reference(self): # Reference implementation from the original paper. @@ -80,13 +83,14 @@ def test_leap_cnn_interp(self): ) def test_leap_cnn_reference_from_config(self): - arch = leap.LeapCNN.from_config(LEAPConfig( - max_stride=8, - output_stride=1, - filters=64, - filters_rate=2, - up_interpolate=False, - stacks=1 + arch = leap.LeapCNN.from_config( + LEAPConfig( + max_stride=8, + output_stride=1, + filters=64, + filters_rate=2, + up_interpolate=False, + stacks=1, ) ) x_in = tf.keras.layers.Input((192, 192, 1)) diff --git a/tests/nn/architectures/test_pretrained_encoders.py b/tests/nn/architectures/test_pretrained_encoders.py index 8111618e7..f318754ac 100644 --- a/tests/nn/architectures/test_pretrained_encoders.py +++ b/tests/nn/architectures/test_pretrained_encoders.py @@ -1,7 +1,9 @@ import numpy as np import tensorflow as tf import pytest -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.architectures import UnetPretrainedEncoder from sleap.nn.config import PretrainedEncoderConfig diff --git a/tests/nn/architectures/test_resnet.py b/tests/nn/architectures/test_resnet.py index 685d95716..965ea3b72 100644 --- a/tests/nn/architectures/test_resnet.py +++ b/tests/nn/architectures/test_resnet.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.architectures import upsampling from sleap.nn.architectures import resnet @@ -51,7 +53,9 @@ def test_resnet50_stride16(self): with self.subTest("feature output stride"): self.assertEqual(model.get_layer("conv5_block1_1_conv").strides, (1, 1)) with self.subTest("feature output dilation rate"): - self.assertEqual(model.get_layer("conv5_block1_1_conv").dilation_rate, (2, 2)) + self.assertEqual( + model.get_layer("conv5_block1_1_conv").dilation_rate, (2, 2) + ) def test_resnet50_upsampling(self): resnet50 = resnet.ResNet50( @@ -167,13 +171,15 @@ def test_resnet152(self): self.assertEqual(model.count_params(), 58364672) def test_resnet50_from_config(self): - resnet50 = resnet.ResNet50.from_config(ResNetConfig( - version="ResNet50", - weights="random", - upsampling=None, - max_stride=32, - output_stride=32, - )) + resnet50 = resnet.ResNet50.from_config( + ResNetConfig( + version="ResNet50", + weights="random", + upsampling=None, + max_stride=32, + output_stride=32, + ) + ) x_in = tf.keras.layers.Input((160, 160, 1)) x, x_mid = resnet50.make_backbone(x_in) model = tf.keras.Model(x_in, x) diff --git a/tests/nn/architectures/test_upsampling.py b/tests/nn/architectures/test_upsampling.py index daf8c6fac..ca1526fdc 100644 --- a/tests/nn/architectures/test_upsampling.py +++ b/tests/nn/architectures/test_upsampling.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.architectures import upsampling from sleap.nn.config import UpsamplingConfig @@ -182,16 +184,18 @@ def test_upsampling_stack_upsampling_add(self): ) def test_upsampling_stack_upsampling_concat(self): - upsampling_stack = upsampling.UpsamplingStack.from_config(UpsamplingConfig( + upsampling_stack = upsampling.UpsamplingStack.from_config( + UpsamplingConfig( method="transposed_conv", skip_connections="concatenate", block_stride=2, filters=64, - filters_rate=1., + filters_rate=1.0, refine_convs=1, batch_norm=True, transposed_conv_kernel_size=4, - ), output_stride=4 + ), + output_stride=4, ) x, intermediate_feats = upsampling_stack.make_stack( tf.keras.Input((8, 8, 32)), current_stride=16 diff --git a/tests/nn/data/test_augmentation.py b/tests/nn/data/test_augmentation.py index 59d63ff6c..d565a80b1 100644 --- a/tests/nn/data/test_augmentation.py +++ b/tests/nn/data/test_augmentation.py @@ -1,6 +1,9 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +import sleap +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import providers from sleap.nn.data import augmentation @@ -49,5 +52,117 @@ def test_random_cropper(min_labels): assert "crop_bbox" in example offset = tf.stack([example["crop_bbox"][0, 1], example["crop_bbox"][0, 0]], axis=-1) assert tf.reduce_all( - example["instances"] == ( - example_preaug["instances"] - tf.expand_dims(offset, axis=0))) + example["instances"] + == (example_preaug["instances"] - tf.expand_dims(offset, axis=0)) + ) + + +def test_flip_instances_lr(): + insts = tf.cast( + [ + [[0, 1], [2, 3]], + [[4, 5], [6, 7]], + ], + tf.float32, + ) + + insts_flip = augmentation.flip_instances_lr(insts, 8) + np.testing.assert_array_equal(insts_flip, [[[7, 1], [5, 3]], [[3, 5], [1, 7]]]) + + insts_flip1 = augmentation.flip_instances_lr(insts, 8, [[0, 1]]) + insts_flip2 = augmentation.flip_instances_lr(insts, 8, [[1, 0]]) + np.testing.assert_array_equal(insts_flip1, [[[5, 3], [7, 1]], [[1, 7], [3, 5]]]) + np.testing.assert_array_equal(insts_flip1, insts_flip2) + + +def test_flip_instances_ud(): + insts = tf.cast( + [ + [[0, 1], [2, 3]], + [[4, 5], [6, 7]], + ], + tf.float32, + ) + + insts_flip = augmentation.flip_instances_ud(insts, 8) + np.testing.assert_array_equal(insts_flip, [[[0, 6], [2, 4]], [[4, 2], [6, 0]]]) + + insts_flip1 = augmentation.flip_instances_ud(insts, 8, [[0, 1]]) + insts_flip2 = augmentation.flip_instances_ud(insts, 8, [[1, 0]]) + np.testing.assert_array_equal(insts_flip1, [[[2, 4], [0, 6]], [[6, 0], [4, 2]]]) + np.testing.assert_array_equal(insts_flip1, insts_flip2) + + +def test_random_flipper(): + vid = sleap.Video.from_filename( + "tests/data/json_format_v1/centered_pair_low_quality.mp4" + ) + skel = sleap.Skeleton.from_names_and_edge_inds(["A", "BL", "BR"], [[0, 1], [0, 2]]) + labels = sleap.Labels( + [ + sleap.LabeledFrame( + video=vid, + frame_idx=0, + instances=[ + sleap.Instance.from_pointsarray( + [[25, 50], [50, 25], [25, 25]], skeleton=skel + ), + sleap.Instance.from_pointsarray( + [[125, 150], [150, 125], [125, 125]], skeleton=skel + ), + ], + ) + ] + ) + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=True, probability=1.0 + ) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0][:, ::-1]) + np.testing.assert_array_equal( + ex["instances"], + [ + [[358.0, 50.0], [333.0, 25.0], [358.0, 25.0]], + [[258.0, 150.0], [233.0, 125.0], [258.0, 125.0]], + ], + ) + + skel.add_symmetry("BL", "BR") + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=True, probability=1.0 + ) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0][:, ::-1]) + np.testing.assert_array_equal( + ex["instances"], + [ + [[358.0, 50.0], [358.0, 25.0], [333.0, 25.0]], + [[258.0, 150.0], [258.0, 125.0], [233.0, 125.0]], + ], + ) + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=True, probability=0.0 + ) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0]) + np.testing.assert_array_equal( + ex["instances"], + [[[25, 50], [50, 25], [25, 25]], [[125, 150], [150, 125], [125, 125]]], + ) + + p = labels.to_pipeline() + p += sleap.nn.data.augmentation.RandomFlipper.from_skeleton( + skel, horizontal=False, probability=1.0 + ) + ex = p.peek() + np.testing.assert_array_equal(ex["image"], vid[0][0][::-1, :]) + np.testing.assert_array_equal( + ex["instances"], + [[[25, 333], [25, 358], [50, 358]], [[125, 233], [125, 258], [150, 258]]], + ) diff --git a/tests/nn/data/test_confidence_maps.py b/tests/nn/data/test_confidence_maps.py index 750600c58..ada8ce88b 100644 --- a/tests/nn/data/test_confidence_maps.py +++ b/tests/nn/data/test_confidence_maps.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import providers from sleap.nn.data import instance_centroids @@ -18,57 +20,65 @@ def test_make_confmaps(): xv, yv = make_grid_vectors(image_height=4, image_width=5, output_stride=1) - points = tf.cast([[0.5, 1.], - [3, 3.5], - [2., 2.]], tf.float32) - cm = make_confmaps(points, xv, yv, sigma=1.) + points = tf.cast([[0.5, 1.0], [3, 3.5], [2.0, 2.0]], tf.float32) + cm = make_confmaps(points, xv, yv, sigma=1.0) assert cm.dtype == tf.float32 assert cm.shape == (4, 5, 3) np.testing.assert_allclose( cm, - [[[0.535, 0. , 0.018], - [0.535, 0. , 0.082], - [0.197, 0.001, 0.135], - [0.027, 0.002, 0.082], - [0.001, 0.001, 0.018],], - [[0.882, 0. , 0.082], - [0.882, 0.006, 0.368], - [0.325, 0.027, 0.607], - [0.044, 0.044, 0.368], - [0.002, 0.027, 0.082],], - [[0.535, 0.004, 0.135], - [0.535, 0.044, 0.607], - [0.197, 0.197, 1. ], - [0.027, 0.325, 0.607], - [0.001, 0.197, 0.135],], - [[0.119, 0.01 , 0.082], - [0.119, 0.119, 0.368], - [0.044, 0.535, 0.607], - [0.006, 0.882, 0.368], - [0. , 0.535, 0.082],]], - atol=1e-3 - ) + [ + [ + [0.535, 0.0, 0.018], + [0.535, 0.0, 0.082], + [0.197, 0.001, 0.135], + [0.027, 0.002, 0.082], + [0.001, 0.001, 0.018], + ], + [ + [0.882, 0.0, 0.082], + [0.882, 0.006, 0.368], + [0.325, 0.027, 0.607], + [0.044, 0.044, 0.368], + [0.002, 0.027, 0.082], + ], + [ + [0.535, 0.004, 0.135], + [0.535, 0.044, 0.607], + [0.197, 0.197, 1.0], + [0.027, 0.325, 0.607], + [0.001, 0.197, 0.135], + ], + [ + [0.119, 0.01, 0.082], + [0.119, 0.119, 0.368], + [0.044, 0.535, 0.607], + [0.006, 0.882, 0.368], + [0.0, 0.535, 0.082], + ], + ], + atol=1e-3, + ) # Grid aligned peak points = tf.cast([[2, 3]], tf.float32) - cm = make_confmaps(points, xv, yv, sigma=1.) + cm = make_confmaps(points, xv, yv, sigma=1.0) assert cm.shape == (4, 5, 1) assert cm[3, 2] == 1.0 # Output stride xv, yv = make_grid_vectors(image_height=8, image_width=8, output_stride=2) points = tf.cast([[2, 4]], tf.float32) - cm = make_confmaps(points, xv, yv, sigma=1.) + cm = make_confmaps(points, xv, yv, sigma=1.0) assert cm.shape == (4, 4, 1) assert cm[2, 1] == 1.0 # Missing points xv, yv = make_grid_vectors(image_height=8, image_width=8, output_stride=2) points = tf.cast([[2, 4]], tf.float32) - cm = make_confmaps(points, xv, yv, sigma=1.) + cm = make_confmaps(points, xv, yv, sigma=1.0) points_with_nan = tf.cast([[2, 4], [np.nan, np.nan]], tf.float32) - cm_with_nan = make_confmaps(points_with_nan, xv, yv, sigma=1.) + cm_with_nan = make_confmaps(points_with_nan, xv, yv, sigma=1.0) assert cm_with_nan.shape == (4, 4, 2) assert cm_with_nan.dtype == tf.float32 np.testing.assert_array_equal(cm_with_nan[:, :, 0], cm[:, :, 0]) @@ -77,34 +87,39 @@ def test_make_confmaps(): def test_make_multi_confmaps(): xv, yv = make_grid_vectors(image_height=4, image_width=5, output_stride=1) - instances = tf.cast([ - [[0.5, 1.], [2., 2.]], - [[1.5, 1.], [2., 3.]], - [[np.nan, np.nan], [-1., 5.]], - ], tf.float32) - cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.) + instances = tf.cast( + [ + [[0.5, 1.0], [2.0, 2.0]], + [[1.5, 1.0], [2.0, 3.0]], + [[np.nan, np.nan], [-1.0, 5.0]], + ], + tf.float32, + ) + cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.0) assert cms.shape == (4, 5, 2) assert cms.dtype == tf.float32 - cm0 = make_confmaps(instances[0], xv=xv, yv=yv, sigma=1.) - cm1 = make_confmaps(instances[1], xv=xv, yv=yv, sigma=1.) - cm2 = make_confmaps(instances[2], xv=xv, yv=yv, sigma=1.) + cm0 = make_confmaps(instances[0], xv=xv, yv=yv, sigma=1.0) + cm1 = make_confmaps(instances[1], xv=xv, yv=yv, sigma=1.0) + cm2 = make_confmaps(instances[2], xv=xv, yv=yv, sigma=1.0) np.testing.assert_array_equal( - cms, - tf.reduce_max(tf.stack([cm0, cm1, cm2], axis=-1), axis=-1) + cms, tf.reduce_max(tf.stack([cm0, cm1, cm2], axis=-1), axis=-1) ) def test_make_multi_confmaps_with_offsets(): xv, yv = make_grid_vectors(image_height=4, image_width=5, output_stride=1) - instances = tf.cast([ - [[0.5, 1.], [2., 2.]], - [[1.5, 1.], [2., 3.]], - [[np.nan, np.nan], [-1., 5.]], - ], tf.float32) + instances = tf.cast( + [ + [[0.5, 1.0], [2.0, 2.0]], + [[1.5, 1.0], [2.0, 3.0]], + [[np.nan, np.nan], [-1.0, 5.0]], + ], + tf.float32, + ) cms, offsets = make_multi_confmaps_with_offsets( - instances, xv, yv, stride=1, sigma=1., offsets_threshold=0.2 + instances, xv, yv, stride=1, sigma=1.0, offsets_threshold=0.2 ) assert offsets.shape == (4, 5, 2, 2) @@ -131,6 +146,7 @@ def test_single_instance_confidence_map_generator(min_labels_robot): assert example["offsets"].shape == (320 // 2, 560 // 2, 4) assert example["offsets"].dtype == tf.float32 + def test_multi_confidence_map_generator(min_labels): labels_reader = providers.LabelsReader(min_labels) multi_confmap_generator = MultiConfidenceMapGenerator( @@ -290,11 +306,8 @@ def test_instance_confidence_map_generator_with_all_instances(min_labels): np.testing.assert_allclose( all_cms[y, x, :], - [[[0.91393119, 0.], - [0., 0.94459903]], - [[0., 0.], - [0., 0.]]], - atol=1e-6 + [[[0.91393119, 0.0], [0.0, 0.94459903]], [[0.0, 0.0], [0.0, 0.0]]], + atol=1e-6, ) instance_confmap_generator.with_offsets = True diff --git a/tests/nn/data/test_data_training.py b/tests/nn/data/test_data_training.py index 031fdbf7e..eb79464e0 100644 --- a/tests/nn/data/test_data_training.py +++ b/tests/nn/data/test_data_training.py @@ -1,50 +1,75 @@ import numpy as np -import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test - import sleap -from sleap.nn.data.providers import LabelsReader -from sleap.nn.data import training - - -def test_split_labels_reader(min_labels): - labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0], min_labels[0]]) - labels_reader = LabelsReader(labels) - reader1, reader2 = training.split_labels_reader(labels_reader, [0.5, 0.5]) - assert len(reader1) == 2 - assert len(reader2) == 2 - assert len(set(reader1.example_indices).intersection(set(reader2.example_indices))) == 0 - - reader1, reader2 = training.split_labels_reader(labels_reader, [0.1, 0.5]) - assert len(reader1) == 1 - assert len(reader2) == 2 - assert len(set(reader1.example_indices).intersection(set(reader2.example_indices))) == 0 - - reader1, reader2 = training.split_labels_reader(labels_reader, [0.1, -1]) - assert len(reader1) == 1 - assert len(reader2) == 3 - assert len(set(reader1.example_indices).intersection(set(reader2.example_indices))) == 0 - - labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0], min_labels[0]]) - labels_reader = LabelsReader(labels, example_indices=[1, 2, 3]) - reader1, reader2 = training.split_labels_reader(labels_reader, [0.1, -1]) - assert len(reader1) == 1 - assert len(reader2) == 2 - assert len(set(reader1.example_indices).intersection(set(reader2.example_indices))) == 0 - assert 0 not in reader1.example_indices - assert 0 not in reader2.example_indices - - -def test_keymapper(): - ds = tf.data.Dataset.from_tensors({"a": 0, "b": 1}) - mapper = training.KeyMapper(key_maps={"a": "x", "b": "y"}) - ds = mapper.transform_dataset(ds) - np.testing.assert_array_equal(next(iter(ds)), {"x": 0, "y": 1}) - assert mapper.input_keys == ["a", "b"] - assert mapper.output_keys == ["x", "y"] - - ds = tf.data.Dataset.from_tensors({"a": 0, "b": 1}) - ds = training.KeyMapper(key_maps=[{"a": "x"}, {"b": "y"}]).transform_dataset(ds) - np.testing.assert_array_equal(next(iter(ds)), ({"x": 0}, {"y": 1})) - assert mapper.input_keys == ["a", "b"] - assert mapper.output_keys == ["x", "y"] +from sleap.nn.data.training import split_labels_train_val + + +sleap.use_cpu_only() # hide GPUs for test + + +def test_split_labels_train_val(): + vid = sleap.Video(backend=sleap.io.video.MediaVideo) + labels = sleap.Labels([sleap.LabeledFrame(video=vid, frame_idx=0)]) + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0) + assert len(train) == 1 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.1) + assert len(train) == 1 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.5) + assert len(train) == 1 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 1.0) + assert len(train) == 1 + assert len(val) == 1 + + labels = sleap.Labels( + [ + sleap.LabeledFrame(video=vid, frame_idx=0), + sleap.LabeledFrame(video=vid, frame_idx=1), + ] + ) + train, train_inds, val, val_inds = split_labels_train_val(labels, 0) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.1) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.5) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx + + train, train_inds, val, val_inds = split_labels_train_val(labels, 1.0) + assert len(train) == 1 + assert len(val) == 1 + assert train[0].frame_idx != val[0].frame_idx + + labels = sleap.Labels( + [ + sleap.LabeledFrame(video=vid, frame_idx=0), + sleap.LabeledFrame(video=vid, frame_idx=1), + sleap.LabeledFrame(video=vid, frame_idx=2), + ] + ) + train, train_inds, val, val_inds = split_labels_train_val(labels, 0) + assert len(train) == 2 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.1) + assert len(train) == 2 + assert len(val) == 1 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 0.5) + assert len(train) + len(val) == 3 + + train, train_inds, val, val_inds = split_labels_train_val(labels, 1.0) + assert len(train) == 1 + assert len(val) == 2 diff --git a/tests/nn/data/test_dataset_ops.py b/tests/nn/data/test_dataset_ops.py index 2429744be..bea4d730a 100644 --- a/tests/nn/data/test_dataset_ops.py +++ b/tests/nn/data/test_dataset_ops.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import dataset_ops diff --git a/tests/nn/data/test_edge_maps.py b/tests/nn/data/test_edge_maps.py index 9dd68a44b..295360538 100644 --- a/tests/nn/data/test_edge_maps.py +++ b/tests/nn/data/test_edge_maps.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import providers from sleap.nn.data import edge_maps @@ -14,23 +16,18 @@ def test_distance_to_edge(): sigma = 1.0 sampling_grid = tf.stack(tf.meshgrid(xv, yv), axis=-1) # (height, width, 2) - distances = edge_maps.distance_to_edge(sampling_grid, - edge_source=edge_source, edge_destination=edge_destination) + distances = edge_maps.distance_to_edge( + sampling_grid, edge_source=edge_source, edge_destination=edge_destination + ) np.testing.assert_allclose( distances, [ - [[1.25, 0. ], - [0.25, 0.5 ], - [1.25, 2. ]], - [[1. , 0.5 ], - [0. , 0. ], - [1. , 0.5 ]], - [[1.25, 2. ], - [0.25, 0.5 ], - [1.25, 0. ]] + [[1.25, 0.0], [0.25, 0.5], [1.25, 2.0]], + [[1.0, 0.5], [0.0, 0.0], [1.0, 0.5]], + [[1.25, 2.0], [0.25, 0.5], [1.25, 0.0]], ], - atol=1e-3 + atol=1e-3, ) @@ -40,23 +37,22 @@ def test_edge_confidence_map(): edge_destination = tf.cast([[1, 1.5], [2, 2]], tf.float32) sigma = 1.0 - edge_confidence_map = edge_maps.make_edge_maps(xv=xv, yv=yv, - edge_source=edge_source, edge_destination=edge_destination, sigma=sigma) + edge_confidence_map = edge_maps.make_edge_maps( + xv=xv, + yv=yv, + edge_source=edge_source, + edge_destination=edge_destination, + sigma=sigma, + ) np.testing.assert_allclose( edge_confidence_map, [ - [[0.458, 1.000], - [0.969, 0.882], - [0.458, 0.135]], - [[0.607, 0.882], - [1.000, 1.000], - [0.607, 0.882]], - [[0.458, 0.135], - [0.969, 0.882], - [0.458, 1.000]] + [[0.458, 1.000], [0.969, 0.882], [0.458, 0.135]], + [[0.607, 0.882], [1.000, 1.000], [0.607, 0.882]], + [[0.458, 0.135], [0.969, 0.882], [0.458, 1.000]], ], - atol=1e-3 + atol=1e-3, ) @@ -66,68 +62,84 @@ def test_make_pafs(): edge_destination = tf.cast([[1, 1.5], [2, 2]], tf.float32) sigma = 1.0 - pafs = edge_maps.make_pafs(xv=xv, yv=yv, - edge_source=edge_source, edge_destination=edge_destination, sigma=sigma) + pafs = edge_maps.make_pafs( + xv=xv, + yv=yv, + edge_source=edge_source, + edge_destination=edge_destination, + sigma=sigma, + ) np.testing.assert_allclose( pafs, - [[[[0. , 0.458], - [0.707, 0.707]], - [[0. , 0.969], - [0.624, 0.624]], - [[0. , 0.458], - [0.096, 0.096]]], - [[[0. , 0.607], - [0.624, 0.624]], - [[0. , 1. ], - [0.707, 0.707]], - [[0. , 0.607], - [0.624, 0.624]]], - [[[0. , 0.458], - [0.096, 0.096]], - [[0. , 0.969], - [0.624, 0.624]], - [[0. , 0.458], - [0.707, 0.707]]]], - atol=1e-3) + [ + [ + [[0.0, 0.458], [0.707, 0.707]], + [[0.0, 0.969], [0.624, 0.624]], + [[0.0, 0.458], [0.096, 0.096]], + ], + [ + [[0.0, 0.607], [0.624, 0.624]], + [[0.0, 1.0], [0.707, 0.707]], + [[0.0, 0.607], [0.624, 0.624]], + ], + [ + [[0.0, 0.458], [0.096, 0.096]], + [[0.0, 0.969], [0.624, 0.624]], + [[0.0, 0.458], [0.707, 0.707]], + ], + ], + atol=1e-3, + ) def test_make_multi_pafs(): xv, yv = make_grid_vectors(image_height=3, image_width=3, output_stride=1) - edge_source = tf.cast([ + edge_source = tf.cast( + [ [[1, 0.5], [0, 0]], [[1, 0.5], [0, 0]], - ], tf.float32) - edge_destination = tf.cast([ + ], + tf.float32, + ) + edge_destination = tf.cast( + [ [[1, 1.5], [2, 2]], [[1, 1.5], [2, 2]], - ], tf.float32) + ], + tf.float32, + ) sigma = 1.0 - pafs = edge_maps.make_multi_pafs(xv=xv, yv=yv, - edge_sources=edge_source, edge_destinations=edge_destination, sigma=sigma) + pafs = edge_maps.make_multi_pafs( + xv=xv, + yv=yv, + edge_sources=edge_source, + edge_destinations=edge_destination, + sigma=sigma, + ) np.testing.assert_allclose( pafs, - [[[[0. , 0.916], - [1.414, 1.414]], - [[0. , 1.938], - [1.248, 1.248]], - [[0. , 0.916], - [0.191, 0.191]]], - [[[0. , 1.213], - [1.248, 1.248]], - [[0. , 2. ], - [1.414, 1.414]], - [[0. , 1.213], - [1.248, 1.248]]], - [[[0. , 0.916], - [0.191, 0.191]], - [[0. , 1.938], - [1.248, 1.248]], - [[0. , 0.916], - [1.414, 1.414]]]], - atol=1e-3) + [ + [ + [[0.0, 0.916], [1.414, 1.414]], + [[0.0, 1.938], [1.248, 1.248]], + [[0.0, 0.916], [0.191, 0.191]], + ], + [ + [[0.0, 1.213], [1.248, 1.248]], + [[0.0, 2.0], [1.414, 1.414]], + [[0.0, 1.213], [1.248, 1.248]], + ], + [ + [[0.0, 0.916], [0.191, 0.191]], + [[0.0, 1.938], [1.248, 1.248]], + [[0.0, 0.916], [1.414, 1.414]], + ], + ], + atol=1e-3, + ) def test_get_edge_points(): @@ -136,40 +148,29 @@ def test_get_edge_points(): edge_sources, edge_destinations = edge_maps.get_edge_points(instances, edge_inds) np.testing.assert_array_equal( edge_sources, - [[[ 0, 1], - [ 2, 3], - [ 0, 1]], - [[ 6, 7], - [ 8, 9], - [ 6, 7]], - [[12, 13], - [14, 15], - [12, 13]], - [[18, 19], - [20, 21], - [18, 19]]] + [ + [[0, 1], [2, 3], [0, 1]], + [[6, 7], [8, 9], [6, 7]], + [[12, 13], [14, 15], [12, 13]], + [[18, 19], [20, 21], [18, 19]], + ], ) np.testing.assert_array_equal( edge_destinations, - [[[ 2, 3], - [ 4, 5], - [ 4, 5]], - [[ 8, 9], - [10, 11], - [10, 11]], - [[14, 15], - [16, 17], - [16, 17]], - [[20, 21], - [22, 23], - [22, 23]]] + [ + [[2, 3], [4, 5], [4, 5]], + [[8, 9], [10, 11], [10, 11]], + [[14, 15], [16, 17], [16, 17]], + [[20, 21], [22, 23], [22, 23]], + ], ) def test_part_affinity_fields_generator(min_labels): labels_reader = providers.LabelsReader(min_labels) paf_generator = edge_maps.PartAffinityFieldsGenerator( - sigma=8, output_stride=2, skeletons=labels_reader.labels.skeletons) + sigma=8, output_stride=2, skeletons=labels_reader.labels.skeletons + ) ds = labels_reader.make_dataset() ds = paf_generator.transform_dataset(ds) @@ -180,5 +181,5 @@ def test_part_affinity_fields_generator(min_labels): np.testing.assert_allclose( example["part_affinity_fields"][196 // 2, 250 // 2, :, :], - [[0.9600351, 0.20435576]] + [[0.9600351, 0.20435576]], ) diff --git a/tests/nn/data/test_identity.py b/tests/nn/data/test_identity.py index a7209161b..52d25dd1b 100644 --- a/tests/nn/data/test_identity.py +++ b/tests/nn/data/test_identity.py @@ -5,6 +5,7 @@ from sleap.nn.data.identity import ( make_class_vectors, make_class_maps, + ClassVectorGenerator, ClassMapGenerator, ) @@ -27,6 +28,29 @@ def test_make_class_maps(): ) +def test_class_vector_generator(min_tracks_2node_labels): + labels = min_tracks_2node_labels + + gen = ClassVectorGenerator() + + p = labels.to_pipeline() + ds = p.make_dataset() + ds = gen.transform_dataset(ds) + ex = next(iter(ds)) + + np.testing.assert_array_equal(ex["class_vectors"], [[0, 1], [1, 0]]) + assert ex["class_vectors"].dtype == tf.float32 + + p = labels.to_pipeline() + p += gen + p += sleap.pipelines.InstanceCentroidFinder() + p += sleap.pipelines.InstanceCropper(32, 32) + ex = p.peek() + + np.testing.assert_array_equal(ex["class_vectors"], [0, 1]) + assert ex["class_vectors"].dtype == tf.float32 + + def test_class_map_generator(min_tracks_2node_labels): labels = min_tracks_2node_labels @@ -47,7 +71,9 @@ def test_class_map_generator(min_tracks_2node_labels): // 4 ) np.testing.assert_allclose( - tf.gather_nd(ex["class_maps"], subs), [[0, 1], [0, 1], [1, 0], [1, 0]], atol=1e-2 + tf.gather_nd(ex["class_maps"], subs), + [[0, 1], [0, 1], [1, 0], [1, 0]], + atol=1e-2, ) gen = ClassMapGenerator( diff --git a/tests/nn/data/test_instance_centroids.py b/tests/nn/data/test_instance_centroids.py index 0d3784f5b..78dee251c 100644 --- a/tests/nn/data/test_instance_centroids.py +++ b/tests/nn/data/test_instance_centroids.py @@ -1,7 +1,9 @@ import pytest import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test import sleap from sleap.nn.data import providers @@ -10,29 +12,19 @@ def test_find_points_bbox_midpoint(): - pts = tf.convert_to_tensor([ - [1, 2], - [2, 3]], dtype=tf.float32) + pts = tf.convert_to_tensor([[1, 2], [2, 3]], dtype=tf.float32) mid_pt = instance_centroids.find_points_bbox_midpoint(pts) np.testing.assert_array_equal(mid_pt, [1.5, 2.5]) - pts = tf.convert_to_tensor([ - [1, 2], - [np.nan, np.nan], - [2, 3]], dtype=tf.float32) + pts = tf.convert_to_tensor([[1, 2], [np.nan, np.nan], [2, 3]], dtype=tf.float32) mid_pt = instance_centroids.find_points_bbox_midpoint(pts) np.testing.assert_array_equal(mid_pt, [1.5, 2.5]) def test_get_instance_anchors(): - instances = tf.convert_to_tensor([ - [[0, 1], - [2, 3], - [4, 5]], - [[6, 7], - [8, 9], - [10, 11]] - ]) + instances = tf.convert_to_tensor( + [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]] + ) anchor_inds = tf.convert_to_tensor([0, 1], tf.int32) anchors = instance_centroids.get_instance_anchors(instances, anchor_inds) np.testing.assert_array_equal(anchors, [[0, 1], [8, 9]]) @@ -43,43 +35,47 @@ def test_instance_centroid_finder(min_labels): labels_ds = labels_reader.make_dataset() instance_centroid_finder = instance_centroids.InstanceCentroidFinder( - center_on_anchor_part=False) + center_on_anchor_part=False + ) ds = instance_centroid_finder.transform_dataset(labels_ds) example = next(iter(ds)) assert example["centroids"].dtype == tf.float32 - np.testing.assert_allclose(example["centroids"], - [[122.49705, 180.57481], - [242.28264, 195.62775]]) + np.testing.assert_allclose( + example["centroids"], [[122.49705, 180.57481], [242.28264, 195.62775]] + ) + def test_instance_centroid_finder_anchored(min_labels): labels_reader = providers.LabelsReader(min_labels) labels_ds = labels_reader.make_dataset() instance_centroid_finder = instance_centroids.InstanceCentroidFinder( - center_on_anchor_part=True, anchor_part_names="A", - skeletons=labels_reader.labels.skeletons) + center_on_anchor_part=True, + anchor_part_names="A", + skeletons=labels_reader.labels.skeletons, + ) ds = instance_centroid_finder.transform_dataset(labels_ds) example = next(iter(ds)) assert example["centroids"].dtype == tf.float32 - np.testing.assert_allclose(example["centroids"], - [[92.65221, 202.72598], - [205.93005, 187.88963]]) + np.testing.assert_allclose( + example["centroids"], [[92.65221, 202.72598], [205.93005, 187.88963]] + ) + def test_instance_centroid_finder_from_config(): finder = instance_centroids.InstanceCentroidFinder.from_config( - config=InstanceCroppingConfig(center_on_part=None), - skeletons=None - ) + config=InstanceCroppingConfig(center_on_part=None), skeletons=None + ) assert finder.center_on_anchor_part == False finder = instance_centroids.InstanceCentroidFinder.from_config( config=InstanceCroppingConfig(center_on_part="A"), skeletons=sleap.Skeleton.from_names_and_edge_inds(["A", "B"]), - ) + ) assert finder.center_on_anchor_part == True assert finder.anchor_part_names == ["A"] assert finder.skeletons[0].node_names == ["A", "B"] @@ -88,4 +84,4 @@ def test_instance_centroid_finder_from_config(): finder = instance_centroids.InstanceCentroidFinder.from_config( config=InstanceCroppingConfig(center_on_part="A"), skeletons=None, - ) \ No newline at end of file + ) diff --git a/tests/nn/data/test_instance_cropping.py b/tests/nn/data/test_instance_cropping.py index 67860d0c6..b54fb0e99 100644 --- a/tests/nn/data/test_instance_cropping.py +++ b/tests/nn/data/test_instance_cropping.py @@ -1,7 +1,9 @@ import pytest import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import providers from sleap.nn.data import instance_centroids @@ -20,11 +22,13 @@ def test_normalize_bboxes(): def test_make_centered_bboxes(): bbox = instance_cropping.make_centered_bboxes( - tf.convert_to_tensor([[1, 1]], tf.float32), box_height=3, box_width=3) + tf.convert_to_tensor([[1, 1]], tf.float32), box_height=3, box_width=3 + ) np.testing.assert_array_equal(bbox, [[0, 0, 2, 2]]) bbox = instance_cropping.make_centered_bboxes( - tf.convert_to_tensor([[2, 2]], tf.float32), box_height=4, box_width=4) + tf.convert_to_tensor([[2, 2]], tf.float32), box_height=4, box_width=4 + ) np.testing.assert_array_equal(bbox, [[0.5, 0.5, 3.5, 3.5]]) @@ -35,40 +39,40 @@ def test_crop_bboxes(): img = tf.stack([XX, YY], axis=-1) centroids = tf.convert_to_tensor([[1, 1]], tf.float32) - bboxes = instance_cropping.make_centered_bboxes(centroids, - box_height=3, box_width=3) + bboxes = instance_cropping.make_centered_bboxes( + centroids, box_height=3, box_width=3 + ) crops = instance_cropping.crop_bboxes(img, bboxes) - patch_xx = [[0, 1, 2], - [0, 1, 2], - [0, 1, 2]] - patch_yy = [[0, 0, 0], - [1, 1, 1], - [2, 2, 2]] + patch_xx = [[0, 1, 2], [0, 1, 2], [0, 1, 2]] + patch_yy = [[0, 0, 0], [1, 1, 1], [2, 2, 2]] expected_patch = np.expand_dims(np.stack([patch_xx, patch_yy], axis=-1), axis=0) np.testing.assert_array_equal(crops, expected_patch) np.testing.assert_array_equal(crops, np.expand_dims(img.numpy()[:3, :3, :], axis=0)) assert crops.dtype == img.dtype + def test_crop_bboxes_rounding(): # Test for rounding truncation bug when computing bounding box size for cropping. bboxes = instance_cropping.make_centered_bboxes( - tf.cast([[464.42838, 550.14276]], tf.float32), - box_height=100, box_width=100 + tf.cast([[464.42838, 550.14276]], tf.float32), box_height=100, box_width=100 ) crops = instance_cropping.crop_bboxes( - tf.zeros([16, 16, 1], tf.float32), - bboxes=bboxes + tf.zeros([16, 16, 1], tf.float32), bboxes=bboxes ) assert crops.shape == (1, 100, 100, 1) + def test_instance_cropper(min_labels): labels_reader = providers.LabelsReader(min_labels) instance_centroid_finder = instance_centroids.InstanceCentroidFinder( - center_on_anchor_part=True, anchor_part_names="A", - skeletons=labels_reader.labels.skeletons) + center_on_anchor_part=True, + anchor_part_names="A", + skeletons=labels_reader.labels.skeletons, + ) instance_cropper = instance_cropping.InstanceCropper( - crop_width=160, crop_height=160, keep_full_image=False) + crop_width=160, crop_height=160, keep_full_image=False + ) ds = instance_centroid_finder.transform_dataset(labels_reader.make_dataset()) ds = instance_cropper.transform_dataset(ds) @@ -117,10 +121,13 @@ def test_instance_cropper(min_labels): def test_instance_cropper_keeping_full_image(min_labels): labels_reader = providers.LabelsReader(min_labels) instance_centroid_finder = instance_centroids.InstanceCentroidFinder( - center_on_anchor_part=True, anchor_part_names="A", - skeletons=labels_reader.labels.skeletons) + center_on_anchor_part=True, + anchor_part_names="A", + skeletons=labels_reader.labels.skeletons, + ) instance_cropper = instance_cropping.InstanceCropper( - crop_width=160, crop_height=160, keep_full_image=True) + crop_width=160, crop_height=160, keep_full_image=True + ) ds = instance_centroid_finder.transform_dataset(labels_reader.make_dataset()) ds = instance_cropper.transform_dataset(ds) diff --git a/tests/nn/data/test_normalization.py b/tests/nn/data/test_normalization.py index f93ef2c45..20a1df4ec 100644 --- a/tests/nn/data/test_normalization.py +++ b/tests/nn/data/test_normalization.py @@ -145,7 +145,10 @@ def test_normalizer_from_config(): def test_ensure_grayscale_from_provider(small_robot_mp4_vid): - video = providers.VideoReader(video=small_robot_mp4_vid, example_indices=[0],) + video = providers.VideoReader( + video=small_robot_mp4_vid, + example_indices=[0], + ) normalizer = normalization.Normalizer(image_key="image", ensure_grayscale=True) @@ -157,7 +160,10 @@ def test_ensure_grayscale_from_provider(small_robot_mp4_vid): def test_ensure_rgb_from_provider(centered_pair_vid): - video = providers.VideoReader(video=centered_pair_vid, example_indices=[0],) + video = providers.VideoReader( + video=centered_pair_vid, + example_indices=[0], + ) normalizer = normalization.Normalizer(image_key="image", ensure_rgb=True) diff --git a/tests/nn/data/test_offset_regression.py b/tests/nn/data/test_offset_regression.py index 61e8a9e57..31e688839 100644 --- a/tests/nn/data/test_offset_regression.py +++ b/tests/nn/data/test_offset_regression.py @@ -23,12 +23,12 @@ def test_make_offsets(): def test_mask_offsets(): - points = np.array([[1., 1.]], "float32") + points = np.array([[1.0, 1.0]], "float32") xv, yv = sleap.nn.data.confidence_maps.make_grid_vectors(4, 4, output_stride=1) off = offset_regression.make_offsets(points, xv, yv, stride=1) cm = sleap.nn.data.confidence_maps.make_confmaps(points, xv, yv, sigma=1) off_mask = offset_regression.mask_offsets(off, cm, threshold=0.2) np.testing.assert_array_equal(off_mask[:3, :3], off[:3, :3]) - np.testing.assert_array_equal(off_mask[3:, :], 0.) - np.testing.assert_array_equal(off_mask[:, 3:], 0.) \ No newline at end of file + np.testing.assert_array_equal(off_mask[3:, :], 0.0) + np.testing.assert_array_equal(off_mask[:, 3:], 0.0) diff --git a/tests/nn/data/test_pipelines.py b/tests/nn/data/test_pipelines.py index 032497311..30b67e13c 100644 --- a/tests/nn/data/test_pipelines.py +++ b/tests/nn/data/test_pipelines.py @@ -1,15 +1,21 @@ import pytest import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test import sleap from sleap.nn.data import pipelines def test_pipeline_concatenation(): - A = pipelines.Pipeline.from_blocks(pipelines.InstanceCentroidFinder(center_on_anchor_part=False)) - B = pipelines.Pipeline.from_blocks(pipelines.InstanceCropper(crop_width=16, crop_height=16)) + A = pipelines.Pipeline.from_blocks( + pipelines.InstanceCentroidFinder(center_on_anchor_part=False) + ) + B = pipelines.Pipeline.from_blocks( + pipelines.InstanceCropper(crop_width=16, crop_height=16) + ) C = A + B assert len(C.transformers) == 2 @@ -21,25 +27,37 @@ def test_pipeline_concatenation(): assert isinstance(C.transformers[0], pipelines.InstanceCentroidFinder) assert isinstance(C.transformers[1], pipelines.InstanceCropper) - D = pipelines.Pipeline.from_blocks(pipelines.InstanceCentroidFinder(center_on_anchor_part=False)) - D += pipelines.Pipeline.from_blocks(pipelines.InstanceCropper(crop_width=16, crop_height=16)) + D = pipelines.Pipeline.from_blocks( + pipelines.InstanceCentroidFinder(center_on_anchor_part=False) + ) + D += pipelines.Pipeline.from_blocks( + pipelines.InstanceCropper(crop_width=16, crop_height=16) + ) assert len(D.transformers) == 2 assert isinstance(D.transformers[0], pipelines.InstanceCentroidFinder) assert isinstance(D.transformers[1], pipelines.InstanceCropper) - E = pipelines.Pipeline.from_blocks(pipelines.InstanceCentroidFinder(center_on_anchor_part=False)) - E |= pipelines.Pipeline.from_blocks(pipelines.InstanceCropper(crop_width=16, crop_height=16)) + E = pipelines.Pipeline.from_blocks( + pipelines.InstanceCentroidFinder(center_on_anchor_part=False) + ) + E |= pipelines.Pipeline.from_blocks( + pipelines.InstanceCropper(crop_width=16, crop_height=16) + ) assert len(E.transformers) == 2 assert isinstance(E.transformers[0], pipelines.InstanceCentroidFinder) assert isinstance(E.transformers[1], pipelines.InstanceCropper) - F = pipelines.Pipeline.from_blocks(pipelines.InstanceCentroidFinder(center_on_anchor_part=False)) + F = pipelines.Pipeline.from_blocks( + pipelines.InstanceCentroidFinder(center_on_anchor_part=False) + ) F += pipelines.InstanceCropper(crop_width=16, crop_height=16) assert len(F.transformers) == 2 assert isinstance(F.transformers[0], pipelines.InstanceCentroidFinder) assert isinstance(F.transformers[1], pipelines.InstanceCropper) - G = pipelines.Pipeline.from_blocks(pipelines.InstanceCentroidFinder(center_on_anchor_part=False)) + G = pipelines.Pipeline.from_blocks( + pipelines.InstanceCentroidFinder(center_on_anchor_part=False) + ) G |= pipelines.InstanceCropper(crop_width=16, crop_height=16) assert len(G.transformers) == 2 assert isinstance(G.transformers[0], pipelines.InstanceCentroidFinder) diff --git a/tests/nn/data/test_providers.py b/tests/nn/data/test_providers.py index 2d384c5a9..6c559bbc4 100644 --- a/tests/nn/data/test_providers.py +++ b/tests/nn/data/test_providers.py @@ -11,6 +11,8 @@ def test_labels_reader(min_labels): labels_reader = providers.LabelsReader.from_user_instances(min_labels) ds = labels_reader.make_dataset() + assert not labels_reader.is_from_multi_size_videos + example = next(iter(ds)) assert len(labels_reader) == 1 @@ -68,6 +70,8 @@ def test_labels_reader_no_visible_points(min_labels): labels_reader = providers.LabelsReader.from_user_instances(min_labels) ds = labels_reader.make_dataset() + assert not labels_reader.is_from_multi_size_videos + example = next(iter(ds)) # There should be two instances in the labels dataset @@ -155,3 +159,49 @@ def test_video_reader_hdf5(): assert example["raw_image_size"].dtype == tf.int32 np.testing.assert_array_equal(example["raw_image_size"], (512, 512, 1)) + + +def test_labels_reader_multi_size(): + # Create some fake data using two different size videos. + skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"]) + labels = sleap.Labels( + [ + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_SMALL_ROBOT_MP4_FILE, grayscale=True + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_H5_FILE, dataset="/box", input_format="channels_first" + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + ] + ) + + # Create a loader for those labels. + labels_reader = providers.LabelsReader(labels) + ds = labels_reader.make_dataset() + ds_iter = iter(ds) + + # Check LabelReader can provide different shapes of individual samples + assert next(ds_iter)["image"].shape == (320, 560, 1) + assert next(ds_iter)["image"].shape == (512, 512, 1) + + # Check util functions + h, w = labels_reader.max_height_and_width + assert h == 512 + assert w == 560 + assert labels_reader.is_from_multi_size_videos diff --git a/tests/nn/data/test_resizing.py b/tests/nn/data/test_resizing.py index ee548b9e1..891bbb189 100644 --- a/tests/nn/data/test_resizing.py +++ b/tests/nn/data/test_resizing.py @@ -1,45 +1,69 @@ import pytest import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only +use_cpu_only() # hide GPUs for test + +import sleap +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import resizing from sleap.nn.data import providers +from sleap.nn.data.resizing import SizeMatcher + +from tests.fixtures.videos import TEST_H5_FILE, TEST_SMALL_ROBOT_MP4_FILE def test_find_padding_for_stride(): assert resizing.find_padding_for_stride( - image_height=127, image_width=129, max_stride=32) == (1, 31) + image_height=127, image_width=129, max_stride=32 + ) == (1, 31) assert resizing.find_padding_for_stride( - image_height=128, image_width=128, max_stride=32) == (0, 0) + image_height=128, image_width=128, max_stride=32 + ) == (0, 0) def test_pad_to_stride(): np.testing.assert_array_equal( resizing.pad_to_stride(tf.ones([3, 5, 1]), max_stride=2), - tf.expand_dims([ - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0]], axis=-1) - ) - assert resizing.pad_to_stride( - tf.ones([3, 5, 1], dtype=tf.uint8), max_stride=2).dtype == tf.uint8 - assert resizing.pad_to_stride( - tf.ones([3, 5, 1], dtype=tf.float32), max_stride=2).dtype == tf.float32 - assert resizing.pad_to_stride( - tf.ones([4, 4, 1]), max_stride=2).shape == (4, 4, 1) + tf.expand_dims( + [ + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + ], + axis=-1, + ), + ) + assert ( + resizing.pad_to_stride(tf.ones([3, 5, 1], dtype=tf.uint8), max_stride=2).dtype + == tf.uint8 + ) + assert ( + resizing.pad_to_stride(tf.ones([3, 5, 1], dtype=tf.float32), max_stride=2).dtype + == tf.float32 + ) + assert resizing.pad_to_stride(tf.ones([4, 4, 1]), max_stride=2).shape == (4, 4, 1) def test_resize_image(): assert resizing.resize_image( - tf.ones([4, 8, 1], dtype=tf.uint8), scale=[0.25, 3]).shape == (12, 2, 1) - assert resizing.resize_image( - tf.ones([4, 8, 1], dtype=tf.uint8), scale=0.5).shape == (2, 4, 1) + tf.ones([4, 8, 1], dtype=tf.uint8), scale=[0.25, 3] + ).shape == (12, 2, 1) assert resizing.resize_image( - tf.ones([4, 8, 1], dtype=tf.uint8), scale=0.5).dtype == tf.uint8 - assert resizing.resize_image( - tf.ones([4, 8, 1], dtype=tf.float32), scale=0.5).dtype == tf.float32 + tf.ones([4, 8, 1], dtype=tf.uint8), scale=0.5 + ).shape == (2, 4, 1) + assert ( + resizing.resize_image(tf.ones([4, 8, 1], dtype=tf.uint8), scale=0.5).dtype + == tf.uint8 + ) + assert ( + resizing.resize_image(tf.ones([4, 8, 1], dtype=tf.float32), scale=0.5).dtype + == tf.float32 + ) def test_resizer(min_labels): @@ -71,7 +95,8 @@ def test_resizer(min_labels): def test_resizer_from_config(): resizer = resizing.Resizer.from_config( - config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=32)) + config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=32) + ) assert resizer.image_key == "image" assert resizer.points_key == "instances" assert resizer.scale == 0.5 @@ -79,7 +104,8 @@ def test_resizer_from_config(): resizer = resizing.Resizer.from_config( config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=32), - pad_to_stride=16) + pad_to_stride=16, + ) assert resizer.image_key == "image" assert resizer.points_key == "instances" assert resizer.scale == 0.5 @@ -87,7 +113,8 @@ def test_resizer_from_config(): resizer = resizing.Resizer.from_config( config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=None), - pad_to_stride=32) + pad_to_stride=32, + ) assert resizer.image_key == "image" assert resizer.points_key == "instances" assert resizer.scale == 0.5 @@ -95,4 +122,90 @@ def test_resizer_from_config(): with pytest.raises(ValueError): resizer = resizing.Resizer.from_config( - config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=None)) + config=resizing.PreprocessingConfig(input_scaling=0.5, pad_to_stride=None) + ) + + +def test_size_matcher(): + # Create some fake data using two different size videos. + skeleton = sleap.Skeleton.from_names_and_edge_inds(["A"]) + labels = sleap.Labels( + [ + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_SMALL_ROBOT_MP4_FILE, grayscale=True + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + sleap.LabeledFrame( + frame_idx=0, + video=sleap.Video.from_filename( + TEST_H5_FILE, dataset="/box", input_format="channels_first" + ), + instances=[ + sleap.Instance.from_pointsarray( + np.array([[128, 128]]), skeleton=skeleton + ) + ], + ), + ] + ) + + # Create a loader for those labels. + labels_reader = providers.LabelsReader(labels) + ds = labels_reader.make_dataset() + ds_iter = iter(ds) + assert next(ds_iter)["image"].shape == (320, 560, 1) + assert next(ds_iter)["image"].shape == (512, 512, 1) + + def check_padding(image, from_y, to_y, from_x, to_x): + for y in range(from_y, to_y): + for x in range(from_x, to_x): + assert image[y][x] == 0 + + # Check SizeMatcher when target dims is not strictly larger than actual image dims + size_matcher = SizeMatcher(max_image_height=560, max_image_width=560) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (560, 560, 1) + # padding should be on the bottom + check_padding(im1, 321, 560, 0, 560) + im2 = next(transform_iter)["image"] + assert im2.shape == (560, 560, 1) + + # Variant 2 + size_matcher = SizeMatcher(max_image_height=320, max_image_width=560) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (320, 560, 1) + im2 = next(transform_iter)["image"] + assert im2.shape == (320, 560, 1) + # padding should be on the right + check_padding(im2, 0, 320, 321, 560) + + # Check SizeMatcher when target is 'max' in both dimensions + size_matcher = SizeMatcher(max_image_height=512, max_image_width=560) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (512, 560, 1) + # Check padding is on the bottom + check_padding(im1, 320, 512, 0, 560) + im2 = next(transform_iter)["image"] + assert im2.shape == (512, 560, 1) + # Check padding is on the right + check_padding(im2, 0, 512, 512, 560) + + # Check SizeMatcher when target is larger in both dimensions + size_matcher = SizeMatcher(max_image_height=750, max_image_width=750) + transform_iter = iter(size_matcher.transform_dataset(ds)) + im1 = next(transform_iter)["image"] + assert im1.shape == (750, 750, 1) + # Check padding is on the bottom + check_padding(im1, 700, 750, 0, 750) + im2 = next(transform_iter)["image"] + assert im2.shape == (750, 750, 1) diff --git a/tests/nn/data/test_utils.py b/tests/nn/data/test_utils.py index 46760ff79..213e357e8 100644 --- a/tests/nn/data/test_utils.py +++ b/tests/nn/data/test_utils.py @@ -1,6 +1,8 @@ import numpy as np import tensorflow as tf -from sleap.nn.system import use_cpu_only; use_cpu_only() # hide GPUs for test +from sleap.nn.system import use_cpu_only + +use_cpu_only() # hide GPUs for test from sleap.nn.data import utils @@ -28,10 +30,8 @@ def test_expand_to_rank(): [[0, 1, 2]], ) np.testing.assert_array_equal( - utils.expand_to_rank( - tf.reshape(tf.range(2 * 3 * 4), [2, 3, 4]), - target_rank=2), - tf.reshape(tf.range(2 * 3 * 4), [2, 3, 4]) + utils.expand_to_rank(tf.reshape(tf.range(2 * 3 * 4), [2, 3, 4]), target_rank=2), + tf.reshape(tf.range(2 * 3 * 4), [2, 3, 4]), ) @@ -62,33 +62,36 @@ def test_describe_tensors(): dict( tens=tf.ones((1, 2), dtype=tf.uint8), rag=tf.ragged.stack([tf.ones((3,)), tf.ones((4,))], axis=0), - np=np.array([1, 2], dtype="int32") - ), return_description=True, + np=np.array([1, 2], dtype="int32"), + ), + return_description=True, + ) + assert desc == "\n".join( + [ + "tens: type=EagerTensor, shape=(1, 2), dtype=tf.uint8, " + "device=/job:localhost/replica:0/task:0/device:CPU:0", + " rag: type=RaggedTensor, shape=(2, None), dtype=tf.float32, device=N/A", + " np: type=ndarray, shape=(2,), dtype=int32, device=N/A", + ] ) - assert desc == "\n".join([ - "tens: type=EagerTensor, shape=(1, 2), dtype=tf.uint8, " - "device=/job:localhost/replica:0/task:0/device:CPU:0", - " rag: type=RaggedTensor, shape=(2, None), dtype=tf.float32, device=N/A", - " np: type=ndarray, shape=(2,), dtype=int32, device=N/A", - ]) def test_unrag_example(): ex = { "not_ragged": tf.ones([1, 2]), - "ragged_float": tf.ragged.stack([ - tf.ones([2], dtype=tf.float32), - tf.ones([1], dtype=tf.float32)], axis=0), - "ragged_int": tf.ragged.stack([ - tf.ones([2], dtype=tf.uint8), - tf.ones([1], dtype=tf.uint8)], axis=0) + "ragged_float": tf.ragged.stack( + [tf.ones([2], dtype=tf.float32), tf.ones([1], dtype=tf.float32)], axis=0 + ), + "ragged_int": tf.ragged.stack( + [tf.ones([2], dtype=tf.uint8), tf.ones([1], dtype=tf.uint8)], axis=0 + ), } ex2 = utils.unrag_example(ex) assert all(isinstance(v, tf.Tensor) for v in ex2.values()) assert (ex2["not_ragged"].numpy() == [[1, 1]]).all() - np.testing.assert_array_equal(ex2["ragged_float"], [[1., 1.], [1., np.nan]]) + np.testing.assert_array_equal(ex2["ragged_float"], [[1.0, 1.0], [1.0, np.nan]]) assert (ex2["ragged_int"].numpy() == [[1, 1], [1, 0]]).all() diff --git a/tests/nn/test_heads.py b/tests/nn/test_heads.py new file mode 100644 index 000000000..02fbc2737 --- /dev/null +++ b/tests/nn/test_heads.py @@ -0,0 +1,292 @@ +import tensorflow as tf + +import sleap +from sleap.nn.heads import ( + Head, + SingleInstanceConfmapsHead, + CentroidConfmapsHead, + CenteredInstanceConfmapsHead, + MultiInstanceConfmapsHead, + PartAffinityFieldsHead, + ClassMapsHead, + OffsetRefinementHead, +) +from sleap.nn.config import ( + SingleInstanceConfmapsHeadConfig, + CentroidsHeadConfig, + CenteredInstanceConfmapsHeadConfig, + MultiInstanceConfmapsHeadConfig, + PartAffinityFieldsHeadConfig, + ClassMapsHeadConfig, +) + + +sleap.use_cpu_only() + + +def test_single_instance_confmaps_head(): + head = SingleInstanceConfmapsHead( + part_names=["a", "b", "c"], + sigma=1.0, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 4]) + x = head.make_head(x_in) + + assert head.channels == 3 + assert tuple(x.shape) == (None, 4, 4, 3) + assert tf.keras.Model(x_in, x).output_names[0] == "SingleInstanceConfmapsHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "linear" + + head = SingleInstanceConfmapsHead.from_config( + SingleInstanceConfmapsHeadConfig( + part_names=None, + sigma=1.5, + output_stride=2, + loss_weight=2.0, + offset_refinement=False, + ), + part_names=["c", "b", "a"], + ) + assert head.part_names == ["c", "b", "a"] + assert head.sigma == 1.5 + assert head.output_stride == 2 + assert head.loss_weight == 2.0 + x = head.make_head(x_in) + + +def test_centroid_confmaps_head(): + head = CentroidConfmapsHead( + anchor_part="a", + sigma=1.0, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 4]) + x = head.make_head(x_in) + + assert head.channels == 1 + assert tuple(x.shape) == (None, 4, 4, 1) + assert tf.keras.Model(x_in, x).output_names[0] == "CentroidConfmapsHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "linear" + + head = CentroidConfmapsHead.from_config( + CentroidsHeadConfig( + anchor_part="a", + sigma=1.5, + output_stride=2, + loss_weight=2.0, + offset_refinement=False, + ) + ) + assert head.anchor_part == "a" + assert head.sigma == 1.5 + assert head.output_stride == 2 + assert head.loss_weight == 2.0 + x = head.make_head(x_in) + + +def test_centroid_confmaps_head(): + head = CentroidConfmapsHead( + anchor_part="a", + sigma=1.0, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 4]) + x = head.make_head(x_in) + + assert head.channels == 1 + assert tuple(x.shape) == (None, 4, 4, 1) + assert tf.keras.Model(x_in, x).output_names[0] == "CentroidConfmapsHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "linear" + + head = CentroidConfmapsHead.from_config( + CentroidsHeadConfig( + anchor_part="a", + sigma=1.5, + output_stride=2, + loss_weight=2.0, + offset_refinement=False, + ) + ) + assert head.anchor_part == "a" + assert head.sigma == 1.5 + assert head.output_stride == 2 + assert head.loss_weight == 2.0 + x = head.make_head(x_in) + + +def test_centered_instance_confmaps_head(): + head = CenteredInstanceConfmapsHead( + part_names=["a", "b", "c"], + anchor_part="a", + sigma=1.0, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 4]) + x = head.make_head(x_in) + + assert head.channels == 3 + assert tuple(x.shape) == (None, 4, 4, 3) + assert tf.keras.Model(x_in, x).output_names[0] == "CenteredInstanceConfmapsHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "linear" + + head = CenteredInstanceConfmapsHead.from_config( + CenteredInstanceConfmapsHeadConfig( + part_names=None, + anchor_part="a", + sigma=1.5, + output_stride=2, + loss_weight=2.0, + offset_refinement=False, + ), + part_names=["c", "b", "a"], + ) + assert head.part_names == ["c", "b", "a"] + assert head.sigma == 1.5 + assert head.output_stride == 2 + assert head.loss_weight == 2.0 + + +def test_multi_instance_confmaps_head(): + head = MultiInstanceConfmapsHead( + part_names=["a", "b", "c"], + sigma=1.0, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 4]) + x = head.make_head(x_in) + + assert head.channels == 3 + assert tuple(x.shape) == (None, 4, 4, 3) + assert tf.keras.Model(x_in, x).output_names[0] == "MultiInstanceConfmapsHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "linear" + + head = MultiInstanceConfmapsHead.from_config( + MultiInstanceConfmapsHeadConfig( + part_names=None, + sigma=1.5, + output_stride=2, + loss_weight=2.0, + offset_refinement=False, + ), + part_names=["c", "b", "a"], + ) + assert head.part_names == ["c", "b", "a"] + assert head.sigma == 1.5 + assert head.output_stride == 2 + assert head.loss_weight == 2.0 + + +def test_part_affinity_fields_head(): + head = PartAffinityFieldsHead( + edges=[("a", "b"), ("b", "c")], + sigma=1.0, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 5]) + x = head.make_head(x_in) + + assert head.channels == 4 + assert tuple(x.shape) == (None, 4, 4, 4) + assert tf.keras.Model(x_in, x).output_names[0] == "PartAffinityFieldsHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "linear" + + head = PartAffinityFieldsHead.from_config( + PartAffinityFieldsHeadConfig( + edges=None, + sigma=1.5, + output_stride=2, + loss_weight=2.0, + ), + edges=[("a", "b"), ("b", "c")], + ) + assert head.edges == [("a", "b"), ("b", "c")] + assert head.sigma == 1.5 + assert head.output_stride == 2 + assert head.loss_weight == 2.0 + + +def test_class_maps_head(): + head = ClassMapsHead( + classes=["1", "2"], + sigma=1.0, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 4]) + x = head.make_head(x_in) + + assert head.channels == 2 + assert tuple(x.shape) == (None, 4, 4, 2) + assert tf.keras.Model(x_in, x).output_names[0] == "ClassMapsHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "sigmoid" + + head = ClassMapsHead.from_config( + ClassMapsHeadConfig( + classes=None, + sigma=1.5, + output_stride=2, + loss_weight=2.0, + ), + classes=["1", "2"], + ) + assert head.classes == ["1", "2"] + assert head.sigma == 1.5 + assert head.output_stride == 2 + assert head.loss_weight == 2.0 + + +def test_offset_refinement_head(): + head = OffsetRefinementHead( + part_names=["a", "b", "c"], + sigma_threshold=0.3, + output_stride=1, + loss_weight=1.0, + ) + + x_in = tf.keras.layers.Input([4, 4, 8]) + x = head.make_head(x_in) + + assert head.channels == 6 + assert tuple(x.shape) == (None, 4, 4, 6) + assert tf.keras.Model(x_in, x).output_names[0] == "OffsetRefinementHead" + assert tf.keras.Model(x_in, x).layers[-1].activation.__name__ == "linear" + + head = OffsetRefinementHead.from_config( + MultiInstanceConfmapsHeadConfig( + part_names=["a", "b"], + sigma=1.5, + output_stride=2, + loss_weight=2.0, + offset_refinement=False, + ), + sigma_threshold=0.4, + ) + assert head.part_names == ["a", "b"] + assert head.output_stride == 2 + assert head.sigma_threshold == 0.4 + + head = OffsetRefinementHead.from_config( + MultiInstanceConfmapsHeadConfig(), part_names=["a", "b"] + ) + assert head.part_names == ["a", "b"] + + head = OffsetRefinementHead.from_config(CentroidsHeadConfig(anchor_part="a")) + assert head.part_names == ["a"] + + head = OffsetRefinementHead.from_config(CentroidsHeadConfig(anchor_part=None)) + assert head.part_names == [None] + assert head.channels == 2 diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index b92ce3ab2..00526ef7c 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -23,8 +23,8 @@ FindInstancePeaksGroundTruth, FindInstancePeaks, TopDownInferenceModel, - TopdownPredictor, - BottomupPredictor, + TopDownPredictor, + BottomUpPredictor, BottomUpMultiClassPredictor, load_model, ) @@ -493,7 +493,7 @@ def test_single_instance_predictor( def test_topdown_predictor_centroid(min_labels, min_centroid_model_path): - predictor = TopdownPredictor.from_trained_models( + predictor = TopDownPredictor.from_trained_models( centroid_model_path=min_centroid_model_path ) labels_pr = predictor.predict(min_labels) @@ -513,7 +513,7 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path): def test_topdown_predictor_centered_instance( min_labels, min_centered_instance_model_path ): - predictor = TopdownPredictor.from_trained_models( + predictor = TopDownPredictor.from_trained_models( confmap_model_path=min_centered_instance_model_path ) labels_pr = predictor.predict(min_labels) @@ -530,8 +530,8 @@ def test_topdown_predictor_centered_instance( assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5) -def test_topdown_predictor_bottomup(min_labels, min_bottomup_model_path): - predictor = BottomupPredictor.from_trained_models( +def test_bottomup_predictor(min_labels, min_bottomup_model_path): + predictor = BottomUpPredictor.from_trained_models( model_path=min_bottomup_model_path ) labels_pr = predictor.predict(min_labels) @@ -548,7 +548,7 @@ def test_topdown_predictor_bottomup(min_labels, min_bottomup_model_path): assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.75) -def test_topdown_predictor_bottomup_multiclass( +def test_bottomup_multiclass_predictor( min_tracks_2node_labels, min_bottomup_multiclass_model_path ): labels_gt = sleap.Labels(min_tracks_2node_labels[[0]]) @@ -573,8 +573,9 @@ def test_topdown_predictor_bottomup_multiclass( labels_gt[0][inds1[1]].numpy(), labels_pr[0][inds2[1]].numpy(), rtol=0.02 ) - labels_pr = predictor.predict(sleap.nn.data.pipelines.VideoReader( - labels_gt.video, example_indices=[0])) + labels_pr = predictor.predict( + sleap.nn.data.pipelines.VideoReader(labels_gt.video, example_indices=[0]) + ) labels_pr[0][0].track.name == "female" labels_pr[0][1].track.name == "male" @@ -590,10 +591,10 @@ def test_load_model( assert isinstance(predictor, SingleInstancePredictor) predictor = load_model([min_centroid_model_path, min_centered_instance_model_path]) - assert isinstance(predictor, TopdownPredictor) + assert isinstance(predictor, TopDownPredictor) predictor = load_model(min_bottomup_model_path) - assert isinstance(predictor, BottomupPredictor) + assert isinstance(predictor, BottomUpPredictor) predictor = load_model(min_bottomup_multiclass_model_path) assert isinstance(predictor, BottomUpMultiClassPredictor) diff --git a/tests/nn/test_inference_identity.py b/tests/nn/test_inference_identity.py index b0fec15df..22be152ea 100644 --- a/tests/nn/test_inference_identity.py +++ b/tests/nn/test_inference_identity.py @@ -5,7 +5,7 @@ import sleap from sleap.nn.identity import ( group_class_peaks, - classify_peaks, + classify_peaks_from_maps, ) @@ -13,9 +13,17 @@ def test_group_class_peaks(): - peak_class_probs = np.array([ - [0.1, 0.9], [0.9, 0.1], [0.95, 0.05], [0.8, 0.2], - [0.9, 0.1], [0.85, 0.15], [0.1, 0.9]]) + peak_class_probs = np.array( + [ + [0.1, 0.9], + [0.9, 0.1], + [0.95, 0.05], + [0.8, 0.2], + [0.9, 0.1], + [0.85, 0.15], + [0.1, 0.9], + ] + ) peak_sample_inds = np.array([0, 0, 0, 0, 1, 1, 1]) peak_channel_inds = np.array([0, 0, 1, 1, 0, 0, 0]) peak_inds, class_inds = group_class_peaks( @@ -30,10 +38,18 @@ def test_group_class_peaks(): assert_array_equal(class_inds, [1, 0, 0, 0, 1]) -def test_classify_peaks(): - peak_class_probs = np.array([ - [0.1, 0.9], [0.91, 0.09], [0.95, 0.05], [0.8, 0.2], - [0.92, 0.08], [0.85, 0.15], [0.07, 0.93]]) +def test_classify_peaks_from_maps(): + peak_class_probs = np.array( + [ + [0.1, 0.9], + [0.91, 0.09], + [0.95, 0.05], + [0.8, 0.2], + [0.92, 0.08], + [0.85, 0.15], + [0.07, 0.93], + ] + ) peak_sample_inds = np.array([0, 0, 0, 0, 1, 1, 1]) peak_channel_inds = np.array([0, 0, 1, 1, 0, 0, 0]) peak_points = tf.reshape(tf.range(7 * 2, dtype=tf.float32), [7, 2]) @@ -44,9 +60,13 @@ def test_classify_peaks(): class_maps[s, int(y), int(x), :] = pr class_maps = tf.cast(class_maps, tf.float32) - points, point_vals, class_probs = classify_peaks( - class_maps, peak_points, peak_vals, peak_sample_inds, peak_channel_inds, - n_channels=2 + points, point_vals, class_probs = classify_peaks_from_maps( + class_maps, + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + n_channels=2, ) assert_array_equal(points[0][0], peak_points.numpy()[[1, 2]]) diff --git a/tests/nn/test_kalman.py b/tests/nn/test_kalman.py index 3e2939403..db5176194 100644 --- a/tests/nn/test_kalman.py +++ b/tests/nn/test_kalman.py @@ -48,7 +48,12 @@ def test_first_choice_matching(): # another cost matrix # make sure we get *best* match for each track, regardless of row order - cost_matrix = np.array([[50, 100], [10, 150],]) + cost_matrix = np.array( + [ + [50, 100], + [10, 150], + ] + ) match_by_track = k.match_dict_from_match_function( cost_matrix=cost_matrix, row_items=instances, @@ -121,7 +126,12 @@ def test_track_instance_matches(): # best match is instance a -> track a # next match is instance b -> track b # but instance b would prefer track a - cost_matrix = np.array([[10, 100], [50, 150],]) + cost_matrix = np.array( + [ + [10, 100], + [50, 150], + ] + ) matches = k.get_track_instance_matches( cost_matrix=cost_matrix, @@ -143,7 +153,12 @@ def test_track_instance_matches(): # best match is instance b -> track a (cost 10) # next match is instance a -> track b (cost 100) # each instance gets its first choice so "too close" check shouldn't apply - cost_matrix = np.array([[50, 100], [10, 150],]) + cost_matrix = np.array( + [ + [50, 100], + [10, 150], + ] + ) matches = k.get_track_instance_matches( cost_matrix=cost_matrix, diff --git a/tests/nn/test_model.py b/tests/nn/test_model.py new file mode 100644 index 000000000..329e5528f --- /dev/null +++ b/tests/nn/test_model.py @@ -0,0 +1,50 @@ +import pytest +import sleap + +from sleap.nn.model import Model +from sleap.nn.config import ( + SingleInstanceConfmapsHeadConfig, + CentroidsHeadConfig, + CenteredInstanceConfmapsHeadConfig, + MultiInstanceConfmapsHeadConfig, + PartAffinityFieldsHeadConfig, + ClassMapsHeadConfig, + HeadsConfig, + UNetConfig, + BackboneConfig, + ModelConfig, +) + +sleap.use_cpu_only() + + +def test_model_from_config(): + skel = sleap.Skeleton() + skel.add_node("a") + cfg = ModelConfig() + cfg.backbone.unet = UNetConfig(filters=4, max_stride=4, output_stride=2) + with pytest.raises(ValueError): + Model.from_config(cfg, skeleton=skel) + cfg.heads.single_instance = SingleInstanceConfmapsHeadConfig( + part_names=None, + sigma=1.5, + output_stride=2, + loss_weight=2.0, + offset_refinement=True, + ) + model = Model.from_config(cfg, skeleton=skel) + + assert isinstance(model.heads[0], sleap.nn.heads.SingleInstanceConfmapsHead) + assert isinstance(model.heads[1], sleap.nn.heads.OffsetRefinementHead) + + keras_model = model.make_model(input_shape=(16, 16, 1)) + assert keras_model.output_names == [ + "SingleInstanceConfmapsHead", + "OffsetRefinementHead", + ] + assert tuple(keras_model.outputs[0].shape) == (None, 8, 8, 1) + assert tuple(keras_model.outputs[1].shape) == (None, 8, 8, 2) + + cfg.heads.single_instance = None + with pytest.raises(ValueError): + Model.from_config(cfg, skeleton=skel) diff --git a/tests/nn/test_nn_utils.py b/tests/nn/test_nn_utils.py index 872a20543..8644f6cbe 100644 --- a/tests/nn/test_nn_utils.py +++ b/tests/nn/test_nn_utils.py @@ -16,10 +16,7 @@ def test_tf_linear_sum_assignment(): def test_match_points(): - inds1, inds2 = match_points( - [[0, 0], [1, 2]], - [[1, 2], [0, 0]] - ) + inds1, inds2 = match_points([[0, 0], [1, 2]], [[1, 2], [0, 0]]) assert_array_equal(inds1, [0, 1]) assert_array_equal(inds2, [1, 0]) diff --git a/tests/nn/test_paf_grouping.py b/tests/nn/test_paf_grouping.py index 029783262..d9bb48561 100644 --- a/tests/nn/test_paf_grouping.py +++ b/tests/nn/test_paf_grouping.py @@ -26,22 +26,14 @@ def test_get_connection_candidates(): n_nodes = 4 edge_inds, edge_peak_inds = get_connection_candidates( - peak_channel_inds_sample, - skeleton_edges, - n_nodes + peak_channel_inds_sample, skeleton_edges, n_nodes ) assert_array_equal(edge_inds, [0, 0, 0, 0, 0, 0, 1, 1]) - assert_array_equal(edge_peak_inds, - [[0, 3], - [0, 4], - [1, 3], - [1, 4], - [2, 3], - [2, 4], - [3, 5], - [4, 5]]) + assert_array_equal( + edge_peak_inds, [[0, 3], [0, 4], [1, 3], [1, 4], [2, 3], [2, 4], [3, 5], [4, 5]] + ) def test_make_line_subs(): @@ -50,19 +42,12 @@ def test_make_line_subs(): edge_inds = tf.constant([0], tf.int32) line_subs = make_line_subs( - peaks_sample, - edge_peak_inds, - edge_inds, - n_line_points=3, - pafs_stride=2 + peaks_sample, edge_peak_inds, edge_inds, n_line_points=3, pafs_stride=2 + ) + assert_array_equal( + line_subs, + [[[[0, 0, 0], [0, 0, 1]], [[2, 1, 0], [2, 1, 1]], [[4, 2, 0], [4, 2, 1]]]], ) - assert_array_equal(line_subs, - [[[[0, 0, 0], - [0, 0, 1]], - [[2, 1, 0], - [2, 1, 1]], - [[4, 2, 0], - [4, 2, 1]]]]) def test_paf_lines(): @@ -77,12 +62,9 @@ def test_paf_lines(): edge_peak_inds, edge_inds, n_line_points=3, - pafs_stride=2 + pafs_stride=2, ) - assert_array_equal(paf_lines, - [[[ 0, 1], - [18, 19], - [36, 37]]]) + assert_array_equal(paf_lines, [[[0, 1], [18, 19], [36, 37]]]) def test_score_paf_lines(): @@ -90,7 +72,14 @@ def test_score_paf_lines(): peaks_sample = tf.constant([[0, 0], [4, 8]], tf.float32) edge_peak_inds = tf.constant([[0, 1]], tf.int32) edge_inds = tf.constant([0], tf.int32) - paf_lines = get_paf_lines(pafs_sample, peaks_sample, edge_peak_inds, edge_inds, n_line_points=3, pafs_stride=2) + paf_lines = get_paf_lines( + pafs_sample, + peaks_sample, + edge_peak_inds, + edge_inds, + n_line_points=3, + pafs_stride=2, + ) scores = score_paf_lines(paf_lines, peaks_sample, edge_peak_inds, max_edge_length=2) assert_allclose(scores, [24.27], atol=1e-2) @@ -114,7 +103,7 @@ def test_score_paf_lines_batch(): n_line_points, pafs_stride, max_edge_length_ratio, - n_nodes + n_nodes, ) assert_array_equal(edge_inds.to_list(), [[0]]) assert_array_equal(edge_peak_inds.to_list(), [[[0, 1]]]) @@ -133,10 +122,7 @@ def test_match_candidates_sample(): match_dst_peak_inds, match_line_scores, ) = match_candidates_sample( - edge_inds_sample, - edge_peak_inds_sample, - line_scores_sample, - n_edges + edge_inds_sample, edge_peak_inds_sample, line_scores_sample, n_edges ) src_peak_inds_k, _ = tf.unique(edge_peak_inds_sample[:, 0]) @@ -145,16 +131,22 @@ def test_match_candidates_sample(): assert_array_equal(match_edge_inds, [0]) assert_array_equal(match_src_peak_inds, [1]) assert_array_equal(match_dst_peak_inds, [0]) - assert_array_equal(match_line_scores, [1.]) + assert_array_equal(match_line_scores, [1.0]) assert tf.gather(src_peak_inds_k, match_src_peak_inds)[0] == 2 assert tf.gather(dst_peak_inds_k, match_dst_peak_inds)[0] == 1 def test_match_candidates_batch(): row_ids = tf.constant([0, 0], dtype=tf.int32) - edge_inds = tf.RaggedTensor.from_value_rowids(tf.constant([0, 0], dtype=tf.int32), row_ids) - edge_peak_inds = tf.RaggedTensor.from_value_rowids(tf.constant([[0, 1], [2, 1]], dtype=tf.int32), row_ids) - line_scores = tf.RaggedTensor.from_value_rowids(tf.constant([-0.5, 1.0], dtype=tf.float32), row_ids) + edge_inds = tf.RaggedTensor.from_value_rowids( + tf.constant([0, 0], dtype=tf.int32), row_ids + ) + edge_peak_inds = tf.RaggedTensor.from_value_rowids( + tf.constant([[0, 1], [2, 1]], dtype=tf.int32), row_ids + ) + line_scores = tf.RaggedTensor.from_value_rowids( + tf.constant([-0.5, 1.0], dtype=tf.float32), row_ids + ) n_edges = 1 ( @@ -162,12 +154,7 @@ def test_match_candidates_batch(): match_src_peak_inds, match_dst_peak_inds, match_line_scores, - ) = match_candidates_batch( - edge_inds, - edge_peak_inds, - line_scores, - n_edges - ) + ) = match_candidates_batch(edge_inds, edge_peak_inds, line_scores, n_edges) assert isinstance(match_edge_inds, tf.RaggedTensor) assert isinstance(match_src_peak_inds, tf.RaggedTensor) @@ -176,7 +163,7 @@ def test_match_candidates_batch(): assert_array_equal(match_edge_inds.flat_values, [0]) assert_array_equal(match_src_peak_inds.flat_values, [1]) assert_array_equal(match_dst_peak_inds.flat_values, [0]) - assert_array_equal(match_line_scores.flat_values, [1.]) + assert_array_equal(match_line_scores.flat_values, [1.0]) def test_group_instances_sample(): @@ -195,7 +182,7 @@ def test_group_instances_sample(): ( predicted_instances, predicted_peak_scores, - predicted_instance_scores + predicted_instance_scores, ) = group_instances_sample( peaks_sample, peak_scores_sample, @@ -212,35 +199,43 @@ def test_group_instances_sample(): assert_array_equal( predicted_instances, - [[[ 0., 1.], - [ 2., 3.], - [ 4., 5.]], - - [[ 6., 7.], - [ 8., 9.], - [ np.nan, np.nan],]] - ) - assert_array_equal( - predicted_peak_scores, - [[0., 1., 2.], - [3., 4., np.nan]] - ) - assert_array_equal( - predicted_instance_scores, - [1., 2.] + [ + [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]], + [ + [6.0, 7.0], + [8.0, 9.0], + [np.nan, np.nan], + ], + ], ) + assert_array_equal(predicted_peak_scores, [[0.0, 1.0, 2.0], [3.0, 4.0, np.nan]]) + assert_array_equal(predicted_instance_scores, [1.0, 2.0]) def test_group_instances_batch(): row_ids = tf.zeros([5], dtype=tf.int32) - peaks = tf.RaggedTensor.from_value_rowids(tf.reshape(tf.range(5 * 2, dtype=tf.float32), [5, 2]), row_ids) - peak_scores = tf.RaggedTensor.from_value_rowids(tf.range(5, dtype=tf.float32), row_ids) - peak_channel_inds = tf.RaggedTensor.from_value_rowids(tf.constant([0, 1, 2, 0, 1], tf.int32), row_ids) + peaks = tf.RaggedTensor.from_value_rowids( + tf.reshape(tf.range(5 * 2, dtype=tf.float32), [5, 2]), row_ids + ) + peak_scores = tf.RaggedTensor.from_value_rowids( + tf.range(5, dtype=tf.float32), row_ids + ) + peak_channel_inds = tf.RaggedTensor.from_value_rowids( + tf.constant([0, 1, 2, 0, 1], tf.int32), row_ids + ) row_ids_edges = tf.zeros([3], dtype=tf.int32) - match_edge_inds = tf.RaggedTensor.from_value_rowids(tf.constant([0, 1, 0], tf.int32), row_ids_edges) - match_src_peak_inds = tf.RaggedTensor.from_value_rowids(tf.constant([0, 0, 1], tf.int32), row_ids_edges) - match_dst_peak_inds = tf.RaggedTensor.from_value_rowids(tf.constant([0, 0, 1], tf.int32), row_ids_edges) - match_line_scores = tf.RaggedTensor.from_value_rowids(tf.range(3, dtype=tf.float32), row_ids_edges) + match_edge_inds = tf.RaggedTensor.from_value_rowids( + tf.constant([0, 1, 0], tf.int32), row_ids_edges + ) + match_src_peak_inds = tf.RaggedTensor.from_value_rowids( + tf.constant([0, 0, 1], tf.int32), row_ids_edges + ) + match_dst_peak_inds = tf.RaggedTensor.from_value_rowids( + tf.constant([0, 0, 1], tf.int32), row_ids_edges + ) + match_line_scores = tf.RaggedTensor.from_value_rowids( + tf.range(3, dtype=tf.float32), row_ids_edges + ) n_nodes = 3 n_edges = 2 edge_types = [EdgeType(0, 1), EdgeType(1, 2)] @@ -249,7 +244,7 @@ def test_group_instances_batch(): ( predicted_instances, predicted_peak_scores, - predicted_instance_scores + predicted_instance_scores, ) = group_instances_batch( peaks, peak_scores, @@ -270,20 +265,16 @@ def test_group_instances_batch(): assert_array_equal( predicted_instances.flat_values, - [[[ 0., 1.], - [ 2., 3.], - [ 4., 5.]], - - [[ 6., 7.], - [ 8., 9.], - [ np.nan, np.nan],]] - ) - assert_array_equal( - predicted_peak_scores.flat_values, - [[0., 1., 2.], - [3., 4., np.nan]] + [ + [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]], + [ + [6.0, 7.0], + [8.0, 9.0], + [np.nan, np.nan], + ], + ], ) assert_array_equal( - predicted_instance_scores.flat_values, - [1., 2.] + predicted_peak_scores.flat_values, [[0.0, 1.0, 2.0], [3.0, 4.0, np.nan]] ) + assert_array_equal(predicted_instance_scores.flat_values, [1.0, 2.0]) diff --git a/tests/nn/test_peak_finding.py b/tests/nn/test_peak_finding.py index 653be7d11..9e8f8c590 100644 --- a/tests/nn/test_peak_finding.py +++ b/tests/nn/test_peak_finding.py @@ -26,21 +26,25 @@ def test_find_local_offsets(): - offsets = find_offsets_local_direction(np.array( - [[0., 1., 0.], - [1., 3., 2.], - [0., 1., 0.]]).reshape(1, 3, 3, 1), 0.25) + offsets = find_offsets_local_direction( + np.array([[0.0, 1.0, 0.0], [1.0, 3.0, 2.0], [0.0, 1.0, 0.0]]).reshape( + 1, 3, 3, 1 + ), + 0.25, + ) assert tuple(offsets.shape) == (1, 2) assert offsets[0][0] == 0.25 - assert offsets[0][1] == 0. + assert offsets[0][1] == 0.0 + + offsets = find_offsets_local_direction( + np.array([[0.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 0.0]]).reshape( + 1, 3, 3, 1 + ), + 0.25, + ) + assert offsets[0][0] == 0.0 + assert offsets[0][1] == 0.0 - offsets = find_offsets_local_direction(np.array( - [[0., 1., 0.], - [1., 3., 1.], - [0., 1., 0.]]).reshape(1, 3, 3, 1), 0.25) - assert offsets[0][0] == 0. - assert offsets[0][1] == 0. - def test_find_global_peaks_rough(): xv, yv = make_grid_vectors(image_height=8, image_width=8, output_stride=1) @@ -73,8 +77,11 @@ def test_find_global_peaks_integral(): cm = make_confmaps(points, xv, yv, sigma=1.0) peaks, peak_vals = find_global_peaks( - tf.expand_dims(cm, axis=0), threshold=0.1, refinement="integral", - integral_patch_size=5) + tf.expand_dims(cm, axis=0), + threshold=0.1, + refinement="integral", + integral_patch_size=5, + ) assert peaks.shape == (1, 3, 2) assert peak_vals.shape == (1, 3) @@ -82,23 +89,31 @@ def test_find_global_peaks_integral(): assert_allclose(peak_vals[0].numpy(), [1, 1, 1], atol=0.3) peaks, peak_vals = find_global_peaks( - tf.zeros((1, 8, 8, 3), dtype=tf.float32), threshold=0.1, refinement="integral", - integral_patch_size=5) + tf.zeros((1, 8, 8, 3), dtype=tf.float32), + threshold=0.1, + refinement="integral", + integral_patch_size=5, + ) assert peaks.shape == (1, 3, 2) assert peak_vals.shape == (1, 3) assert tf.reduce_all(tf.math.is_nan(peaks)) assert_array_equal(peak_vals, [[0, 0, 0]]) peaks, peak_vals = find_global_peaks( - tf.stack([tf.zeros([12, 12, 3], dtype=tf.float32), cm], axis=0), threshold=0.1, - refinement="integral", integral_patch_size=5) + tf.stack([tf.zeros([12, 12, 3], dtype=tf.float32), cm], axis=0), + threshold=0.1, + refinement="integral", + integral_patch_size=5, + ) assert peaks.shape == (2, 3, 2) assert tf.reduce_all(tf.math.is_nan(peaks[0])) assert_allclose(peaks[1].numpy(), points.numpy(), atol=0.1) peaks, peak_vals = find_global_peaks_integral( - tf.stack([tf.zeros([12, 12, 3], dtype=tf.float32), cm], axis=0), threshold=0.1, - crop_size=5) + tf.stack([tf.zeros([12, 12, 3], dtype=tf.float32), cm], axis=0), + threshold=0.1, + crop_size=5, + ) assert peaks.shape == (2, 3, 2) assert tf.reduce_all(tf.math.is_nan(peaks[0])) assert_allclose(peaks[1].numpy(), points.numpy(), atol=0.1) @@ -110,31 +125,36 @@ def test_find_global_peaks_local(): cm = make_confmaps(points, xv, yv, sigma=1.0) peaks, peak_vals = find_global_peaks( - tf.expand_dims(cm, axis=0), threshold=0.1, refinement="local") + tf.expand_dims(cm, axis=0), threshold=0.1, refinement="local" + ) assert peaks.shape == (1, 3, 2) assert peak_vals.shape == (1, 3) - assert_allclose(peaks[0].numpy(), np.array([[1.75, 2.75], [3.75, 4.75], [5.75, 6.75]])) + assert_allclose( + peaks[0].numpy(), np.array([[1.75, 2.75], [3.75, 4.75], [5.75, 6.75]]) + ) assert_allclose(peak_vals[0].numpy(), [1, 1, 1], atol=0.3) def test_find_local_peaks_rough(): xv, yv = make_grid_vectors(image_height=16, image_width=16, output_stride=1) - instances = tf.cast([ - [[1, 2], [3, 4]], - [[5, 6], [7, 8]], - [[np.nan, np.nan], [11, 12]], - ], tf.float32) - cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.) - instances2 = tf.cast([ - [[2, 3], [4, 5]], - [[6, 7], [8, 9]] - ], tf.float32) - cms = tf.stack([cms, - make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.)], axis=0) + instances = tf.cast( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + [[np.nan, np.nan], [11, 12]], + ], + tf.float32, + ) + cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.0) + instances2 = tf.cast([[[2, 3], [4, 5]], [[6, 7], [8, 9]]], tf.float32) + cms = tf.stack( + [cms, make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.0)], axis=0 + ) peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( - cms, threshold=0.1, refinement=None) + cms, threshold=0.1, refinement=None + ) assert peak_points.shape == (9, 2) assert peak_vals.shape == (9,) @@ -153,14 +173,15 @@ def test_find_local_peaks_rough(): [4, 5], [6, 7], [8, 9], - ]) - assert_array_equal(peak_vals, [1, 1, 1, 1, 1, 1, 1, 1, 1]) - assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) + ], + ) + assert_array_equal(peak_vals, [1, 1, 1, 1, 1, 1, 1, 1, 1]) + assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) assert_array_equal(peak_channel_inds, [0, 1, 0, 1, 1, 0, 1, 0, 1]) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( - tf.zeros([1, 4, 4, 3], tf.float32), threshold=0.1, refinement=None) + tf.zeros([1, 4, 4, 3], tf.float32), threshold=0.1, refinement=None + ) assert peak_points.shape == (0, 2) assert peak_vals.shape == (0,) assert peak_sample_inds.shape == (0,) @@ -169,21 +190,27 @@ def test_find_local_peaks_rough(): def test_find_local_peaks_integral(): xv, yv = make_grid_vectors(image_height=32, image_width=32, output_stride=1) - instances = tf.cast([ - [[1, 2], [3, 4]], - [[5, 6], [7, 8]], - [[np.nan, np.nan], [11, 12]], - ], tf.float32) * 2 + 0.3 - cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.) - instances2 = tf.cast([ - [[2, 3], [4, 5]], - [[6, 7], [8, 9]] - ], tf.float32) * 2 + 0.3 - cms = tf.stack([cms, - make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.)], axis=0) + instances = ( + tf.cast( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + [[np.nan, np.nan], [11, 12]], + ], + tf.float32, + ) + * 2 + + 0.3 + ) + cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.0) + instances2 = tf.cast([[[2, 3], [4, 5]], [[6, 7], [8, 9]]], tf.float32) * 2 + 0.3 + cms = tf.stack( + [cms, make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.0)], axis=0 + ) peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( - cms, threshold=0.1, refinement="integral", integral_patch_size=5) + cms, threshold=0.1, refinement="integral", integral_patch_size=5 + ) assert peak_points.shape == (9, 2) assert peak_vals.shape == (9,) @@ -192,31 +219,41 @@ def test_find_local_peaks_integral(): assert_allclose( peak_points.numpy(), - np.array([ - [1, 2], - [3, 4], - [5, 6], - [7, 8], - [11, 12], - [2, 3], - [4, 5], - [6, 7], - [8, 9], - ]) * 2 + 0.3, atol=0.2) + np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [11, 12], + [2, 3], + [4, 5], + [6, 7], + [8, 9], + ] + ) + * 2 + + 0.3, + atol=0.2, + ) assert_allclose(peak_vals, [1, 1, 1, 1, 1, 1, 1, 1, 1], atol=0.1) - assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) + assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) assert_array_equal(peak_channel_inds, [0, 1, 0, 1, 1, 0, 1, 0, 1]) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( - tf.zeros([1, 4, 4, 3], tf.float32), refinement="integral", integral_patch_size=5) + tf.zeros([1, 4, 4, 3], tf.float32), refinement="integral", integral_patch_size=5 + ) assert peak_points.shape == (0, 2) assert peak_vals.shape == (0,) assert peak_sample_inds.shape == (0,) assert peak_channel_inds.shape == (0,) - peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks_integral( - tf.zeros([1, 4, 4, 3], tf.float32), crop_size=5) + ( + peak_points, + peak_vals, + peak_sample_inds, + peak_channel_inds, + ) = find_local_peaks_integral(tf.zeros([1, 4, 4, 3], tf.float32), crop_size=5) assert peak_points.shape == (0, 2) assert peak_vals.shape == (0,) assert peak_sample_inds.shape == (0,) @@ -225,21 +262,27 @@ def test_find_local_peaks_integral(): def test_find_local_peaks_local(): xv, yv = make_grid_vectors(image_height=32, image_width=32, output_stride=1) - instances = tf.cast([ - [[1, 2], [3, 4]], - [[5, 6], [7, 8]], - [[np.nan, np.nan], [11, 12]], - ], tf.float32) * 2 + 0.25 - cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.) - instances2 = tf.cast([ - [[2, 3], [4, 5]], - [[6, 7], [8, 9]] - ], tf.float32) * 2 + 0.25 - cms = tf.stack([cms, - make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.)], axis=0) + instances = ( + tf.cast( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + [[np.nan, np.nan], [11, 12]], + ], + tf.float32, + ) + * 2 + + 0.25 + ) + cms = make_multi_confmaps(instances, xv=xv, yv=yv, sigma=1.0) + instances2 = tf.cast([[[2, 3], [4, 5]], [[6, 7], [8, 9]]], tf.float32) * 2 + 0.25 + cms = tf.stack( + [cms, make_multi_confmaps(instances2, xv=xv, yv=yv, sigma=1.0)], axis=0 + ) peak_points, peak_vals, peak_sample_inds, peak_channel_inds = find_local_peaks( - cms, threshold=0.1, refinement="local") + cms, threshold=0.1, refinement="local" + ) assert peak_points.shape == (9, 2) assert peak_vals.shape == (9,) @@ -248,19 +291,24 @@ def test_find_local_peaks_local(): assert_allclose( peak_points.numpy(), - np.array([ - [1, 2], - [3, 4], - [5, 6], - [7, 8], - [11, 12], - [2, 3], - [4, 5], - [6, 7], - [8, 9], - ]) * 2 + 0.25) + np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [11, 12], + [2, 3], + [4, 5], + [6, 7], + [8, 9], + ] + ) + * 2 + + 0.25, + ) assert_allclose(peak_vals, [1, 1, 1, 1, 1, 1, 1, 1, 1], atol=0.1) - assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) + assert_array_equal(peak_sample_inds, [0, 0, 0, 0, 0, 1, 1, 1, 1]) assert_array_equal(peak_channel_inds, [0, 1, 0, 1, 1, 0, 1, 0, 1]) @@ -268,7 +316,9 @@ def test_find_local_peaks_local(): def test_find_global_peaks_with_offsets(output_stride, min_labels): p = min_labels.to_pipeline() p += sleap.pipelines.InstanceCentroidFinder( - center_on_anchor_part=True, anchor_part_names="A", skeletons=min_labels.skeletons + center_on_anchor_part=True, + anchor_part_names="A", + skeletons=min_labels.skeletons, ) p += sleap.pipelines.InstanceCropper(crop_width=192, crop_height=192) p += sleap.pipelines.InstanceConfidenceMapGenerator( @@ -292,10 +342,7 @@ def test_find_global_peaks_with_offsets(output_stride, min_labels): def test_find_local_peaks_with_offsets(output_stride, min_labels): p = min_labels.to_pipeline() p += sleap.pipelines.MultiConfidenceMapGenerator( - sigma=1.5, - output_stride=output_stride, - centroids=False, - with_offsets=True + sigma=1.5, output_stride=output_stride, centroids=False, with_offsets=True ) p += sleap.pipelines.Batcher(batch_size=2) @@ -307,7 +354,7 @@ def test_find_local_peaks_with_offsets(output_stride, min_labels): refined_peaks, peak_vals, peak_sample_inds, - peak_channel_inds + peak_channel_inds, ) = find_local_peaks_with_offsets(cms, offs) refined_peaks *= output_stride diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index 7b3ec4c89..6e1d03dea 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -34,66 +34,58 @@ def cfg(): def test_train_single_instance(min_labels_robot, cfg): - cfg.model.heads.single_instance = ( - sleap.nn.config.SingleInstanceConfmapsHeadConfig( - sigma=1.5, output_stride=1, offset_refinement=False - ) + cfg.model.heads.single_instance = sleap.nn.config.SingleInstanceConfmapsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=False ) trainer = sleap.nn.training.SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead_0" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) def test_train_single_instance_with_offset(min_labels_robot, cfg): - cfg.model.heads.single_instance = ( - sleap.nn.config.SingleInstanceConfmapsHeadConfig( - sigma=1.5, output_stride=1, offset_refinement=True - ) + cfg.model.heads.single_instance = sleap.nn.config.SingleInstanceConfmapsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=True ) trainer = sleap.nn.training.SingleInstanceModelTrainer.from_config( cfg, training_labels=min_labels_robot ) trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead_0" + assert trainer.keras_model.output_names[0] == "SingleInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 320, 560, 2) - assert trainer.keras_model.output_names[1] == "OffsetRefinementHead_0" + assert trainer.keras_model.output_names[1] == "OffsetRefinementHead" assert tuple(trainer.keras_model.outputs[1].shape) == (None, 320, 560, 4) def test_train_centroids(training_labels, cfg): - cfg.model.heads.centroid = ( - sleap.nn.config.CentroidsHeadConfig( - sigma=1.5, output_stride=1, offset_refinement=False - ) + cfg.model.heads.centroid = sleap.nn.config.CentroidsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=False ) trainer = sleap.nn.training.CentroidConfmapsModelTrainer.from_config( cfg, training_labels=training_labels ) trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "CentroidConfmapsHead_0" + assert trainer.keras_model.output_names[0] == "CentroidConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 384, 384, 1) def test_train_centroids_with_offset(training_labels, cfg): - cfg.model.heads.centroid = ( - sleap.nn.config.CentroidsHeadConfig( - sigma=1.5, output_stride=1, offset_refinement=True - ) + cfg.model.heads.centroid = sleap.nn.config.CentroidsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=True ) trainer = sleap.nn.training.CentroidConfmapsModelTrainer.from_config( cfg, training_labels=training_labels ) trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "CentroidConfmapsHead_0" - assert trainer.keras_model.output_names[1] == "OffsetRefinementHead_0" + assert trainer.keras_model.output_names[0] == "CentroidConfmapsHead" + assert trainer.keras_model.output_names[1] == "OffsetRefinementHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 384, 384, 1) assert tuple(trainer.keras_model.outputs[1].shape) == (None, 384, 384, 2) @@ -109,7 +101,7 @@ def test_train_topdown(training_labels, cfg): ) trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead_0" + assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 96, 96, 2) @@ -125,8 +117,8 @@ def test_train_topdown_with_offset(training_labels, cfg): trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead_0" - assert trainer.keras_model.output_names[1] == "OffsetRefinementHead_0" + assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead" + assert trainer.keras_model.output_names[1] == "OffsetRefinementHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 96, 96, 2) assert tuple(trainer.keras_model.outputs[1].shape) == (None, 96, 96, 4) @@ -134,8 +126,9 @@ def test_train_topdown_with_offset(training_labels, cfg): def test_train_bottomup(training_labels, cfg): cfg.model.heads.multi_instance = sleap.nn.config.MultiInstanceConfig( confmaps=sleap.nn.config.MultiInstanceConfmapsHeadConfig( - output_stride=1, offset_refinement=False), - pafs=sleap.nn.config.PartAffinityFieldsHeadConfig(output_stride=2) + output_stride=1, offset_refinement=False + ), + pafs=sleap.nn.config.PartAffinityFieldsHeadConfig(output_stride=2), ) trainer = sleap.nn.training.TopdownConfmapsModelTrainer.from_config( cfg, training_labels=training_labels @@ -143,8 +136,8 @@ def test_train_bottomup(training_labels, cfg): trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead_0" - assert trainer.keras_model.output_names[1] == "PartAffinityFieldsHead_0" + assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead" + assert trainer.keras_model.output_names[1] == "PartAffinityFieldsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 384, 384, 2) assert tuple(trainer.keras_model.outputs[1].shape) == (None, 192, 192, 2) @@ -152,8 +145,9 @@ def test_train_bottomup(training_labels, cfg): def test_train_bottomup_with_offset(training_labels, cfg): cfg.model.heads.multi_instance = sleap.nn.config.MultiInstanceConfig( confmaps=sleap.nn.config.MultiInstanceConfmapsHeadConfig( - output_stride=1, offset_refinement=True), - pafs=sleap.nn.config.PartAffinityFieldsHeadConfig(output_stride=2) + output_stride=1, offset_refinement=True + ), + pafs=sleap.nn.config.PartAffinityFieldsHeadConfig(output_stride=2), ) trainer = sleap.nn.training.TopdownConfmapsModelTrainer.from_config( cfg, training_labels=training_labels @@ -161,9 +155,9 @@ def test_train_bottomup_with_offset(training_labels, cfg): trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead_0" - assert trainer.keras_model.output_names[1] == "PartAffinityFieldsHead_0" - assert trainer.keras_model.output_names[2] == "OffsetRefinementHead_0" + assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead" + assert trainer.keras_model.output_names[1] == "PartAffinityFieldsHead" + assert trainer.keras_model.output_names[2] == "OffsetRefinementHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 384, 384, 2) assert tuple(trainer.keras_model.outputs[1].shape) == (None, 192, 192, 2) assert tuple(trainer.keras_model.outputs[2].shape) == (None, 384, 384, 4) @@ -172,11 +166,11 @@ def test_train_bottomup_with_offset(training_labels, cfg): def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg): labels = min_tracks_2node_labels cfg.data.preprocessing.input_scaling = 0.5 - cfg.model.heads.multi_class = sleap.nn.config.MultiClassConfig( + cfg.model.heads.multi_class_bottomup = sleap.nn.config.MultiClassBottomUpConfig( confmaps=sleap.nn.config.MultiInstanceConfmapsHeadConfig( - output_stride=2, offset_refinement=False), - class_maps=sleap.nn.config.ClassMapsHeadConfig( - output_stride=2) + output_stride=2, offset_refinement=False + ), + class_maps=sleap.nn.config.ClassMapsHeadConfig(output_stride=2), ) trainer = sleap.nn.training.BottomUpMultiClassModelTrainer.from_config( cfg, training_labels=labels @@ -184,7 +178,7 @@ def test_train_bottomup_multiclass(min_tracks_2node_labels, cfg): trainer.setup() trainer.train() - assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead_0" - assert trainer.keras_model.output_names[1] == "ClassMapsHead_0" + assert trainer.keras_model.output_names[0] == "MultiInstanceConfmapsHead" + assert trainer.keras_model.output_names[1] == "ClassMapsHead" assert tuple(trainer.keras_model.outputs[0].shape) == (None, 256, 256, 2) assert tuple(trainer.keras_model.outputs[1].shape) == (None, 256, 256, 2) diff --git a/tests/test_instance.py b/tests/test_instance.py index 7c4f1f67c..90fe2e62d 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -34,12 +34,20 @@ def test_instance_node_get_set_item(skeleton): thorax_point = instance["thorax"] assert math.isnan(thorax_point.x) and math.isnan(thorax_point.y) + instance[0] = [-20, -50] + assert instance["head"].x == -20 + assert instance["head"].y == -50 + + instance[0] = np.array([-21, -51]) + assert instance["head"].x == -21 + assert instance["head"].y == -51 + def test_instance_node_multi_get_set_item(skeleton): """ Test basic get item and set item functionality of instances. """ - node_names = ["left-wing", "head", "right-wing"] + node_names = ["head", "left-wing", "right-wing"] points = {"head": Point(1, 4), "left-wing": Point(2, 5), "right-wing": Point(3, 6)} instance1 = Instance(skeleton=skeleton, points=points) @@ -52,6 +60,23 @@ def test_instance_node_multi_get_set_item(skeleton): assert np.allclose(x_values, [1, 2, 3]) assert np.allclose(y_values, [4, 5, 6]) + np.testing.assert_array_equal( + instance1[np.array([0, 2, 4])], [[1, 4], [np.nan, np.nan], [3, 6]] + ) + + instance1[np.array([0, 1])] = [[1, 2], [3, 4]] + np.testing.assert_array_equal(instance1[np.array([0, 1])], [[1, 2], [3, 4]]) + + instance1[[0, 1]] = [[4, 3], [2, 1]] + np.testing.assert_array_equal(instance1[np.array([0, 1])], [[4, 3], [2, 1]]) + + instance1[["left-wing", "right-wing"]] = [[-4, -3], [-2, -1]] + np.testing.assert_array_equal(instance1[np.array([3, 4])], [[-4, -3], [-2, -1]]) + assert instance1["left-wing"].x == -4 + assert instance1["left-wing"].y == -3 + assert instance1["right-wing"].x == -2 + assert instance1["right-wing"].y == -1 + def test_non_exist_node(skeleton): """ @@ -258,9 +283,14 @@ def test_instance_from_pointsarray(skeleton): def test_frame_merge_predicted_and_user(skeleton, centered_pair_vid): - user_inst = Instance(skeleton=skeleton, points={skeleton.nodes[0]: Point(1, 2)},) + user_inst = Instance( + skeleton=skeleton, + points={skeleton.nodes[0]: Point(1, 2)}, + ) user_frame = LabeledFrame( - video=centered_pair_vid, frame_idx=0, instances=[user_inst], + video=centered_pair_vid, + frame_idx=0, + instances=[user_inst], ) pred_inst = PredictedInstance( @@ -269,7 +299,9 @@ def test_frame_merge_predicted_and_user(skeleton, centered_pair_vid): score=1.0, ) pred_frame = LabeledFrame( - video=centered_pair_vid, frame_idx=0, instances=[pred_inst], + video=centered_pair_vid, + frame_idx=0, + instances=[pred_inst], ) LabeledFrame.complex_frame_merge(user_frame, pred_frame) @@ -282,9 +314,18 @@ def test_frame_merge_predicted_and_user(skeleton, centered_pair_vid): def test_frame_merge_between_predicted_and_user(skeleton, centered_pair_vid): - user_inst = Instance(skeleton=skeleton, points={skeleton.nodes[0]: Point(1, 2)},) + user_inst = Instance( + skeleton=skeleton, + points={skeleton.nodes[0]: Point(1, 2)}, + ) user_labels = Labels( - [LabeledFrame(video=centered_pair_vid, frame_idx=0, instances=[user_inst],)] + [ + LabeledFrame( + video=centered_pair_vid, + frame_idx=0, + instances=[user_inst], + ) + ] ) pred_inst = PredictedInstance( @@ -293,7 +334,13 @@ def test_frame_merge_between_predicted_and_user(skeleton, centered_pair_vid): score=1.0, ) pred_labels = Labels( - [LabeledFrame(video=centered_pair_vid, frame_idx=0, instances=[pred_inst],)] + [ + LabeledFrame( + video=centered_pair_vid, + frame_idx=0, + instances=[pred_inst], + ) + ] ) # Merge predictions into current labels dataset diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index acdda25b1..e35aa5bec 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -144,6 +144,14 @@ def test_symmetry(): s1.add_symmetry("1", "5") s1.add_symmetry("3", "6") + assert (s1.nodes[0], s1.nodes[4]) in s1.symmetries + assert (s1.nodes[2], s1.nodes[5]) in s1.symmetries + assert len(s1.symmetries) == 2 + + assert (0, 4) in s1.symmetric_inds + assert (2, 5) in s1.symmetric_inds + assert len(s1.symmetric_inds) == 2 + assert s1.get_symmetry("1").name == "5" assert s1.get_symmetry("5").name == "1" @@ -388,13 +396,3 @@ def test_arborescence(): assert len(skeleton.cycles) == 0 assert len(skeleton.root_nodes) == 1 assert len(skeleton.in_degree_over_one) == 1 - - -def test_repr_str(): - skel = Skeleton(name="skel") - skel.add_node("A") - skel.add_node("B") - skel.add_edge("A", "B") - - assert repr(skel) == "Skeleton(name='skel', nodes=['A', 'B'], edges=[('A', 'B')])" - assert str(skel) == "Skeleton(nodes=2, edges=1)"