Skip to content

Commit

Permalink
Add formatting fixtures, tracking_scores test
Browse files Browse the repository at this point in the history
  • Loading branch information
sheridana committed Mar 31, 2022
1 parent 9506b1c commit 7bb4f9f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
Binary file not shown.
12 changes: 12 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
TEST_SLP_MIN_LABELS_ROBOT = "tests/data/slp_hdf5/small_robot_minimal.slp"
TEST_MIN_TRACKS_2NODE_LABELS = "tests/data/tracks/clip.2node.slp"
TEST_MIN_TRACKS_13NODE_LABELS = "tests/data/tracks/clip.slp"
TEST_HDF5_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
TEST_SLP_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.slp"


@pytest.fixture
Expand Down Expand Up @@ -213,3 +215,13 @@ def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman):
labels = Labels(labels)

return labels


@pytest.fixture
def centered_pair_predictions_hdf5_path():
return TEST_HDF5_PREDICTIONS


@pytest.fixture
def centered_pair_predictions_slp_path():
return TEST_SLP_PREDICTIONS
37 changes: 31 additions & 6 deletions tests/io/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_missing_file():
disp.read("missing_file.txt")


def test_hdf5_v1(tmpdir):
filename = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
def test_hdf5_v1(tmpdir, centered_pair_predictions_hdf5_path):
filename = centered_pair_predictions_hdf5_path
disp = dispatch.Dispatch.make_dispatcher(adaptor.SleapObjectType.labels)

# Make sure reading works
Expand All @@ -99,8 +99,9 @@ def test_hdf5_v1(tmpdir):
assert len(y.labeled_frames) == 1100


def test_hdf5_v1_filehandle():
filename = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
def test_hdf5_v1_filehandle(centered_pair_predictions_hdf5_path):

filename = centered_pair_predictions_hdf5_path

labels = hdf5.LabelsV1Adaptor.read_headers(filehandle.FileHandle(filename))

Expand Down Expand Up @@ -142,11 +143,11 @@ def test_json_v1(tmpdir, centered_pair_labels):
assert len(y.labeled_frames) == len(centered_pair_labels.labeled_frames)


def test_matching_adaptor():
def test_matching_adaptor(centered_pair_predictions_hdf5_path):
from sleap.io.format import read

read(
"tests/data/hdf5_format_v1/centered_pair_predictions.h5",
centered_pair_predictions_hdf5_path,
for_object="labels",
as_format="*",
)
Expand Down Expand Up @@ -188,3 +189,27 @@ def test_madlc(test_data):
assert_array_equal(labels[1][1].numpy(), [[17, 18], [np.nan, np.nan], [20, 21]])
assert_array_equal(labels[2][0].numpy(), [[22, 23], [24, 25], [26, 27]])
assert labels[2].frame_idx == 3


def test_tracking_scores(tmpdir, centered_pair_predictions_slp_path):

# test reading
filename = centered_pair_predictions_slp_path

fh = filehandle.FileHandle(filename)

assert fh.format_id is not None

labels = hdf5.LabelsV1Adaptor.read(fh)

for instance in labels.instances():
assert hasattr(instance, "tracking_score")

# test writing
filename = os.path.join(tmpdir, "test.slp")
labels.save(filename)

labels = hdf5.LabelsV1Adaptor.read(filehandle.FileHandle(filename))

for instance in labels.instances():
assert hasattr(instance, "tracking_score")

0 comments on commit 7bb4f9f

Please sign in to comment.