Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate skeletons during labels merge #2075

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ def _update_from_labels(self, merge: bool = False):
self.videos.extend(list(new_videos))

# Ditto for skeletons
if merge or len(self.skeletons) == 0:
if len(self.skeletons) == 0:
# if `labels.skeletons` is empty, then add all new skeletons
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
self.skeletons = list(
set(self.skeletons).union(
{
Expand All @@ -477,16 +478,33 @@ def _update_from_labels(self, merge: bool = False):
)
)

# Ditto for nodes
if merge or len(self.nodes) == 0:
if len(self.nodes) == 0:
self.nodes = list(
set(self.nodes).union(
{node for skeleton in self.skeletons for node in skeleton.nodes}
)
set([node for skeleton in self.skeletons for node in skeleton.nodes])
)

if self.skeletons and merge:
# remove duplicate skeletons during merge
skeletons = [self.skeletons[0]]
for lf in self.labels:
for instance in lf.instances:
for skeleton in skeletons:
# check if the new skeleton is already in `labels.skeletons`
if not skeleton.matches(instance.skeleton):
skeletons.append(instance.skeleton)
else:
# assign the existing skeleton if the instance has duplicate skeleton
instance.skeleton = skeleton

self.skeletons = skeletons

# updates nodes after removing duplicate skeletons
self.nodes = list(
set([node for skeleton in self.skeletons for node in skeleton.nodes])
)

# Ditto for tracks, a pattern is emerging here
if merge or len(self.tracks) == 0:
if len(self.tracks) == 0:
# Get tracks from any Instances or PredictedInstances
other_tracks = {
instance.track
Expand All @@ -509,6 +527,12 @@ def _update_from_labels(self, merge: bool = False):

# Get list of other tracks not already in track list
new_tracks = list(other_tracks - set(self.tracks))
if self.tracks and merge:
new_tracks = [self.tracks[0]]
for track in other_tracks:
for t in new_tracks:
if not track.matches(t):
new_tracks.append(track)

# Sort the new tracks by spawned on and then name
new_tracks.sort(key=lambda t: (t.spawned_on, t.name))
Expand Down Expand Up @@ -1898,9 +1922,7 @@ def to_dict(self, skip_labels: bool = False) -> Dict[str, Any]:
# We shouldn't have to do this here, but for some reason we're missing nodes
# which are in the skeleton but don't have points (in the first instance?).
self.nodes = list(
set(self.nodes).union(
{node for skeleton in self.skeletons for node in skeleton.nodes}
)
set([node for skeleton in self.skeletons for node in skeleton.nodes])
)

# Register some unstructure hooks since we don't want complete deserialization
Expand Down
2 changes: 1 addition & 1 deletion tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_import_labels_from_dlc_folder():
assert len(labels.videos) == 2
assert len(labels.skeletons) == 1
assert len(labels.nodes) == 3
assert len(labels.tracks) == 3
assert len(labels.tracks) == 2

assert set(
[fix_path_separator(l.video.backend.filename) for l in labels.labeled_frames]
Expand Down
20 changes: 0 additions & 20 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,26 +728,6 @@ def test_unify_skeletons():
labels.to_dict()


def test_dont_unify_skeletons():
vid = Video.from_filename("foo.mp4")

skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json")
skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json")

lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)])
lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)])

labels = Labels(labeled_frames=[lf_a])
labels.extend_from([lf_b], unify=False)
ids = skeleton_ids_from_label_instances(labels)

# Make sure we still have two distinct skeleton objects
assert len(set(ids)) == 2

# Make sure we can serialize this
labels.to_dict()


def test_instance_access():
labels = Labels()

Expand Down
Loading