From 60619514e44e0215168220d5e63e4f59de574544 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 20 Dec 2024 11:00:55 -0800 Subject: [PATCH 1/4] Fix duplicate skeletons during merge --- sleap/io/dataset.py | 45 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 1b894089f..99d4badb1 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -467,20 +467,35 @@ def _update_from_labels(self, merge: bool = False): # Ditto for skeletons if merge or len(self.skeletons) == 0: - self.skeletons = list( - set(self.skeletons).union( - { - instance.skeleton - for label in self.labels - for instance in label.instances - } + + if not self.skeletons: + # if `labels.skeletons` is empty, then add all new skeletons + self.skeletons = list( + set(self.skeletons).union( + { + instance.skeleton + for label in self.labels + for instance in label.instances + } + ) ) - ) + + else: + for lf in self.labels: + for instance in lf.instances: + for skeleton in self.skeletons: + # check if the new skeleton is already in `labels.skeletons` + if not skeleton.matches(instance.skeleton): + self.skeletons.append(instance.skeleton) + else: + # assign the existing skeleton if the instance has duplicate skeleton + instance.skeleton = skeleton # Ditto for nodes if merge or len(self.nodes) == 0: + self.nodes = list( - set(self.nodes).union( + set().union( {node for skeleton in self.skeletons for node in skeleton.nodes} ) ) @@ -508,7 +523,15 @@ def _update_from_labels(self, merge: bool = False): ) # Get list of other tracks not already in track list - new_tracks = list(other_tracks - set(self.tracks)) + # new_tracks = list(other_tracks - set(self.tracks)) + new_tracks = [] + if not self.tracks: + new_tracks = list(other_tracks) + else: + for t in other_tracks: + for track in self.tracks: + if not track.matches(t): + new_tracks.append(t) # Sort the new tracks by spawned on and then name new_tracks.sort(key=lambda t: (t.spawned_on, t.name)) @@ -1898,7 +1921,7 @@ def to_dict(self, skip_labels: bool = False) -> Dict[str, Any]: # We shouldn't have to do this here, but for some reason we're missing nodes # which are in the skeleton but don't have points (in the first instance?). self.nodes = list( - set(self.nodes).union( + set().union( {node for skeleton in self.skeletons for node in skeleton.nodes} ) ) From 12081d7207ff7512b5284f01d3fafd38917908c9 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Fri, 20 Dec 2024 12:02:26 -0800 Subject: [PATCH 2/4] Fix merge skeletons --- sleap/io/dataset.py | 58 +++++++++++++++++++++++----------------- tests/io/test_dataset.py | 1 + 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 99d4badb1..6ce7057d0 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -466,34 +466,42 @@ def _update_from_labels(self, merge: bool = False): self.videos.extend(list(new_videos)) # Ditto for skeletons - if merge or len(self.skeletons) == 0: - - if not self.skeletons: - # if `labels.skeletons` is empty, then add all new skeletons - self.skeletons = list( - set(self.skeletons).union( - { - instance.skeleton - for label in self.labels - for instance in label.instances - } - ) + if len(self.skeletons) == 0: + # if `labels.skeletons` is empty, then add all new skeletons + self.skeletons = list( + set(self.skeletons).union( + { + instance.skeleton + for label in self.labels + for instance in label.instances + } ) + ) - else: - for lf in self.labels: - for instance in lf.instances: - for skeleton in self.skeletons: - # check if the new skeleton is already in `labels.skeletons` - if not skeleton.matches(instance.skeleton): - self.skeletons.append(instance.skeleton) - else: - # assign the existing skeleton if the instance has duplicate skeleton - instance.skeleton = skeleton - - # Ditto for nodes - if merge or len(self.nodes) == 0: + if len(self.nodes) == 0: + self.nodes = list( + set().union( + {node for skeleton in self.skeletons for node in skeleton.nodes} + ) + ) + if merge: + + # remove duplicate skeletons during merge + skeletons = [self.skeletons[0]] + for lf in self.labels: + for instance in lf.instances: + for skeleton in skeletons: + # check if the new skeleton is already in `labels.skeletons` + if not skeleton.matches(instance.skeleton): + skeletons.append(instance.skeleton) + else: + # assign the existing skeleton if the instance has duplicate skeleton + instance.skeleton = skeleton + + self.skeletons = skeletons + + # updates nodes after removing duplicate skeletons self.nodes = list( set().union( {node for skeleton in self.skeletons for node in skeleton.nodes} diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index d71d4cc83..52109fe2b 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -733,6 +733,7 @@ def test_dont_unify_skeletons(): skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") + # skeleton_b.add_node("foo") lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) From d0af4e2edebb3410773c71bef987927a23d00145 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Mon, 23 Dec 2024 08:21:29 -0800 Subject: [PATCH 3/4] Fix tests --- sleap/io/dataset.py | 27 +++++++++------------------ tests/gui/test_commands.py | 2 +- tests/io/test_dataset.py | 21 --------------------- 3 files changed, 10 insertions(+), 40 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 6ce7057d0..37bf5f640 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -480,13 +480,10 @@ def _update_from_labels(self, merge: bool = False): if len(self.nodes) == 0: self.nodes = list( - set().union( - {node for skeleton in self.skeletons for node in skeleton.nodes} - ) + set([node for skeleton in self.skeletons for node in skeleton.nodes]) ) if merge: - # remove duplicate skeletons during merge skeletons = [self.skeletons[0]] for lf in self.labels: @@ -503,13 +500,11 @@ def _update_from_labels(self, merge: bool = False): # updates nodes after removing duplicate skeletons self.nodes = list( - set().union( - {node for skeleton in self.skeletons for node in skeleton.nodes} - ) + set([node for skeleton in self.skeletons for node in skeleton.nodes]) ) # Ditto for tracks, a pattern is emerging here - if merge or len(self.tracks) == 0: + if len(self.tracks) == 0: # Get tracks from any Instances or PredictedInstances other_tracks = { instance.track @@ -531,13 +526,11 @@ def _update_from_labels(self, merge: bool = False): ) # Get list of other tracks not already in track list - # new_tracks = list(other_tracks - set(self.tracks)) - new_tracks = [] - if not self.tracks: - new_tracks = list(other_tracks) - else: - for t in other_tracks: - for track in self.tracks: + new_tracks = list(other_tracks - set(self.tracks)) + if self.tracks and merge: + new_tracks = [self.tracks[0]] + for track in other_tracks: + for t in new_tracks: if not track.matches(t): new_tracks.append(t) @@ -1929,9 +1922,7 @@ def to_dict(self, skip_labels: bool = False) -> Dict[str, Any]: # We shouldn't have to do this here, but for some reason we're missing nodes # which are in the skeleton but don't have points (in the first instance?). self.nodes = list( - set().union( - {node for skeleton in self.skeletons for node in skeleton.nodes} - ) + set([node for skeleton in self.skeletons for node in skeleton.nodes]) ) # Register some unstructure hooks since we don't want complete deserialization diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index e19e00236..2117934b4 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -70,7 +70,7 @@ def test_import_labels_from_dlc_folder(): assert len(labels.videos) == 2 assert len(labels.skeletons) == 1 assert len(labels.nodes) == 3 - assert len(labels.tracks) == 3 + assert len(labels.tracks) == 2 assert set( [fix_path_separator(l.video.backend.filename) for l in labels.labeled_frames] diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 52109fe2b..4aa7d65f6 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -728,27 +728,6 @@ def test_unify_skeletons(): labels.to_dict() -def test_dont_unify_skeletons(): - vid = Video.from_filename("foo.mp4") - - skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") - skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json") - # skeleton_b.add_node("foo") - - lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)]) - lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)]) - - labels = Labels(labeled_frames=[lf_a]) - labels.extend_from([lf_b], unify=False) - ids = skeleton_ids_from_label_instances(labels) - - # Make sure we still have two distinct skeleton objects - assert len(set(ids)) == 2 - - # Make sure we can serialize this - labels.to_dict() - - def test_instance_access(): labels = Labels() From 65a0fa07c59f09460c03f9e61c1ded8831c05ccf Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 2 Jan 2025 09:20:58 -0800 Subject: [PATCH 4/4] Fix merge tracks --- sleap/io/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 37bf5f640..38bed6087 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -483,7 +483,7 @@ def _update_from_labels(self, merge: bool = False): set([node for skeleton in self.skeletons for node in skeleton.nodes]) ) - if merge: + if self.skeletons and merge: # remove duplicate skeletons during merge skeletons = [self.skeletons[0]] for lf in self.labels: @@ -532,7 +532,7 @@ def _update_from_labels(self, merge: bool = False): for track in other_tracks: for t in new_tracks: if not track.matches(t): - new_tracks.append(t) + new_tracks.append(track) # Sort the new tracks by spawned on and then name new_tracks.sort(key=lambda t: (t.spawned_on, t.name))