Skip to content

Commit

Permalink
Ensure frames to predict list is unique (#1293)
Browse files Browse the repository at this point in the history
* Ensure frames to predict list is unique

* Ensure frames to predict on are ordered correctly

* Better frame sorting
  • Loading branch information
roomrys authored May 18, 2023
1 parent 1e9026d commit d05e2bf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
3 changes: 2 additions & 1 deletion sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def cli_args(self):

# -Y represents endpoint of [X, Y) range but inference cli expects
# [X, Y-1] range (so add 1 since negative).
frame_int_list = [i + 1 if i < 0 else i for i in self.frames]
frame_int_list = list(set([i + 1 if i < 0 else i for i in self.frames]))
frame_int_list.sort(reverse=min(frame_int_list) < 0) # Assumes len of 2 if neg.

arg_list.extend(("--frames", ",".join(map(str, frame_int_list))))

Expand Down
19 changes: 16 additions & 3 deletions tests/gui/test_inference_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,22 @@ def test_scoped_key_dict():


@pytest.mark.parametrize(
"labels_path, video_path", [("labels.slp", "video.mp4"), (None, "video.mp4")]
"labels_path, video_path, frames",
[
("labels.slp", "video.mp4", [0, 1, 2]),
(None, "video.mp4", [0, -1]),
(None, "video.mp4", [1, -4]),
],
)
def test_inference_cli_builder(labels_path, video_path):
def test_inference_cli_builder(labels_path, video_path, frames):

inference_task = runners.InferenceTask(
trained_job_paths=["model1", "model2"],
inference_params={"tracking.tracker": "simple"},
)

item_for_inference = runners.VideoItemForInference(
video=Video.from_filename(video_path), frames=[1, 2, 3], labels_path=labels_path
video=Video.from_filename(video_path), frames=frames, labels_path=labels_path
)

cli_args, output_path = inference_task.make_predict_cli_call(item_for_inference)
Expand All @@ -62,6 +67,14 @@ def test_inference_cli_builder(labels_path, video_path):
assert "model1" in cli_args
assert "model2" in cli_args
assert "--frames" in cli_args

frames_idx = cli_args.index("--frames")
if -1 in frames:
assert cli_args[frames_idx + 1] == "0" # No redundant frames
elif -4 in frames:
assert cli_args[frames_idx + 1] == "1,-3" # Ordered correctly
else:
assert cli_args[frames_idx + 1] == ",".join([str(f) for f in frames])
assert "--tracking.tracker" in cli_args

assert output_path.startswith(data_path)
Expand Down

0 comments on commit d05e2bf

Please sign in to comment.