Skip to content

Commit

Permalink
Merge branch 'develop' into shrivaths/changelog-announcement-1
Browse files Browse the repository at this point in the history
  • Loading branch information
shrivaths16 authored Oct 12, 2023
2 parents 83d34ad + 1e0627a commit 7038b60
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 46 deletions.
6 changes: 3 additions & 3 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ optional arguments:
-e [EXPORT_PATH], --export_path [EXPORT_PATH]
Path to output directory where the frozen model will be exported to.
Defaults to a folder named 'exported_model'.
-u, --unrag UNRAG
Convert ragged tensors into regular tensors with NaN padding.
Defaults to True.
-r, --ragged RAGGED
Keep tensors ragged if present. If ommited, convert
ragged tensors into regular tensors with NaN padding.
-n, --max_instances MAX_INSTANCES
Limit maximum number of instances in multi-instance models.
Not available for ID models. Defaults to None.
Expand Down
57 changes: 47 additions & 10 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import platform
import random
import re
import traceback
from logging import getLogger
from pathlib import Path
from typing import Callable, List, Optional, Tuple

Expand Down Expand Up @@ -85,6 +87,9 @@
from sleap.util import parse_uri_path


logger = getLogger(__name__)


class MainWindow(QMainWindow):
"""The SLEAP GUI application.
Expand All @@ -101,6 +106,7 @@ class MainWindow(QMainWindow):
def __init__(
self,
labels_path: Optional[str] = None,
labels: Optional[Labels] = None,
reset: bool = False,
no_usage_data: bool = False,
*args,
Expand All @@ -118,7 +124,7 @@ def __init__(
self.setAcceptDrops(True)

self.state = GuiState()
self.labels = Labels()
self.labels = labels or Labels()

self.commands = CommandContext(
state=self.state, app=self, update_callback=self.on_data_update
Expand Down Expand Up @@ -175,8 +181,10 @@ def __init__(
print("Restoring GUI state...")
self.restoreState(prefs["window state"])

if labels_path:
if labels_path is not None:
self.commands.loadProjectFile(filename=labels_path)
elif labels is not None:
self.commands.loadLabelsObject(labels=labels)
else:
self.state["project_loaded"] = False

Expand Down Expand Up @@ -254,7 +262,6 @@ def dragEnterEvent(self, event):
event.acceptProposedAction()

def dropEvent(self, event):

# Parse filenames
filenames = event.mimeData().data("text/uri-list").data().decode()
filenames = [parse_uri_path(f.strip()) for f in filenames.strip().split("\n")]
Expand Down Expand Up @@ -1594,8 +1601,12 @@ def _show_keyboard_shortcuts_window(self):
ShortcutDialog().exec_()


def main(args: Optional[list] = None):
"""Starts new instance of app."""
def create_sleap_label_parser():
"""Creates parser for `sleap-label` command line arguments.
Returns:
argparse.ArgumentParser: The parser.
"""

import argparse

Expand Down Expand Up @@ -1635,6 +1646,23 @@ def main(args: Optional[list] = None):
default=False,
)

return parser


def create_app():
"""Creates Qt application."""

app = QApplication([])
app.setApplicationName(f"SLEAP v{sleap.version.__version__}")
app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png")))

return app


def main(args: Optional[list] = None, labels: Optional[Labels] = None):
"""Starts new instance of app."""

parser = create_sleap_label_parser()
args = parser.parse_args(args)

if args.nonnative:
Expand All @@ -1646,17 +1674,26 @@ def main(args: Optional[list] = None):
# https://stackoverflow.com/q/64818879
os.environ["QT_MAC_WANTS_LAYER"] = "1"

app = QApplication([])
app.setApplicationName(f"SLEAP v{sleap.version.__version__}")
app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png")))
app = create_app()

window = MainWindow(
labels_path=args.labels_path, reset=args.reset, no_usage_data=args.no_usage_data
labels_path=args.labels_path,
labels=labels,
reset=args.reset,
no_usage_data=args.no_usage_data,
)
window.showMaximized()

# Disable GPU in GUI process. This does not affect subprocesses.
sleap.use_cpu_only()
try:
sleap.use_cpu_only()
except RuntimeError: # Visible devices cannot be modified after being initialized
logger.warning(
"Running processes on the GPU. Restarting your GUI should allow switching "
"back to CPU-only mode.\n"
"Received the following error when trying to switch back to CPU-only mode:"
)
traceback.print_exc()

# Print versions.
print()
Expand Down
25 changes: 12 additions & 13 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from enum import Enum
from glob import glob
from pathlib import Path, PurePath
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union

import attr
import cv2
Expand Down Expand Up @@ -260,16 +260,15 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None):
"""
self.execute(LoadLabelsObject, labels=labels, filename=filename)

def loadProjectFile(self, filename: str):
def loadProjectFile(self, filename: Union[str, Labels]):
"""Loads given labels file into GUI.
Args:
filename: The path to the saved labels dataset. If None,
then don't do anything.
filename: The path to the saved labels dataset or the `Labels` object.
If None, then don't do anything.
Returns:
None
"""
self.execute(LoadProjectFile, filename=filename)

Expand Down Expand Up @@ -647,9 +646,8 @@ def do_action(context: "CommandContext", params: dict):
Returns:
None.
"""
filename = params["filename"]
filename = params.get("filename", None) # If called with just a Labels object
labels: Labels = params["labels"]

context.state["labels"] = labels
Expand All @@ -669,7 +667,9 @@ def do_action(context: "CommandContext", params: dict):
context.state["video"] = labels.videos[0]

context.state["project_loaded"] = True
context.state["has_changes"] = params.get("changed_on_load", False)
context.state["has_changes"] = params.get("changed_on_load", False) or (
filename is None
)

# This is not listed as an edit command since we want a clean changestack
context.app.on_data_update([UpdateTopic.project, UpdateTopic.all])
Expand All @@ -683,17 +683,16 @@ def ask(context: "CommandContext", params: dict):
if len(filename) == 0:
return

gui_video_callback = Labels.make_gui_video_callback(
search_paths=[os.path.dirname(filename)], context=params
)

has_loaded = False
labels = None
if type(filename) == Labels:
if isinstance(filename, Labels):
labels = filename
filename = None
has_loaded = True
else:
gui_video_callback = Labels.make_gui_video_callback(
search_paths=[os.path.dirname(filename)], context=params
)
try:
labels = Labels.load_file(filename, video_search=gui_video_callback)
has_loaded = True
Expand Down
12 changes: 6 additions & 6 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4939,7 +4939,7 @@ def export_cli(args: Optional[list] = None):
export_model(
args.models,
args.export_path,
unrag_outputs=args.unrag,
unrag_outputs=(not args.ragged),
max_instances=args.max_instances,
)

Expand Down Expand Up @@ -4971,13 +4971,13 @@ def _make_export_cli_parser() -> argparse.ArgumentParser:
),
)
parser.add_argument(
"-u",
"--unrag",
"-r",
"--ragged",
action="store_true",
default=True,
default=False,
help=(
"Convert ragged tensors into regular tensors with NaN padding. "
"Defaults to True."
"Keep tensors ragged if present. If ommited, convert ragged tensors"
" into regular tensors with NaN padding."
),
)
parser.add_argument(
Expand Down
40 changes: 26 additions & 14 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_make_tracker_from_cli,
main as sleap_track,
export_cli as sleap_export,
_make_export_cli_parser,
)
from sleap.nn.tracking import (
MatchedFrameInstance,
Expand Down Expand Up @@ -925,7 +926,7 @@ def test_load_model(resize_input_shape, model_fixture_name, request):
predictor = load_model(model_path, resize_input_layer=resize_input_shape)

# Determine predictor type
for (fname, mname, ptype, ishape) in fname_mname_ptype_ishape:
for fname, mname, ptype, ishape in fname_mname_ptype_ishape:
if fname in model_fixture_name:
expected_model_name = mname
expected_predictor_type = ptype
Expand Down Expand Up @@ -966,7 +967,6 @@ def test_topdown_multi_size_inference(
def test_ensure_numpy(
min_centroid_model_path, min_centered_instance_model_path, min_labels_slp
):

model = load_model([min_centroid_model_path, min_centered_instance_model_path])

# each frame has same number of instances
Expand Down Expand Up @@ -1037,7 +1037,6 @@ def test_ensure_numpy(


def test_centroid_inference():

xv, yv = make_grid_vectors(image_height=12, image_width=12, output_stride=1)
points = tf.cast([[[1.75, 2.75]], [[3.75, 4.75]], [[5.75, 6.75]]], tf.float32)
cms = tf.expand_dims(make_multi_confmaps(points, xv, yv, sigma=1.5), axis=0)
Expand Down Expand Up @@ -1093,7 +1092,6 @@ def test_centroid_inference():


def export_frozen_graph(model, preds, output_path):

tensors = {}

for key, val in preds.items():
Expand All @@ -1120,7 +1118,6 @@ def export_frozen_graph(model, preds, output_path):
info = json.load(json_file)

for tensor_info in info["frozen_model_inputs"] + info["frozen_model_outputs"]:

saved_name = (
tensor_info.split("Tensor(")[1].split(", shape")[0].replace('"', "")
)
Expand All @@ -1137,7 +1134,6 @@ def export_frozen_graph(model, preds, output_path):


def test_single_instance_save(min_single_instance_robot_model_path, tmp_path):

single_instance_model = tf.keras.models.load_model(
min_single_instance_robot_model_path + "/best_model.h5", compile=False
)
Expand All @@ -1152,7 +1148,6 @@ def test_single_instance_save(min_single_instance_robot_model_path, tmp_path):


def test_centroid_save(min_centroid_model_path, tmp_path):

centroid_model = tf.keras.models.load_model(
min_centroid_model_path + "/best_model.h5", compile=False
)
Expand All @@ -1171,7 +1166,6 @@ def test_centroid_save(min_centroid_model_path, tmp_path):
def test_topdown_save(
min_centroid_model_path, min_centered_instance_model_path, min_labels_slp, tmp_path
):

centroid_model = tf.keras.models.load_model(
min_centroid_model_path + "/best_model.h5", compile=False
)
Expand All @@ -1195,7 +1189,6 @@ def test_topdown_save(
def test_topdown_id_save(
min_centroid_model_path, min_topdown_multiclass_model_path, min_labels_slp, tmp_path
):

centroid_model = tf.keras.models.load_model(
min_centroid_model_path + "/best_model.h5", compile=False
)
Expand All @@ -1217,7 +1210,6 @@ def test_topdown_id_save(


def test_single_instance_predictor_save(min_single_instance_robot_model_path, tmp_path):

# directly initialize predictor
predictor = SingleInstancePredictor.from_trained_models(
min_single_instance_robot_model_path, resize_input_layer=False
Expand Down Expand Up @@ -1254,10 +1246,33 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm
)


def test_make_export_cli():
models_path = r"psuedo/models/path"
export_path = r"psuedo/test/path"
max_instances = 5

parser = _make_export_cli_parser()

# Test default values
args = None
args, _ = parser.parse_known_args(args=args)
assert args.models is None
assert args.export_path == "exported_model"
assert not args.ragged
assert args.max_instances is None

# Test all arguments
cmd = f"-m {models_path} -e {export_path} -r -n {max_instances}"
args, _ = parser.parse_known_args(args=cmd.split())
assert args.models == [models_path]
assert args.export_path == export_path
assert args.ragged
assert args.max_instances == max_instances


def test_topdown_predictor_save(
min_centroid_model_path, min_centered_instance_model_path, tmp_path
):

# directly initialize predictor
predictor = TopDownPredictor.from_trained_models(
centroid_model_path=min_centroid_model_path,
Expand Down Expand Up @@ -1300,7 +1315,6 @@ def test_topdown_predictor_save(
def test_topdown_id_predictor_save(
min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path
):

# directly initialize predictor
predictor = TopDownMultiClassPredictor.from_trained_models(
centroid_model_path=min_centroid_model_path,
Expand Down Expand Up @@ -1478,7 +1492,6 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
# Run tracking on subset of frames using psuedo-implementation of
# sleap.nn.tracking.run_tracker
for lf in frames[:20]:

# Clear the tracks
for inst in lf.instances:
inst.track = None
Expand Down Expand Up @@ -1522,7 +1535,6 @@ def test_max_tracks_matching_queue(
frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx)

for lf in frames[:20]:

# Clear the tracks
for inst in lf.instances:
inst.track = None
Expand Down

0 comments on commit 7038b60

Please sign in to comment.