diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 6a9d05806..9e07c0a25 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -99,9 +99,9 @@ optional arguments: -e [EXPORT_PATH], --export_path [EXPORT_PATH] Path to output directory where the frozen model will be exported to. Defaults to a folder named 'exported_model'. - -u, --unrag UNRAG - Convert ragged tensors into regular tensors with NaN padding. - Defaults to True. + -r, --ragged RAGGED + Keep tensors ragged if present. If ommited, convert + ragged tensors into regular tensors with NaN padding. -n, --max_instances MAX_INSTANCES Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 8183b0c32..ed9b34600 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -49,6 +49,8 @@ import platform import random import re +import traceback +from logging import getLogger from pathlib import Path from typing import Callable, List, Optional, Tuple @@ -85,6 +87,9 @@ from sleap.util import parse_uri_path +logger = getLogger(__name__) + + class MainWindow(QMainWindow): """The SLEAP GUI application. @@ -101,6 +106,7 @@ class MainWindow(QMainWindow): def __init__( self, labels_path: Optional[str] = None, + labels: Optional[Labels] = None, reset: bool = False, no_usage_data: bool = False, *args, @@ -118,7 +124,7 @@ def __init__( self.setAcceptDrops(True) self.state = GuiState() - self.labels = Labels() + self.labels = labels or Labels() self.commands = CommandContext( state=self.state, app=self, update_callback=self.on_data_update @@ -180,8 +186,10 @@ def __init__( print("Restoring GUI state...") self.restoreState(prefs["window state"]) - if labels_path: + if labels_path is not None: self.commands.loadProjectFile(filename=labels_path) + elif labels is not None: + self.commands.loadLabelsObject(labels=labels) else: self.state["project_loaded"] = False @@ -261,7 +269,6 @@ def dragEnterEvent(self, event): event.acceptProposedAction() def dropEvent(self, event): - # Parse filenames filenames = event.mimeData().data("text/uri-list").data().decode() filenames = [parse_uri_path(f.strip()) for f in filenames.strip().split("\n")] @@ -1601,8 +1608,12 @@ def _show_keyboard_shortcuts_window(self): ShortcutDialog().exec_() -def main(args: Optional[list] = None): - """Starts new instance of app.""" +def create_sleap_label_parser(): + """Creates parser for `sleap-label` command line arguments. + + Returns: + argparse.ArgumentParser: The parser. + """ import argparse @@ -1642,6 +1653,23 @@ def main(args: Optional[list] = None): default=False, ) + return parser + + +def create_app(): + """Creates Qt application.""" + + app = QApplication([]) + app.setApplicationName(f"SLEAP v{sleap.version.__version__}") + app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) + + return app + + +def main(args: Optional[list] = None, labels: Optional[Labels] = None): + """Starts new instance of app.""" + + parser = create_sleap_label_parser() args = parser.parse_args(args) if args.nonnative: @@ -1653,17 +1681,26 @@ def main(args: Optional[list] = None): # https://stackoverflow.com/q/64818879 os.environ["QT_MAC_WANTS_LAYER"] = "1" - app = QApplication([]) - app.setApplicationName(f"SLEAP v{sleap.version.__version__}") - app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) + app = create_app() window = MainWindow( - labels_path=args.labels_path, reset=args.reset, no_usage_data=args.no_usage_data + labels_path=args.labels_path, + labels=labels, + reset=args.reset, + no_usage_data=args.no_usage_data, ) window.showMaximized() # Disable GPU in GUI process. This does not affect subprocesses. - sleap.use_cpu_only() + try: + sleap.use_cpu_only() + except RuntimeError: # Visible devices cannot be modified after being initialized + logger.warning( + "Running processes on the GPU. Restarting your GUI should allow switching " + "back to CPU-only mode.\n" + "Received the following error when trying to switch back to CPU-only mode:" + ) + traceback.print_exc() # Print versions. print() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 78a8c2a31..8ac4d87fb 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from enum import Enum from glob import glob from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union import attr import cv2 @@ -260,16 +260,15 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None): """ self.execute(LoadLabelsObject, labels=labels, filename=filename) - def loadProjectFile(self, filename: str): + def loadProjectFile(self, filename: Union[str, Labels]): """Loads given labels file into GUI. Args: - filename: The path to the saved labels dataset. If None, - then don't do anything. + filename: The path to the saved labels dataset or the `Labels` object. + If None, then don't do anything. Returns: None - """ self.execute(LoadProjectFile, filename=filename) @@ -647,9 +646,8 @@ def do_action(context: "CommandContext", params: dict): Returns: None. - """ - filename = params["filename"] + filename = params.get("filename", None) # If called with just a Labels object labels: Labels = params["labels"] context.state["labels"] = labels @@ -669,7 +667,9 @@ def do_action(context: "CommandContext", params: dict): context.state["video"] = labels.videos[0] context.state["project_loaded"] = True - context.state["has_changes"] = params.get("changed_on_load", False) + context.state["has_changes"] = params.get("changed_on_load", False) or ( + filename is None + ) # This is not listed as an edit command since we want a clean changestack context.app.on_data_update([UpdateTopic.project, UpdateTopic.all]) @@ -683,17 +683,16 @@ def ask(context: "CommandContext", params: dict): if len(filename) == 0: return - gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)], context=params - ) - has_loaded = False labels = None - if type(filename) == Labels: + if isinstance(filename, Labels): labels = filename filename = None has_loaded = True else: + gui_video_callback = Labels.make_gui_video_callback( + search_paths=[os.path.dirname(filename)], context=params + ) try: labels = Labels.load_file(filename, video_search=gui_video_callback) has_loaded = True diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 6d7d24f8c..0cabc91bb 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4939,7 +4939,7 @@ def export_cli(args: Optional[list] = None): export_model( args.models, args.export_path, - unrag_outputs=args.unrag, + unrag_outputs=(not args.ragged), max_instances=args.max_instances, ) @@ -4971,13 +4971,13 @@ def _make_export_cli_parser() -> argparse.ArgumentParser: ), ) parser.add_argument( - "-u", - "--unrag", + "-r", + "--ragged", action="store_true", - default=True, + default=False, help=( - "Convert ragged tensors into regular tensors with NaN padding. " - "Defaults to True." + "Keep tensors ragged if present. If ommited, convert ragged tensors" + " into regular tensors with NaN padding." ), ) parser.add_argument( diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fe848bb1c..dedf0d324 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -50,6 +50,7 @@ _make_tracker_from_cli, main as sleap_track, export_cli as sleap_export, + _make_export_cli_parser, ) from sleap.nn.tracking import ( MatchedFrameInstance, @@ -925,7 +926,7 @@ def test_load_model(resize_input_shape, model_fixture_name, request): predictor = load_model(model_path, resize_input_layer=resize_input_shape) # Determine predictor type - for (fname, mname, ptype, ishape) in fname_mname_ptype_ishape: + for fname, mname, ptype, ishape in fname_mname_ptype_ishape: if fname in model_fixture_name: expected_model_name = mname expected_predictor_type = ptype @@ -966,7 +967,6 @@ def test_topdown_multi_size_inference( def test_ensure_numpy( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp ): - model = load_model([min_centroid_model_path, min_centered_instance_model_path]) # each frame has same number of instances @@ -1037,7 +1037,6 @@ def test_ensure_numpy( def test_centroid_inference(): - xv, yv = make_grid_vectors(image_height=12, image_width=12, output_stride=1) points = tf.cast([[[1.75, 2.75]], [[3.75, 4.75]], [[5.75, 6.75]]], tf.float32) cms = tf.expand_dims(make_multi_confmaps(points, xv, yv, sigma=1.5), axis=0) @@ -1093,7 +1092,6 @@ def test_centroid_inference(): def export_frozen_graph(model, preds, output_path): - tensors = {} for key, val in preds.items(): @@ -1120,7 +1118,6 @@ def export_frozen_graph(model, preds, output_path): info = json.load(json_file) for tensor_info in info["frozen_model_inputs"] + info["frozen_model_outputs"]: - saved_name = ( tensor_info.split("Tensor(")[1].split(", shape")[0].replace('"', "") ) @@ -1137,7 +1134,6 @@ def export_frozen_graph(model, preds, output_path): def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): - single_instance_model = tf.keras.models.load_model( min_single_instance_robot_model_path + "/best_model.h5", compile=False ) @@ -1152,7 +1148,6 @@ def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): def test_centroid_save(min_centroid_model_path, tmp_path): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1171,7 +1166,6 @@ def test_centroid_save(min_centroid_model_path, tmp_path): def test_topdown_save( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1195,7 +1189,6 @@ def test_topdown_save( def test_topdown_id_save( min_centroid_model_path, min_topdown_multiclass_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1217,7 +1210,6 @@ def test_topdown_id_save( def test_single_instance_predictor_save(min_single_instance_robot_model_path, tmp_path): - # directly initialize predictor predictor = SingleInstancePredictor.from_trained_models( min_single_instance_robot_model_path, resize_input_layer=False @@ -1254,10 +1246,33 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm ) +def test_make_export_cli(): + models_path = r"psuedo/models/path" + export_path = r"psuedo/test/path" + max_instances = 5 + + parser = _make_export_cli_parser() + + # Test default values + args = None + args, _ = parser.parse_known_args(args=args) + assert args.models is None + assert args.export_path == "exported_model" + assert not args.ragged + assert args.max_instances is None + + # Test all arguments + cmd = f"-m {models_path} -e {export_path} -r -n {max_instances}" + args, _ = parser.parse_known_args(args=cmd.split()) + assert args.models == [models_path] + assert args.export_path == export_path + assert args.ragged + assert args.max_instances == max_instances + + def test_topdown_predictor_save( min_centroid_model_path, min_centered_instance_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1300,7 +1315,6 @@ def test_topdown_predictor_save( def test_topdown_id_predictor_save( min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownMultiClassPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1478,7 +1492,6 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): # Run tracking on subset of frames using psuedo-implementation of # sleap.nn.tracking.run_tracker for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None @@ -1522,7 +1535,6 @@ def test_max_tracks_matching_queue( frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None