From ed77b49164b654bf9e43c29fcc3af73be2f8eb3b Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Fri, 29 Sep 2023 09:42:58 -0400 Subject: [PATCH 1/5] Set default callable for `match_lists_function` (#1520) * Set default for `match_lists_function` * Move test code to official tests * Check using expected values --- sleap/info/metrics.py | 181 ++++++++++++++----------------------- tests/info/test_metrics.py | 55 +++++++++++ 2 files changed, 124 insertions(+), 112 deletions(-) create mode 100644 tests/info/test_metrics.py diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index 2ac61d339..5bec077e4 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -10,75 +10,6 @@ from sleap.io.dataset import Labels -def matched_instance_distances( - labels_gt: Labels, - labels_pr: Labels, - match_lists_function: Callable, - frame_range: Optional[range] = None, -) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: - - """ - Distances between ground truth and predicted nodes over a set of frames. - - Args: - labels_gt: the `Labels` object with ground truth data - labels_pr: the `Labels` object with predicted data - match_lists_function: function for determining corresponding instances - Takes two lists of instances and returns "sorted" lists. - frame_range (optional): range of frames for which to compare data - If None, we compare every frame in labels_gt with corresponding - frame in labels_pr. - Returns: - Tuple: - * frame indices map: instance idx (for other matrices) -> frame idx - * distance matrix: (instances * nodes) - * ground truth points matrix: (instances * nodes * 2) - * predicted points matrix: (instances * nodes * 2) - """ - - frame_idxs = [] - points_gt = [] - points_pr = [] - for lf_gt in labels_gt.find(labels_gt.videos[0]): - frame_idx = lf_gt.frame_idx - - # Get instances from ground truth/predicted labels - instances_gt = lf_gt.instances - lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx) - if len(lfs_pr): - instances_pr = lfs_pr[0].instances - else: - instances_pr = [] - - # Sort ground truth and predicted instances. - # We'll then compare points between corresponding items in lists. - # We can use different "match" functions depending on what we want. - sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr) - - # Convert lists of instances to (instances, nodes, 2) matrices. - # This allows match_lists_function to return data as either - # a list of Instances or a (instances, nodes, 2) matrix. - if type(sorted_gt[0]) != np.ndarray: - sorted_gt = list_points_array(sorted_gt) - if type(sorted_pr[0]) != np.ndarray: - sorted_pr = list_points_array(sorted_pr) - - points_gt.append(sorted_gt) - points_pr.append(sorted_pr) - frame_idxs.extend([frame_idx] * len(sorted_gt)) - - # Convert arrays to numpy matrixes - # instances * nodes * (x,y) - points_gt = np.concatenate(points_gt) - points_pr = np.concatenate(points_pr) - - # Calculate distances between corresponding nodes for all corresponding - # ground truth and predicted instances. - D = np.linalg.norm(points_gt - points_pr, axis=2) - - return frame_idxs, D, points_gt, points_pr - - def match_instance_lists( instances_a: List[Union[Instance, PredictedInstance]], instances_b: List[Union[Instance, PredictedInstance]], @@ -165,6 +96,75 @@ def match_instance_lists_nodewise( return instances_a, best_points_array +def matched_instance_distances( + labels_gt: Labels, + labels_pr: Labels, + match_lists_function: Callable = match_instance_lists_nodewise, + frame_range: Optional[range] = None, +) -> Tuple[List[int], np.ndarray, np.ndarray, np.ndarray]: + + """ + Distances between ground truth and predicted nodes over a set of frames. + + Args: + labels_gt: the `Labels` object with ground truth data + labels_pr: the `Labels` object with predicted data + match_lists_function: function for determining corresponding instances + Takes two lists of instances and returns "sorted" lists. + frame_range (optional): range of frames for which to compare data + If None, we compare every frame in labels_gt with corresponding + frame in labels_pr. + Returns: + Tuple: + * frame indices map: instance idx (for other matrices) -> frame idx + * distance matrix: (instances * nodes) + * ground truth points matrix: (instances * nodes * 2) + * predicted points matrix: (instances * nodes * 2) + """ + + frame_idxs = [] + points_gt = [] + points_pr = [] + for lf_gt in labels_gt.find(labels_gt.videos[0]): + frame_idx = lf_gt.frame_idx + + # Get instances from ground truth/predicted labels + instances_gt = lf_gt.instances + lfs_pr = labels_pr.find(labels_pr.videos[0], frame_idx=frame_idx) + if len(lfs_pr): + instances_pr = lfs_pr[0].instances + else: + instances_pr = [] + + # Sort ground truth and predicted instances. + # We'll then compare points between corresponding items in lists. + # We can use different "match" functions depending on what we want. + sorted_gt, sorted_pr = match_lists_function(instances_gt, instances_pr) + + # Convert lists of instances to (instances, nodes, 2) matrices. + # This allows match_lists_function to return data as either + # a list of Instances or a (instances, nodes, 2) matrix. + if type(sorted_gt[0]) != np.ndarray: + sorted_gt = list_points_array(sorted_gt) + if type(sorted_pr[0]) != np.ndarray: + sorted_pr = list_points_array(sorted_pr) + + points_gt.append(sorted_gt) + points_pr.append(sorted_pr) + frame_idxs.extend([frame_idx] * len(sorted_gt)) + + # Convert arrays to numpy matrixes + # instances * nodes * (x,y) + points_gt = np.concatenate(points_gt) + points_pr = np.concatenate(points_pr) + + # Calculate distances between corresponding nodes for all corresponding + # ground truth and predicted instances. + D = np.linalg.norm(points_gt - points_pr, axis=2) + + return frame_idxs, D, points_gt, points_pr + + def point_dist( inst_a: Union[Instance, PredictedInstance], inst_b: Union[Instance, PredictedInstance], @@ -238,46 +238,3 @@ def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int: def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are not <= threshold.""" return dist_array.shape[0] - point_match_count(dist_array, thresh) - - -if __name__ == "__main__": - - labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") - labels_pr = Labels.load_json( - "tests/data/json_format_v2/centered_pair_predictions.json" - ) - - # OPTION 1 - - # Match each ground truth instance node to the closest corresponding node - # from any predicted instance in the same frame. - - nodewise_matching_func = match_instance_lists_nodewise - - # OPTION 2 - - # Match each ground truth instance to a distinct predicted instance: - # We want to maximize the number of "matching" points between instances, - # where "match" means the points are within some threshold distance. - # Note that each sorted list will be as long as the shorted input list. - - instwise_matching_func = lambda gt_list, pr_list: match_instance_lists( - gt_list, pr_list, point_nonmatch_count - ) - - # PICK THE FUNCTION - - inst_matching_func = nodewise_matching_func - # inst_matching_func = instwise_matching_func - - # Calculate distances - frame_idxs, D, points_gt, points_pr = matched_instance_distances( - labels_gt, labels_pr, inst_matching_func - ) - - # Show mean difference for each node - node_names = labels_gt.skeletons[0].node_names - - for node_idx, node_name in enumerate(node_names): - mean_d = np.nanmean(D[..., node_idx]) - print(f"{node_name}\t\t{mean_d}") diff --git a/tests/info/test_metrics.py b/tests/info/test_metrics.py new file mode 100644 index 000000000..0d2e097e6 --- /dev/null +++ b/tests/info/test_metrics.py @@ -0,0 +1,55 @@ +import numpy as np + +from sleap import Labels +from sleap.info.metrics import ( + match_instance_lists_nodewise, + matched_instance_distances, +) + + +def test_matched_instance_distances(centered_pair_labels, centered_pair_predictions): + labels_gt = centered_pair_labels + labels_pr = centered_pair_predictions + + # Match each ground truth instance node to the closest corresponding node + # from any predicted instance in the same frame. + + inst_matching_func = match_instance_lists_nodewise + + # Calculate distances + frame_idxs, D, points_gt, points_pr = matched_instance_distances( + labels_gt, labels_pr, inst_matching_func + ) + + # Show mean difference for each node + node_names = labels_gt.skeletons[0].node_names + expected_values = { + "head": 0.872426920709296, + "neck": 0.8016280746914615, + "thorax": 0.8602021363390538, + "abdomen": 1.01012200038258, + "wingL": 1.1297727023475939, + "wingR": 1.0869857897008424, + "forelegL1": 0.780584225081443, + "forelegL2": 1.170805798894702, + "forelegL3": 1.1020486509389473, + "forelegR1": 0.9014698776116817, + "forelegR2": 0.9448001033112047, + "forelegR3": 1.308385214215777, + "midlegL1": 0.9095691623265347, + "midlegL2": 1.2203595627907582, + "midlegL3": 0.9813843358470163, + "midlegR1": 0.9871017182813739, + "midlegR2": 1.0209829335569256, + "midlegR3": 1.0990681234096988, + "hindlegL1": 1.0005335192834348, + "hindlegL2": 1.273539518539708, + "hindlegL3": 1.1752245985832817, + "hindlegR1": 1.1402833959265248, + "hindlegR2": 1.3143221301212737, + "hindlegR3": 1.0441458592503365, + } + + for node_idx, node_name in enumerate(node_names): + mean_d = np.nanmean(D[..., node_idx]) + assert np.isclose(mean_d, expected_values[node_name], atol=1e-6) From 79f7fba565b3eb3d44d5f88a7d038a3b7dc16edb Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 11 Oct 2023 12:45:52 -0700 Subject: [PATCH 2/5] Allow passing in `Labels` to `app.main` (#1524) * Allow passing in `Labels` to `app.main` * Load the labels object through command * Add warning when unable to switch back to CPU mode --- sleap/gui/app.py | 37 +++++++++++++++++++++++++++++++------ sleap/gui/commands.py | 25 ++++++++++++------------- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index de6ce9fbf..065563e66 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -49,6 +49,8 @@ import platform import random import re +import traceback +from logging import getLogger from pathlib import Path from typing import Callable, List, Optional, Tuple @@ -85,6 +87,9 @@ from sleap.util import parse_uri_path +logger = getLogger(__name__) + + class MainWindow(QMainWindow): """The SLEAP GUI application. @@ -101,6 +106,7 @@ class MainWindow(QMainWindow): def __init__( self, labels_path: Optional[str] = None, + labels: Optional[Labels] = None, reset: bool = False, no_usage_data: bool = False, *args, @@ -118,7 +124,7 @@ def __init__( self.setAcceptDrops(True) self.state = GuiState() - self.labels = Labels() + self.labels = labels or Labels() self.commands = CommandContext( state=self.state, app=self, update_callback=self.on_data_update @@ -175,8 +181,10 @@ def __init__( print("Restoring GUI state...") self.restoreState(prefs["window state"]) - if labels_path: + if labels_path is not None: self.commands.loadProjectFile(filename=labels_path) + elif labels is not None: + self.commands.loadLabelsObject(labels=labels) else: self.state["project_loaded"] = False @@ -1594,8 +1602,7 @@ def _show_keyboard_shortcuts_window(self): ShortcutDialog().exec_() -def main(args: Optional[list] = None): - """Starts new instance of app.""" +def create_parser(): import argparse @@ -1635,6 +1642,13 @@ def main(args: Optional[list] = None): default=False, ) + return parser + + +def main(args: Optional[list] = None, labels: Optional[Labels] = None): + """Starts new instance of app.""" + + parser = create_parser() args = parser.parse_args(args) if args.nonnative: @@ -1651,12 +1665,23 @@ def main(args: Optional[list] = None): app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) window = MainWindow( - labels_path=args.labels_path, reset=args.reset, no_usage_data=args.no_usage_data + labels_path=args.labels_path, + labels=labels, + reset=args.reset, + no_usage_data=args.no_usage_data, ) window.showMaximized() # Disable GPU in GUI process. This does not affect subprocesses. - sleap.use_cpu_only() + try: + sleap.use_cpu_only() + except RuntimeError: # Visible devices cannot be modified after being initialized + logger.warning( + "Running processes on the GPU. Restarting your GUI should allow switching " + "back to CPU-only mode.\n" + "Received the following error when trying to switch back to CPU-only mode:" + ) + traceback.print_exc() # Print versions. print() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 78a8c2a31..8ac4d87fb 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from enum import Enum from glob import glob from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union import attr import cv2 @@ -260,16 +260,15 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None): """ self.execute(LoadLabelsObject, labels=labels, filename=filename) - def loadProjectFile(self, filename: str): + def loadProjectFile(self, filename: Union[str, Labels]): """Loads given labels file into GUI. Args: - filename: The path to the saved labels dataset. If None, - then don't do anything. + filename: The path to the saved labels dataset or the `Labels` object. + If None, then don't do anything. Returns: None - """ self.execute(LoadProjectFile, filename=filename) @@ -647,9 +646,8 @@ def do_action(context: "CommandContext", params: dict): Returns: None. - """ - filename = params["filename"] + filename = params.get("filename", None) # If called with just a Labels object labels: Labels = params["labels"] context.state["labels"] = labels @@ -669,7 +667,9 @@ def do_action(context: "CommandContext", params: dict): context.state["video"] = labels.videos[0] context.state["project_loaded"] = True - context.state["has_changes"] = params.get("changed_on_load", False) + context.state["has_changes"] = params.get("changed_on_load", False) or ( + filename is None + ) # This is not listed as an edit command since we want a clean changestack context.app.on_data_update([UpdateTopic.project, UpdateTopic.all]) @@ -683,17 +683,16 @@ def ask(context: "CommandContext", params: dict): if len(filename) == 0: return - gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)], context=params - ) - has_loaded = False labels = None - if type(filename) == Labels: + if isinstance(filename, Labels): labels = filename filename = None has_loaded = True else: + gui_video_callback = Labels.make_gui_video_callback( + search_paths=[os.path.dirname(filename)], context=params + ) try: labels = Labels.load_file(filename, video_search=gui_video_callback) has_loaded = True From 6b14bcab93b8c1953913dcfbabd6508f78f77e8a Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 11 Oct 2023 12:55:28 -0700 Subject: [PATCH 3/5] Replace (broken) `--unrag` with `--ragged` (#1539) * Fix unrag always set to true in sleap-export * Replace unrag with ragged * Fix typos --- docs/guides/cli.md | 6 +++--- sleap/nn/inference.py | 12 ++++++------ tests/nn/test_inference.py | 40 +++++++++++++++++++++++++------------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 6a9d05806..9e07c0a25 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -99,9 +99,9 @@ optional arguments: -e [EXPORT_PATH], --export_path [EXPORT_PATH] Path to output directory where the frozen model will be exported to. Defaults to a folder named 'exported_model'. - -u, --unrag UNRAG - Convert ragged tensors into regular tensors with NaN padding. - Defaults to True. + -r, --ragged RAGGED + Keep tensors ragged if present. If ommited, convert + ragged tensors into regular tensors with NaN padding. -n, --max_instances MAX_INSTANCES Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 6d7d24f8c..0cabc91bb 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4939,7 +4939,7 @@ def export_cli(args: Optional[list] = None): export_model( args.models, args.export_path, - unrag_outputs=args.unrag, + unrag_outputs=(not args.ragged), max_instances=args.max_instances, ) @@ -4971,13 +4971,13 @@ def _make_export_cli_parser() -> argparse.ArgumentParser: ), ) parser.add_argument( - "-u", - "--unrag", + "-r", + "--ragged", action="store_true", - default=True, + default=False, help=( - "Convert ragged tensors into regular tensors with NaN padding. " - "Defaults to True." + "Keep tensors ragged if present. If ommited, convert ragged tensors" + " into regular tensors with NaN padding." ), ) parser.add_argument( diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fe848bb1c..dedf0d324 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -50,6 +50,7 @@ _make_tracker_from_cli, main as sleap_track, export_cli as sleap_export, + _make_export_cli_parser, ) from sleap.nn.tracking import ( MatchedFrameInstance, @@ -925,7 +926,7 @@ def test_load_model(resize_input_shape, model_fixture_name, request): predictor = load_model(model_path, resize_input_layer=resize_input_shape) # Determine predictor type - for (fname, mname, ptype, ishape) in fname_mname_ptype_ishape: + for fname, mname, ptype, ishape in fname_mname_ptype_ishape: if fname in model_fixture_name: expected_model_name = mname expected_predictor_type = ptype @@ -966,7 +967,6 @@ def test_topdown_multi_size_inference( def test_ensure_numpy( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp ): - model = load_model([min_centroid_model_path, min_centered_instance_model_path]) # each frame has same number of instances @@ -1037,7 +1037,6 @@ def test_ensure_numpy( def test_centroid_inference(): - xv, yv = make_grid_vectors(image_height=12, image_width=12, output_stride=1) points = tf.cast([[[1.75, 2.75]], [[3.75, 4.75]], [[5.75, 6.75]]], tf.float32) cms = tf.expand_dims(make_multi_confmaps(points, xv, yv, sigma=1.5), axis=0) @@ -1093,7 +1092,6 @@ def test_centroid_inference(): def export_frozen_graph(model, preds, output_path): - tensors = {} for key, val in preds.items(): @@ -1120,7 +1118,6 @@ def export_frozen_graph(model, preds, output_path): info = json.load(json_file) for tensor_info in info["frozen_model_inputs"] + info["frozen_model_outputs"]: - saved_name = ( tensor_info.split("Tensor(")[1].split(", shape")[0].replace('"', "") ) @@ -1137,7 +1134,6 @@ def export_frozen_graph(model, preds, output_path): def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): - single_instance_model = tf.keras.models.load_model( min_single_instance_robot_model_path + "/best_model.h5", compile=False ) @@ -1152,7 +1148,6 @@ def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): def test_centroid_save(min_centroid_model_path, tmp_path): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1171,7 +1166,6 @@ def test_centroid_save(min_centroid_model_path, tmp_path): def test_topdown_save( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1195,7 +1189,6 @@ def test_topdown_save( def test_topdown_id_save( min_centroid_model_path, min_topdown_multiclass_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1217,7 +1210,6 @@ def test_topdown_id_save( def test_single_instance_predictor_save(min_single_instance_robot_model_path, tmp_path): - # directly initialize predictor predictor = SingleInstancePredictor.from_trained_models( min_single_instance_robot_model_path, resize_input_layer=False @@ -1254,10 +1246,33 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm ) +def test_make_export_cli(): + models_path = r"psuedo/models/path" + export_path = r"psuedo/test/path" + max_instances = 5 + + parser = _make_export_cli_parser() + + # Test default values + args = None + args, _ = parser.parse_known_args(args=args) + assert args.models is None + assert args.export_path == "exported_model" + assert not args.ragged + assert args.max_instances is None + + # Test all arguments + cmd = f"-m {models_path} -e {export_path} -r -n {max_instances}" + args, _ = parser.parse_known_args(args=cmd.split()) + assert args.models == [models_path] + assert args.export_path == export_path + assert args.ragged + assert args.max_instances == max_instances + + def test_topdown_predictor_save( min_centroid_model_path, min_centered_instance_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1300,7 +1315,6 @@ def test_topdown_predictor_save( def test_topdown_id_predictor_save( min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownMultiClassPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1478,7 +1492,6 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): # Run tracking on subset of frames using psuedo-implementation of # sleap.nn.tracking.run_tracker for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None @@ -1522,7 +1535,6 @@ def test_max_tracks_matching_queue( frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None From 1e0627aae2e69e664d8eca0dd3af04ba8dd9b13c Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Thu, 12 Oct 2023 08:52:48 -0700 Subject: [PATCH 4/5] Add function to create app (#1546) --- sleap/gui/app.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 065563e66..41d696f0c 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -262,7 +262,6 @@ def dragEnterEvent(self, event): event.acceptProposedAction() def dropEvent(self, event): - # Parse filenames filenames = event.mimeData().data("text/uri-list").data().decode() filenames = [parse_uri_path(f.strip()) for f in filenames.strip().split("\n")] @@ -1602,7 +1601,12 @@ def _show_keyboard_shortcuts_window(self): ShortcutDialog().exec_() -def create_parser(): +def create_sleap_label_parser(): + """Creates parser for `sleap-label` command line arguments. + + Returns: + argparse.ArgumentParser: The parser. + """ import argparse @@ -1645,10 +1649,20 @@ def create_parser(): return parser +def create_app(): + """Creates Qt application.""" + + app = QApplication([]) + app.setApplicationName(f"SLEAP v{sleap.version.__version__}") + app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) + + return app + + def main(args: Optional[list] = None, labels: Optional[Labels] = None): """Starts new instance of app.""" - parser = create_parser() + parser = create_sleap_label_parser() args = parser.parse_args(args) if args.nonnative: @@ -1660,9 +1674,7 @@ def main(args: Optional[list] = None, labels: Optional[Labels] = None): # https://stackoverflow.com/q/64818879 os.environ["QT_MAC_WANTS_LAYER"] = "1" - app = QApplication([]) - app.setApplicationName(f"SLEAP v{sleap.version.__version__}") - app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) + app = create_app() window = MainWindow( labels_path=args.labels_path, From 5c3441cc47ad6a5dbad42f1dc67d13bcf9be7baa Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Thu, 19 Oct 2023 09:45:18 -0700 Subject: [PATCH 5/5] Refactor `AddInstance` command (#1561) * Refactor AddInstance command * Add staticmethod wrappers * Return early from set_visible_nodes --- sleap/gui/commands.py | 273 ++++++++++++++++++++++++++++-------------- 1 file changed, 184 insertions(+), 89 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 8ac4d87fb..ef6055a45 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -49,7 +49,6 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.dialogs.merge import MergeDialog, ReplaceSkeletonTableDialog from sleap.gui.dialogs.message import MessageDialog from sleap.gui.dialogs.missingfiles import MissingFilesDialog -from sleap.gui.dialogs.query import QueryDialog from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track @@ -750,7 +749,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportAlphaTracker(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - video_path = params["video_path"] if "video_path" in params else None labels = Labels.load_alphatracker( @@ -790,7 +788,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportNWB(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_nwb(filename=params["filename"]) new_window = context.app.__class__() @@ -823,7 +820,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportDeepPoseKit(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.from_deepposekit( filename=params["filename"], video_path=params["video_path"], @@ -872,7 +868,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportLEAP(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_leap_matlab( filename=params["filename"], ) @@ -903,7 +898,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportCoco(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_coco( filename=params["filename"], img_dir=params["img_dir"], use_missing_gui=True ) @@ -935,7 +929,6 @@ def ask(context: "CommandContext", params: dict) -> bool: class ImportDeepLabCut(AppCommand): @staticmethod def do_action(context: "CommandContext", params: dict): - labels = Labels.load_deeplabcut(filename=params["filename"]) new_window = context.app.__class__() @@ -1309,7 +1302,6 @@ def do_action(context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - from sleap.gui.dialogs.export_clip import ExportClipDialog dialog = ExportClipDialog() @@ -1585,7 +1577,6 @@ class GoNextSuggestedFrame(NavCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - next_suggestion_frame = context.labels.get_next_suggestion( context.state["video"], context.state["frame_idx"], cls.seek_direction ) @@ -1771,7 +1762,6 @@ class ReplaceVideo(EditCommand): @staticmethod def do_action(context: CommandContext, params: dict) -> bool: - import_list = params["import_list"] for import_item, video in import_list: @@ -1900,7 +1890,6 @@ def ask(context: CommandContext, params: dict) -> bool: video_file_names = [] total_num_labeled_frames = 0 for idx in row_idxs: - video = videos[idx] if video is None: return False @@ -1945,7 +1934,6 @@ def load_skeleton(filename: str): def compare_skeletons( skeleton: Skeleton, new_skeleton: Skeleton ) -> Tuple[List[str], List[str], List[str]]: - delete_nodes = [] add_nodes = [] if skeleton.node_names != new_skeleton.node_names: @@ -2724,7 +2712,6 @@ class GenerateSuggestions(EditCommand): @classmethod def do_action(cls, context: CommandContext, params: dict): - if len(context.labels.videos) == 0: print("Error: no videos to generate suggestions for") return @@ -2852,21 +2839,6 @@ def ask_and_do(cls, context: CommandContext, params: dict): class AddInstance(EditCommand): topics = [UpdateTopic.frame, UpdateTopic.project_instances, UpdateTopic.suggestions] - @staticmethod - def get_previous_frame_index(context: CommandContext) -> Optional[int]: - frames = context.labels.frames( - context.state["video"], - from_frame_idx=context.state["frame_idx"], - reverse=True, - ) - - try: - next_idx = next(frames).frame_idx - except: - return - - return next_idx - @classmethod def do_action(cls, context: CommandContext, params: dict): copy_instance = params.get("copy_instance", None) @@ -2880,6 +2852,175 @@ def do_action(cls, context: CommandContext, params: dict): if len(context.state["skeleton"]) == 0: return + ( + copy_instance, + from_predicted, + from_prev_frame, + ) = AddInstance.find_instance_to_copy_from( + context, copy_instance=copy_instance, init_method=init_method + ) + + new_instance = AddInstance.create_new_instance( + context=context, + from_predicted=from_predicted, + copy_instance=copy_instance, + mark_complete=mark_complete, + init_method=init_method, + location=location, + from_prev_frame=from_prev_frame, + ) + + # Add the instance + context.labels.add_instance(context.state["labeled_frame"], new_instance) + + if context.state["labeled_frame"] not in context.labels.labels: + context.labels.append(context.state["labeled_frame"]) + + @staticmethod + def create_new_instance( + context: CommandContext, + from_predicted: bool, + copy_instance: Optional[Instance], + mark_complete: bool, + init_method: str, + location: Optional[QtCore.QPoint], + from_prev_frame: bool, + ): + """Create new instance.""" + + # Now create the new instance + new_instance = Instance( + skeleton=context.state["skeleton"], + from_predicted=from_predicted, + frame=context.state["labeled_frame"], + ) + + has_missing_nodes = AddInstance.set_visible_nodes( + context=context, + copy_instance=copy_instance, + new_instance=new_instance, + mark_complete=mark_complete, + ) + + if has_missing_nodes: + AddInstance.fill_missing_nodes( + context=context, + copy_instance=copy_instance, + init_method=init_method, + new_instance=new_instance, + location=location, + ) + + # If we're copying a predicted instance or from another frame, copy the track + if hasattr(copy_instance, "score") or from_prev_frame: + new_instance.track = copy_instance.track + + return new_instance + + @staticmethod + def fill_missing_nodes( + context: CommandContext, + copy_instance: Optional[Instance], + init_method: str, + new_instance: Instance, + location: Optional[QtCore.QPoint], + ): + """Fill in missing nodes for new instance. + + Args: + context: The command context. + copy_instance: The instance to copy from. + init_method: The initialization method. + new_instance: The new instance. + location: The location of the instance. + + Returns: + None + """ + + # mark the node as not "visible" if we're copying from a predicted instance without this node + is_visible = copy_instance is None or (not hasattr(copy_instance, "score")) + + if init_method == "force_directed": + AddMissingInstanceNodes.add_force_directed_nodes( + context=context, + instance=new_instance, + visible=is_visible, + center_point=location, + ) + elif init_method == "random": + AddMissingInstanceNodes.add_random_nodes( + context=context, instance=new_instance, visible=is_visible + ) + elif init_method == "template": + AddMissingInstanceNodes.add_nodes_from_template( + context=context, + instance=new_instance, + visible=is_visible, + center_point=location, + ) + else: + AddMissingInstanceNodes.add_best_nodes( + context=context, instance=new_instance, visible=is_visible + ) + + @staticmethod + def set_visible_nodes( + context: CommandContext, + copy_instance: Optional[Instance], + new_instance: Instance, + mark_complete: bool, + ) -> Tuple[Instance, bool]: + """Sets visible nodes for new instance. + + Args: + context: The command context. + copy_instance: The instance to copy from. + new_instance: The new instance. + mark_complete: Whether to mark the instance as complete. + + Returns: + Whether the new instance has missing nodes. + """ + + if copy_instance is None: + return True + + has_missing_nodes = False + + # go through each node in skeleton + for node in context.state["skeleton"].node_names: + # if we're copying from a skeleton that has this node + if node in copy_instance and not copy_instance[node].isnan(): + # just copy x, y, and visible + # we don't want to copy a PredictedPoint or score attribute + new_instance[node] = Point( + x=copy_instance[node].x, + y=copy_instance[node].y, + visible=copy_instance[node].visible, + complete=mark_complete, + ) + else: + has_missing_nodes = True + + return has_missing_nodes + + @staticmethod + def find_instance_to_copy_from( + context: CommandContext, copy_instance: Optional[Instance], init_method: bool + ) -> Tuple[Optional[Instance], bool, bool]: + """Find instance to copy from. + + Args: + context: The command context. + copy_instance: The instance to copy from. + init_method: The initialization method. + + Returns: + The instance to copy from, whether it's from a predicted instance, and + whether it's from a previous frame. + """ + from_predicted = copy_instance from_prev_frame = False @@ -2905,7 +3046,7 @@ def do_action(cls, context: CommandContext, params: dict): ) or init_method == "prior_frame": # Otherwise, if there are instances in previous frames, # copy the points from one of those instances. - prev_idx = cls.get_previous_frame_index(context) + prev_idx = AddInstance.get_previous_frame_index(context) if prev_idx is not None: prev_instances = context.labels.find( @@ -2931,70 +3072,24 @@ def do_action(cls, context: CommandContext, params: dict): from_predicted = from_predicted if hasattr(from_predicted, "score") else None - # Now create the new instance - new_instance = Instance( - skeleton=context.state["skeleton"], - from_predicted=from_predicted, - frame=context.state["labeled_frame"], - ) - - has_missing_nodes = False + return copy_instance, from_predicted, from_prev_frame - # go through each node in skeleton - for node in context.state["skeleton"].node_names: - # if we're copying from a skeleton that has this node - if ( - copy_instance is not None - and node in copy_instance - and not copy_instance[node].isnan() - ): - # just copy x, y, and visible - # we don't want to copy a PredictedPoint or score attribute - new_instance[node] = Point( - x=copy_instance[node].x, - y=copy_instance[node].y, - visible=copy_instance[node].visible, - complete=mark_complete, - ) - else: - has_missing_nodes = True - - if has_missing_nodes: - # mark the node as not "visible" if we're copying from a predicted instance without this node - is_visible = copy_instance is None or (not hasattr(copy_instance, "score")) - - if init_method == "force_directed": - AddMissingInstanceNodes.add_force_directed_nodes( - context=context, - instance=new_instance, - visible=is_visible, - center_point=location, - ) - elif init_method == "random": - AddMissingInstanceNodes.add_random_nodes( - context=context, instance=new_instance, visible=is_visible - ) - elif init_method == "template": - AddMissingInstanceNodes.add_nodes_from_template( - context=context, - instance=new_instance, - visible=is_visible, - center_point=location, - ) - else: - AddMissingInstanceNodes.add_best_nodes( - context=context, instance=new_instance, visible=is_visible - ) + @staticmethod + def get_previous_frame_index(context: CommandContext) -> Optional[int]: + """Returns index of previous frame.""" - # If we're copying a predicted instance or from another frame, copy the track - if hasattr(copy_instance, "score") or from_prev_frame: - new_instance.track = copy_instance.track + frames = context.labels.frames( + context.state["video"], + from_frame_idx=context.state["frame_idx"], + reverse=True, + ) - # Add the instance - context.labels.add_instance(context.state["labeled_frame"], new_instance) + try: + next_idx = next(frames).frame_idx + except: + return - if context.state["labeled_frame"] not in context.labels.labels: - context.labels.append(context.state["labeled_frame"]) + return next_idx class SetInstancePointLocations(EditCommand):