Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix formatting to read and save tracking scores #693

Merged
merged 2 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions sleap/io/format/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@


class LabelsV1Adaptor(format.adaptor.Adaptor):
FORMAT_ID = 1.1
FORMAT_ID = 1.2

# 1.0 points with gridline coordinates, top left corner at (0, 0)
# 1.1 points with midpixel coordinates, top left corner at (-0.5, -0.5)
# 1.2 adds track score to read and write functions

@property
def handles(self):
Expand Down Expand Up @@ -137,14 +138,16 @@ def read(
f = file.file
labels = cls.read_headers(file, video_search, match_to)

format_id = file.format_id

frames_dset = f["frames"][:]
instances_dset = f["instances"][:]
points_dset = f["points"][:]
pred_points_dset = f["pred_points"][:]

# Shift the *non-predicted* points since these used to be saved with a gridline
# coordinate system.
if (file.format_id or 0) < 1.1:
if (format_id or 0) < 1.1:
points_dset[:]["x"] -= 0.5
points_dset[:]["y"] -= 0.5

Expand Down Expand Up @@ -184,6 +187,9 @@ def read(
track=track,
points=pred_points[i["point_id_start"] : i["point_id_end"]],
score=i["score"],
tracking_score=i["tracking_score"]
if (format_id is not None and format_id >= 1.2)
else 0.0,
)
instances.append(instance)

Expand Down Expand Up @@ -349,6 +355,7 @@ def append_unique(old, new):
("score", "f4"),
("point_id_start", "u8"),
("point_id_end", "u8"),
("tracking_score", "f4"),
]
)
frame_dtype = np.dtype(
Expand Down Expand Up @@ -431,9 +438,11 @@ def append_unique(old, new):
if instance_type is PredictedInstance:
score = instance.score
pid = pred_point_id + pred_point_id_offset
tracking_score = instance.tracking_score
else:
score = np.nan
pid = point_id + point_id_offset
tracking_score = np.nan

# Keep track of any from_predicted instance links, we will
# insert the correct instance_id in the dataset after we are
Expand All @@ -453,6 +462,7 @@ def append_unique(old, new):
score,
pid,
pid + len(parray),
tracking_score,
)

# If these are predicted points, copy them to the predicted point
Expand Down
Binary file not shown.
12 changes: 12 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
TEST_SLP_MIN_LABELS_ROBOT = "tests/data/slp_hdf5/small_robot_minimal.slp"
TEST_MIN_TRACKS_2NODE_LABELS = "tests/data/tracks/clip.2node.slp"
TEST_MIN_TRACKS_13NODE_LABELS = "tests/data/tracks/clip.slp"
TEST_HDF5_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
TEST_SLP_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.slp"


@pytest.fixture
Expand Down Expand Up @@ -213,3 +215,13 @@ def multi_skel_vid_labels(hdf5_vid, small_robot_mp4_vid, skeleton, stickman):
labels = Labels(labels)

return labels


@pytest.fixture
def centered_pair_predictions_hdf5_path():
return TEST_HDF5_PREDICTIONS


@pytest.fixture
def centered_pair_predictions_slp_path():
return TEST_SLP_PREDICTIONS
37 changes: 31 additions & 6 deletions tests/io/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_missing_file():
disp.read("missing_file.txt")


def test_hdf5_v1(tmpdir):
filename = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
def test_hdf5_v1(tmpdir, centered_pair_predictions_hdf5_path):
filename = centered_pair_predictions_hdf5_path
disp = dispatch.Dispatch.make_dispatcher(adaptor.SleapObjectType.labels)

# Make sure reading works
Expand All @@ -99,8 +99,9 @@ def test_hdf5_v1(tmpdir):
assert len(y.labeled_frames) == 1100


def test_hdf5_v1_filehandle():
filename = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
def test_hdf5_v1_filehandle(centered_pair_predictions_hdf5_path):

filename = centered_pair_predictions_hdf5_path

labels = hdf5.LabelsV1Adaptor.read_headers(filehandle.FileHandle(filename))

Expand Down Expand Up @@ -142,11 +143,11 @@ def test_json_v1(tmpdir, centered_pair_labels):
assert len(y.labeled_frames) == len(centered_pair_labels.labeled_frames)


def test_matching_adaptor():
def test_matching_adaptor(centered_pair_predictions_hdf5_path):
from sleap.io.format import read

read(
"tests/data/hdf5_format_v1/centered_pair_predictions.h5",
centered_pair_predictions_hdf5_path,
for_object="labels",
as_format="*",
)
Expand Down Expand Up @@ -188,3 +189,27 @@ def test_madlc(test_data):
assert_array_equal(labels[1][1].numpy(), [[17, 18], [np.nan, np.nan], [20, 21]])
assert_array_equal(labels[2][0].numpy(), [[22, 23], [24, 25], [26, 27]])
assert labels[2].frame_idx == 3


def test_tracking_scores(tmpdir, centered_pair_predictions_slp_path):

# test reading
filename = centered_pair_predictions_slp_path

fh = filehandle.FileHandle(filename)

assert fh.format_id is not None

labels = hdf5.LabelsV1Adaptor.read(fh)

for instance in labels.instances():
assert hasattr(instance, "tracking_score")

# test writing
filename = os.path.join(tmpdir, "test.slp")
labels.save(filename)

labels = hdf5.LabelsV1Adaptor.read(filehandle.FileHandle(filename))

for instance in labels.instances():
assert hasattr(instance, "tracking_score")