Skip to content

Commit

Permalink
Merging is now aware of the source video if available. (#462)
Browse files Browse the repository at this point in the history
* Merging is now aware of the source video if available.
- Save source video instead of package if saving without images
- Match using either current or source video filename when merging
- Integration test
  • Loading branch information
talmo authored Jan 22, 2021
1 parent b06d4c2 commit 8d287b2
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 7 deletions.
11 changes: 11 additions & 0 deletions sleap/io/format/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,17 @@ def write(

d["videos"] = Video.cattr().unstructure(new_videos)

else:
# Include the source video metadata if this was a package.
new_videos = []
for video in labels.videos:
if hasattr(video.backend, "_source_video"):
new_videos.append(video.backend._source_video)
else:
new_videos.append(video)
d["videos"] = Video.cattr().unstructure(new_videos)


with h5py.File(filename, "a") as f:

# Add all the JSON metadata
Expand Down
31 changes: 24 additions & 7 deletions sleap/io/format/labels_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def write(

# Ensure that filename ends with .json
# shutil will append .zip
filename = re.sub("(\.json)?(\.zip)?$", ".json", filename)
filename = re.sub("(\\.json)?(\\.zip)?$", ".json", filename)

# Write the json to the tmp directory, we will zip it up with the frame data.
full_out_filename = os.path.join(tmp_dir, os.path.basename(filename))
Expand Down Expand Up @@ -408,12 +408,29 @@ def from_json_data(
break
for idx, vid in enumerate(videos):
for old_vid in match_to.videos:
# compare last three parts of path
if vid.filename == old_vid.filename or weak_filename_match(
vid.filename, old_vid.filename
):
# use video from match
videos[idx] = old_vid

# Try to match videos using either their current or source filename
# if available.
old_vid_paths = [old_vid.filename]
if getattr(old_vid.backend, "has_embedded_images", False):
old_vid_paths.append(old_vid.backend._source_video.filename)

new_vid_paths = [vid.filename]
if getattr(vid.backend, "has_embedded_images", False):
new_vid_paths.append(vid.backend._source_video.filename)

is_match = False
for old_vid_path in old_vid_paths:
for new_vid_path in new_vid_paths:
if old_vid_path == new_vid_path or weak_filename_match(
old_vid_path, new_vid_path
):
is_match = True
videos[idx] = old_vid
break
if is_match:
break
if is_match:
break

suggestions = []
Expand Down
61 changes: 61 additions & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from pathlib import Path

import sleap
from sleap.skeleton import Skeleton
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track
from sleap.io.video import Video, MediaVideo
Expand Down Expand Up @@ -498,6 +499,66 @@ def test_merge_predictions():
assert not extra_b


def test_merge_with_package(min_labels_robot, tmpdir):
# Add a suggestion and save with images.
labels = min_labels_robot
labels.suggestions.append(
sleap.io.dataset.SuggestionFrame(video=labels.video, frame_idx=1)
)
pkg_path = os.path.join(tmpdir, "test.pkg.slp")
assert len(labels.predicted_instances) == 0
labels.save(pkg_path, with_images=True, embed_suggested=True)

# Load package.
labels_pkg = sleap.load_file(pkg_path)
assert isinstance(labels_pkg.video.backend, sleap.io.video.HDF5Video)
assert labels_pkg.video.backend.has_embedded_images
assert isinstance(labels_pkg.video.backend._source_video.backend, sleap.io.video.MediaVideo)
assert len(labels_pkg.predicted_instances) == 0

# Add prediction.
inst = labels_pkg.user_instances[0]
inst_pr = sleap.PredictedInstance.from_pointsarray(
inst.numpy(), skeleton=labels_pkg.skeleton
)
labels_pkg.append(sleap.LabeledFrame(
video=labels_pkg.suggestions[0].video,
frame_idx=labels_pkg.suggestions[0].frame_idx,
instances=[inst_pr])
)

# Save labels without image data.
preds_path = pkg_path + ".predictions.slp"
labels_pkg.save(preds_path)

# Load predicted labels created from package.
labels_pr = sleap.load_file(preds_path)
assert len(labels_pr.predicted_instances) == 1

# Merge with base labels.
base_video_path = labels.video.backend.filename
merged, extra_base, extra_new = sleap.Labels.complex_merge_between(labels, labels_pr)
assert len(labels.videos) == 1
assert labels.video.backend.filename == base_video_path
assert len(labels.predicted_instances) == 1
assert len(extra_base) == 0
assert len(extra_new) == 0
assert labels.predicted_instances[0].frame.frame_idx == 1

# Merge predictions to package instead.
labels_pkg = sleap.load_file(pkg_path)
labels_pr = sleap.load_file(preds_path)
assert len(labels_pkg.predicted_instances) == 0
base_video_path = labels_pkg.video.backend.filename
merged, extra_base, extra_new = sleap.Labels.complex_merge_between(labels_pkg, labels_pr)
assert len(labels_pkg.videos) == 1
assert labels_pkg.video.backend.filename == base_video_path
assert len(labels_pkg.predicted_instances) == 1
assert len(extra_base) == 0
assert len(extra_new) == 0
assert labels_pkg.predicted_instances[0].frame.frame_idx == 1


def skeleton_ids_from_label_instances(labels):
return list(map(id, (lf.instances[0].skeleton for lf in labels.labeled_frames)))

Expand Down

0 comments on commit 8d287b2

Please sign in to comment.