Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefer user instances when calling Labels.numpy() #996

Merged
merged 1 commit into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 56 additions & 37 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
Any,
Set,
Callable,
cast,
)

import attr
Expand Down Expand Up @@ -2401,44 +2402,63 @@ def numpy(
This method assumes that instances have tracks assigned and is intended to
function primarily for single-video prediction results.
"""

def set_track(
inst: Union[Instance, PredictedInstance],
track: np.ndarray,
return_confidence: bool,
):
if return_confidence:
if isinstance(inst, PredictedInstance):
track = inst.points_and_scores_array
else:
track[:, :-1] = inst.numpy()
else:
track = inst.numpy()
return track

# Get labeled frames for specified video.
if video is None:
video = 0
if type(video) == int:
video = self.videos[video]
lfs = self.find(video=video)
try:
if video is None:
video = self.videos[0]
if type(video) == int:
video = self.videos[video]
video = cast(Video, video) # video should now be of type Video
except IndexError as e:
raise IndexError(
f"There are no videos in this project. No points matrix to return."
)

lfs: List[LabeledFrame] = self.find(video=video)

# Figure out frame index range.
if all_frames:
first_frame, last_frame = 0, video.shape[0] - 1
else:
first_frame, last_frame = None, None
for lf in lfs:
if first_frame is None:
first_frame = lf.frame_idx
if last_frame is None:
last_frame = lf.frame_idx
first_frame = min(first_frame, lf.frame_idx)
last_frame = max(last_frame, lf.frame_idx)
frame_idxs = [lf.frame_idx for lf in lfs]
frame_idxs.sort()
first_frame = 0 if all_frames else frame_idxs[0]

# Figure out the number of tracks based on number of instances in each frame.
#
# First, let's check the max number of predicted instances (regardless of
# First, let's check the max number of instances (regardless of
# whether they're tracked.
n_preds = 0
for lf in lfs:
n_preds = max(n_preds, lf.n_predicted_instances)
n_insts = max(
[
lf.n_user_instances
if lf.n_user_instances > 0 # take user instances over predicted
else lf.n_predicted_instances
for lf in lfs
]
)

# Case 1: We don't care about order because there's only 1 instance per frame,
# or we're considering untracked instances.
untracked = untracked or n_preds == 1
untracked = untracked or n_insts == 1
if untracked:
n_tracks = n_preds
# Case 1: We don't care about order because there's only 1 instance per
# frame, or we're considering untracked instances.
n_tracks = n_insts
else:
# Case 2: We're considering only tracked instances.
n_tracks = len(self.tracks)

n_frames = last_frame - first_frame + 1
n_frames = frame_idxs[-1] - first_frame + 1
n_nodes = len(self.skeleton.nodes)

if return_confidence:
Expand All @@ -2447,21 +2467,20 @@ def numpy(
tracks = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, dtype="float32")
for lf in lfs:
i = lf.frame_idx - first_frame
lf_insts: Union[List[Instance], List[PredictedInstance]] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List[Union[Instance, PredictedInstance]] for mixed lists?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a labeled frame has ANY user instances in the frame, then we ONLY use the user instances (the frame is already completely labeled). The lists are never mixed.

lf.user_instances if lf.n_user_instances > 0 else lf.predicted_instances
) # Prefer user labeled instances over predicted
if untracked:
for j, inst in enumerate(lf.predicted_instances):
tracks[i, j] = (
inst.points_and_scores_array
if return_confidence
else inst.numpy()
)
# Add instances in any order if untracked
for j, inst in enumerate(lf_insts):
tracks[i, j] = set_track(inst, tracks[i, j], return_confidence)
else:
for inst in lf.tracked_instances:
# Add instances in track-specific order, ignoring instances w/o a track
for inst in lf_insts:
if inst.track is None:
continue
j = self.tracks.index(inst.track)
tracks[i, j] = (
inst.points_and_scores_array
if return_confidence
else inst.numpy()
)
tracks[i, j] = set_track(inst, tracks[i, j], return_confidence)

return tracks

Expand Down
12 changes: 11 additions & 1 deletion tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def test_remove_predictions_with_new_labels(removal_test_labels):
assert labels[1].has_predicted_instances


def test_labels_numpy(centered_pair_predictions):
def test_labels_numpy(centered_pair_predictions: Labels):
trx = centered_pair_predictions.numpy(video=None, all_frames=False, untracked=False)
assert trx.shape == (1100, 27, 24, 2)

Expand Down Expand Up @@ -1366,6 +1366,16 @@ def test_labels_numpy(centered_pair_predictions):
centered_pair_predictions.tracks = []
assert centered_pair_predictions.numpy(untracked=False).shape == (1100, 0, 24, 2)

# Test labels.numpy prefers user instances
skeleton = centered_pair_predictions.skeleton
lf = centered_pair_predictions.labeled_frames[0]
user_inst = Instance(
skeleton=skeleton, points={node: Point(1, 1) for node in skeleton.nodes}
)
lf.instances.append(user_inst)
labels_np = centered_pair_predictions.numpy(untracked=True, return_confidence=True)
np.testing.assert_array_equal(labels_np[lf.frame_idx, 0, :, :-1], user_inst.numpy())


def test_remove_track(centered_pair_predictions):
labels = centered_pair_predictions
Expand Down