Skip to content

Commit

Permalink
Prefer user instances when calling Labels.numpy() (#996)
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys authored Oct 19, 2022
1 parent 455788d commit 9df1816
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 38 deletions.
93 changes: 56 additions & 37 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Any,
Set,
Callable,
cast,
)

import attr
Expand Down Expand Up @@ -2401,44 +2402,63 @@ def numpy(
This method assumes that instances have tracks assigned and is intended to
function primarily for single-video prediction results.
"""

def set_track(
inst: Union[Instance, PredictedInstance],
track: np.ndarray,
return_confidence: bool,
):
if return_confidence:
if isinstance(inst, PredictedInstance):
track = inst.points_and_scores_array
else:
track[:, :-1] = inst.numpy()
else:
track = inst.numpy()
return track

# Get labeled frames for specified video.
if video is None:
video = 0
if type(video) == int:
video = self.videos[video]
lfs = self.find(video=video)
try:
if video is None:
video = self.videos[0]
if type(video) == int:
video = self.videos[video]
video = cast(Video, video) # video should now be of type Video
except IndexError as e:
raise IndexError(
f"There are no videos in this project. No points matrix to return."
)

lfs: List[LabeledFrame] = self.find(video=video)

# Figure out frame index range.
if all_frames:
first_frame, last_frame = 0, video.shape[0] - 1
else:
first_frame, last_frame = None, None
for lf in lfs:
if first_frame is None:
first_frame = lf.frame_idx
if last_frame is None:
last_frame = lf.frame_idx
first_frame = min(first_frame, lf.frame_idx)
last_frame = max(last_frame, lf.frame_idx)
frame_idxs = [lf.frame_idx for lf in lfs]
frame_idxs.sort()
first_frame = 0 if all_frames else frame_idxs[0]

# Figure out the number of tracks based on number of instances in each frame.
#
# First, let's check the max number of predicted instances (regardless of
# First, let's check the max number of instances (regardless of
# whether they're tracked.
n_preds = 0
for lf in lfs:
n_preds = max(n_preds, lf.n_predicted_instances)
n_insts = max(
[
lf.n_user_instances
if lf.n_user_instances > 0 # take user instances over predicted
else lf.n_predicted_instances
for lf in lfs
]
)

# Case 1: We don't care about order because there's only 1 instance per frame,
# or we're considering untracked instances.
untracked = untracked or n_preds == 1
untracked = untracked or n_insts == 1
if untracked:
n_tracks = n_preds
# Case 1: We don't care about order because there's only 1 instance per
# frame, or we're considering untracked instances.
n_tracks = n_insts
else:
# Case 2: We're considering only tracked instances.
n_tracks = len(self.tracks)

n_frames = last_frame - first_frame + 1
n_frames = frame_idxs[-1] - first_frame + 1
n_nodes = len(self.skeleton.nodes)

if return_confidence:
Expand All @@ -2447,21 +2467,20 @@ def numpy(
tracks = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, dtype="float32")
for lf in lfs:
i = lf.frame_idx - first_frame
lf_insts: Union[List[Instance], List[PredictedInstance]] = (
lf.user_instances if lf.n_user_instances > 0 else lf.predicted_instances
) # Prefer user labeled instances over predicted
if untracked:
for j, inst in enumerate(lf.predicted_instances):
tracks[i, j] = (
inst.points_and_scores_array
if return_confidence
else inst.numpy()
)
# Add instances in any order if untracked
for j, inst in enumerate(lf_insts):
tracks[i, j] = set_track(inst, tracks[i, j], return_confidence)
else:
for inst in lf.tracked_instances:
# Add instances in track-specific order, ignoring instances w/o a track
for inst in lf_insts:
if inst.track is None:
continue
j = self.tracks.index(inst.track)
tracks[i, j] = (
inst.points_and_scores_array
if return_confidence
else inst.numpy()
)
tracks[i, j] = set_track(inst, tracks[i, j], return_confidence)

return tracks

Expand Down
12 changes: 11 additions & 1 deletion tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def test_remove_predictions_with_new_labels(removal_test_labels):
assert labels[1].has_predicted_instances


def test_labels_numpy(centered_pair_predictions):
def test_labels_numpy(centered_pair_predictions: Labels):
trx = centered_pair_predictions.numpy(video=None, all_frames=False, untracked=False)
assert trx.shape == (1100, 27, 24, 2)

Expand Down Expand Up @@ -1366,6 +1366,16 @@ def test_labels_numpy(centered_pair_predictions):
centered_pair_predictions.tracks = []
assert centered_pair_predictions.numpy(untracked=False).shape == (1100, 0, 24, 2)

# Test labels.numpy prefers user instances
skeleton = centered_pair_predictions.skeleton
lf = centered_pair_predictions.labeled_frames[0]
user_inst = Instance(
skeleton=skeleton, points={node: Point(1, 1) for node in skeleton.nodes}
)
lf.instances.append(user_inst)
labels_np = centered_pair_predictions.numpy(untracked=True, return_confidence=True)
np.testing.assert_array_equal(labels_np[lf.frame_idx, 0, :, :-1], user_inst.numpy())


def test_remove_track(centered_pair_predictions):
labels = centered_pair_predictions
Expand Down

0 comments on commit 9df1816

Please sign in to comment.