Skip to content

Commit

Permalink
add tests for Labels and LabeledFrames (#460)
Browse files Browse the repository at this point in the history
* Add tests
- test `Labels.remove_untracked_instances()` for both cases of `remove_empty_frames: bool`
-test `LabeledFrames.remove_untracked()` for both user-labeled and predicted frames
  • Loading branch information
roomrys committed Mar 18, 2022
1 parent fdb2fe6 commit e6dbdfd
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,3 +1420,41 @@ def test_split(centered_pair_predictions):
assert len(labels_a) == 1
assert len(labels_b) == 1
assert labels_a[0] == labels_b[0]


def test_remove_untracked_instances(min_tracks_2node_labels):
"""Test removal of untracked instances and empty frames.
Args:
min_tracks_2node_labels: Labels object which contains user labeled frames with
tracked instances.
"""
# XXX(LM): only test user_labeled instances
# XXX(LM): can we ensure that the datasets will remain unchanged?
labels = min_tracks_2node_labels

# XXX(LM): should I remove multiple tracks and frames?
# XXX(LM): if len(labels.labeled_frames)==1,
# XXX(LM) then will not properly test lf.remove_untracked()
# Preprocessing
labels.labeled_frames[0].instances[0].track = None
labels.labeled_frames[-1].instances = []
assert any(
[(inst.track is None) for lf in labels.labeled_frames for inst in lf.instances]
)
assert any([(len(lf.instances) == 0) for lf in labels.labeled_frames])

# Test function with remove_empty_frames=False
labels.remove_untracked_instances(remove_empty_frames=False)
assert all(
[
(inst.track is not None)
for lf in labels.labeled_frames
for inst in lf.instances
]
)
assert any([(len(lf.instances) == 0) for lf in labels.labeled_frames])

# Test function with remove_empty_frames=True
labels.remove_untracked_instances(remove_empty_frames=True)
assert all([(len(lf.instances) > 0) for lf in labels.labeled_frames])
30 changes: 30 additions & 0 deletions tests/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,33 @@ def test_labeledframe_instance_counting(min_labels, centered_pair_predictions):
assert lf.n_tracked_instances == 0
assert len(lf.tracked_instances) == 0
assert not lf.has_tracked_instances


def test_labeledframe_remove_untracked(
min_tracks_2node_labels: "Labels", centered_pair_predictions: "Labels"
):
"""Test removal of untracked instances on both user-labeled and predicted frames.
Args:
min_tracks_2node_labels: Labels object which contains user labeled frames with tracked instances.
centered_pair_predictions: Labels object which contains predicted frames with tracked instances.
"""
# Load user-labeled frames.
lf = min_tracks_2node_labels.labeled_frames[0]
assert any([type(inst) == Instance for inst in lf.instances])

lf.instances[0].track = None
assert any([(inst.track is None) for inst in lf.instances])

lf.remove_untracked()
assert all([(inst.track is not None) for inst in lf.instances])

# Load predicted frames.
lf = centered_pair_predictions.labeled_frames[0]
assert any([type(inst) == PredictedInstance for inst in lf.instances])

lf.instances[0].track = None
assert any([(inst.track is None) for inst in lf.instances])

lf.remove_untracked()
assert all([(inst.track is not None) for inst in lf.instances])

0 comments on commit e6dbdfd

Please sign in to comment.