From ebfc47b300a61bf1a649d8f5eed780039318d73c Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:36:24 -0700 Subject: [PATCH] Add `InstancesList` class to handle backref to `LabeledFrame` (#1807) * Add InstancesList class to handle backref to LabeledFrame * Register structure/unstructure hooks for InstancesList * Add tests for the InstanceList class * Handle case where instance are passed in but labeled_frame is None * Add tests relevant methods in LabeledFrame * Delegate setting frame to InstancesList * Add test for PredictedInstance.frame after complex merge * Add todo comment to not use Instance.frame * Add rtest for InstasnceList.remove * Use normal list for informative `merged_instances` * Add test for copy and clear * Add copy and clear methods, use normal lists in merge method --- sleap/instance.py | 167 ++++++++++++++++++++++++----- tests/test_instance.py | 231 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 363 insertions(+), 35 deletions(-) diff --git a/sleap/instance.py b/sleap/instance.py index c14038552..67e96f330 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -364,7 +364,7 @@ class Instance: from_predicted: Optional["PredictedInstance"] = attr.ib(default=None) _points: PointArray = attr.ib(default=None) _nodes: List = attr.ib(default=None) - frame: Union["LabeledFrame", None] = attr.ib(default=None) + frame: Union["LabeledFrame", None] = attr.ib(default=None) # TODO(LM): Make private # The underlying Point array type that this instances point array should be. _point_array_type = PointArray @@ -1214,6 +1214,9 @@ def unstructure_instance(x: Instance): converter.register_unstructure_hook(Instance, unstructure_instance) converter.register_unstructure_hook(PredictedInstance, unstructure_instance) + converter.register_unstructure_hook( + InstancesList, lambda x: [converter.unstructure(inst) for inst in x] + ) ## STRUCTURE HOOKS @@ -1247,6 +1250,7 @@ def structure_instances_list(x, type): converter.register_structure_hook( Union[List[Instance], List[PredictedInstance]], structure_instances_list ) + converter.register_structure_hook(InstancesList, structure_instances_list) # Structure forward reference for PredictedInstance for the Instance.from_predicted # attribute. @@ -1278,6 +1282,127 @@ def structure_point_array(x, t): return converter +class InstancesList(list): + """A list of `Instance`s associated with a `LabeledFrame`. + + This class should only be used for the `LabeledFrame.instances` attribute. + """ + + def __init__(self, *args, labeled_frame: Optional["LabeledFrame"] = None): + super(InstancesList, self).__init__(*args) + + # Set the labeled frame for each instance + self.labeled_frame = labeled_frame + + @property + def labeled_frame(self) -> "LabeledFrame": + """Return the `LabeledFrame` associated with this list of instances.""" + + return self._labeled_frame + + @labeled_frame.setter + def labeled_frame(self, labeled_frame: "LabeledFrame"): + """Set the `LabeledFrame` associated with this list of instances. + + This updates the `frame` attribute on each instance. + + Args: + labeled_frame: The `LabeledFrame` to associate with this list of instances. + """ + + try: + # If the labeled frame is the same as the one we're setting, then skip + if self._labeled_frame == labeled_frame: + return + except AttributeError: + # Only happens on init and updates each instance.frame (even if None) + pass + + # Otherwise, update the frame for each instance + self._labeled_frame = labeled_frame + for instance in self: + instance.frame = labeled_frame + + def append(self, instance: Union[Instance, PredictedInstance]): + """Append an `Instance` or `PredictedInstance` to the list, setting the frame. + + Args: + item: The `Instance` or `PredictedInstance` to append to the list. + """ + + if not isinstance(instance, (Instance, PredictedInstance)): + raise ValueError( + f"InstancesList can only contain Instance or PredictedInstance objects," + f" but got {type(instance)}." + ) + instance.frame = self.labeled_frame + super().append(instance) + + def extend(self, instances: List[Union[PredictedInstance, Instance]]): + """Extend the list with a list of `Instance`s or `PredictedInstance`s. + + Args: + instances: A list of `Instance` or `PredictedInstance` objects to add to the + list. + + Returns: + None + """ + for instance in instances: + self.append(instance) + + def __delitem__(self, index): + """Remove instance (by index), and set instance.frame to None.""" + + instance: Instance = self.__getitem__(index) + super().__delitem__(index) + + # Modify the instance to remove reference to the frame + instance.frame = None + + def insert(self, index: int, instance: Union[Instance, PredictedInstance]) -> None: + super().insert(index, instance) + instance.frame = self.labeled_frame + + def __setitem__(self, index, instance: Union[Instance, PredictedInstance]): + """Set nth instance in frame to the given instance. + + Args: + index: The index of instance to replace with new instance. + value: The new instance to associate with frame. + + Returns: + None. + """ + super().__setitem__(index, instance) + instance.frame = self.labeled_frame + + def pop(self, index: int) -> Union[Instance, PredictedInstance]: + """Remove and return instance at index, setting instance.frame to None.""" + + instance = super().pop(index) + instance.frame = None + return instance + + def remove(self, instance: Union[Instance, PredictedInstance]) -> None: + """Remove instance from list, setting instance.frame to None.""" + super().remove(instance) + instance.frame = None + + def clear(self) -> None: + """Remove all instances from list, setting instance.frame to None.""" + for instance in self: + instance.frame = None + super().clear() + + def copy(self) -> list: + """Return a shallow copy of the list of instances as a list. + + Note: This will not return an `InstancesList` object, but a normal list. + """ + return list(self) + + @attr.s(auto_attribs=True, eq=False, repr=False, str=False) class LabeledFrame: """Holds labeled data for a single frame of a video. @@ -1290,9 +1415,7 @@ class LabeledFrame: video: Video = attr.ib() frame_idx: int = attr.ib(converter=int) - _instances: Union[List[Instance], List[PredictedInstance]] = attr.ib( - default=attr.Factory(list) - ) + _instances: InstancesList = attr.ib(default=attr.Factory(InstancesList)) def __attrs_post_init__(self): """Called by attrs. @@ -1302,8 +1425,7 @@ def __attrs_post_init__(self): """ # Make sure all instances have a reference to this frame - for instance in self.instances: - instance.frame = self + self.instances = self._instances def __len__(self) -> int: """Return number of instances associated with frame.""" @@ -1319,13 +1441,8 @@ def index(self, value: Instance) -> int: def __delitem__(self, index): """Remove instance (by index) from frame.""" - value = self.instances.__getitem__(index) - self.instances.__delitem__(index) - # Modify the instance to remove reference to this frame - value.frame = None - def __repr__(self) -> str: """Return a readable representation of the LabeledFrame.""" return ( @@ -1348,9 +1465,6 @@ def insert(self, index: int, value: Instance): """ self.instances.insert(index, value) - # Modify the instance to have a reference back to this frame - value.frame = self - def __setitem__(self, index, value: Instance): """Set nth instance in frame to the given instance. @@ -1363,9 +1477,6 @@ def __setitem__(self, index, value: Instance): """ self.instances.__setitem__(index, value) - # Modify the instance to have a reference back to this frame - value.frame = self - def find( self, track: Optional[Union[Track, int]] = -1, user: bool = False ) -> List[Instance]: @@ -1393,7 +1504,7 @@ def instances(self) -> List[Instance]: return self._instances @instances.setter - def instances(self, instances: List[Instance]): + def instances(self, instances: Union[InstancesList, List[Instance]]): """Set the list of instances associated with this frame. Updates the `frame` attribute on each instance to the @@ -1408,9 +1519,11 @@ def instances(self, instances: List[Instance]): None """ - # Make sure to set the frame for each instance to this LabeledFrame - for instance in instances: - instance.frame = self + # Make sure to set the LabeledFrame for each instance to this frame + if isinstance(instances, InstancesList): + instances.labeled_frame = self + else: + instances = InstancesList(instances, labeled_frame=self) self._instances = instances @@ -1685,22 +1798,20 @@ def complex_frame_merge( * list of conflicting instances from base * list of conflicting instances from new """ - merged_instances = [] - redundant_instances = [] - extra_base_instances = copy(base_frame.instances) - extra_new_instances = [] + merged_instances: List[Instance] = [] # Only used for informing user + redundant_instances: List[Instance] = [] + extra_base_instances: List[Instance] = list(base_frame.instances) + extra_new_instances: List[Instance] = [] for new_inst in new_frame: redundant = False for base_inst in base_frame.instances: if new_inst.matches(base_inst): - base_inst.frame = None extra_base_instances.remove(base_inst) redundant_instances.append(base_inst) redundant = True continue if not redundant: - new_inst.frame = None extra_new_instances.append(new_inst) conflict = False @@ -1732,7 +1843,7 @@ def complex_frame_merge( else: # No conflict, so include all instances in base base_frame.instances.extend(extra_new_instances) - merged_instances = copy(extra_new_instances) + merged_instances: List[Instance] = copy(extra_new_instances) extra_base_instances = [] extra_new_instances = [] diff --git a/tests/test_instance.py b/tests/test_instance.py index 74a8b192e..58a630a8b 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -1,19 +1,21 @@ -import os -import math import copy +import math +import os +from typing import List -import pytest import numpy as np +import pytest -from sleap.skeleton import Skeleton +from sleap import Labels from sleap.instance import ( Instance, - PredictedInstance, + InstancesList, + LabeledFrame, Point, + PredictedInstance, PredictedPoint, - LabeledFrame, ) -from sleap import Labels +from sleap.skeleton import Skeleton def test_instance_node_get_set_item(skeleton): @@ -310,6 +312,8 @@ def test_frame_merge_predicted_and_user(skeleton, centered_pair_vid): # and we want to retain both even though they perfectly match. assert user_inst in user_frame.instances assert pred_inst in user_frame.instances + assert user_inst.frame == user_frame + assert pred_inst.frame == user_frame assert len(user_frame.instances) == 2 @@ -529,3 +533,216 @@ def test_instance_structuring_from_predicted(centered_pair_predictions): # Unstructure -> structure labels_copy = labels.copy() + + +def test_instances_list(centered_pair_predictions): + + labels = centered_pair_predictions + + def test_extend(instances: InstancesList, list_of_instances: List[Instance]): + instances.extend(list_of_instances) + assert len(instances) == len(list_of_instances) + for instance in instances: + assert isinstance(instance, PredictedInstance) + if instances.labeled_frame is None: + assert instance.frame is None + else: + assert instance.frame == instances.labeled_frame + + def test_append(instances: InstancesList, instance: Instance): + prev_len = len(instances) + instances.append(instance) + assert len(instances) == prev_len + 1 + assert instances[-1] == instance + assert instance.frame == instances.labeled_frame + + def test_labeled_frame_setter( + instances: InstancesList, labeled_frame: LabeledFrame + ): + instances.labeled_frame = labeled_frame + for instance in instances: + assert instance.frame == labeled_frame + + # Case 1: Create an empty instances list + labeled_frame = labels.labeled_frames[0] + list_of_instances = list(labeled_frame.instances) + instances = InstancesList() + assert len(instances) == 0 + assert instances._labeled_frame is None + assert instances.labeled_frame is None + + # Extend instances list + assert not isinstance(list_of_instances, InstancesList) + assert isinstance(list_of_instances, list) + test_extend(instances, list_of_instances) + + # Set the labeled frame + test_labeled_frame_setter(instances, labeled_frame) + + # Case 2: Create an empy instances list but initialize the labeled frame + instances = InstancesList(labeled_frame=labeled_frame) + assert len(instances) == 0 + assert instances._labeled_frame == labeled_frame + assert instances.labeled_frame == labeled_frame + + # Extend instances to the list from a different labeled frame + labeled_frame = labels.labeled_frames[1] + list_of_instances = list(labeled_frame.instances) + test_extend(instances, list_of_instances) + + # Add instance to the list + instance = list_of_instances[0] + instance.frame = None + test_append(instances, instance) + + # Set the labeled frame + test_labeled_frame_setter(instances, labeled_frame) + + # Test InstancesList.copy + instances_copy = instances.copy() + assert len(instances_copy) == len(instances) + assert not isinstance(instances_copy, InstancesList) + assert isinstance(instances_copy, list) + + # Test InstancesList.clear + instances_in_instances = list(instances) + instances.clear() + assert len(instances) == 0 + for instance in instances_in_instances: + assert instance.frame is None + + # Case 3: Create an instances list with a list of instances + labeled_frame = labels.labeled_frames[0] + list_of_instances = list(labeled_frame.instances) + instances = InstancesList(list_of_instances) + assert len(instances) == len(list_of_instances) + assert instances._labeled_frame is None + assert instances.labeled_frame is None + for instance in instances: + assert instance.frame is None + + # Add instance to the list + instance = list_of_instances[0] + test_append(instances, instance) + + # Case 4: Create an instances list with a list of instances and initialize the frame + labeled_frame_1 = labels.labeled_frames[0] + labeled_frame_2 = labels.labeled_frames[1] + list_of_instances = list(labeled_frame_2.instances) + instances = InstancesList(list_of_instances, labeled_frame=labeled_frame_1) + assert len(instances) == len(list_of_instances) + assert instances._labeled_frame == labeled_frame + assert instances.labeled_frame == labeled_frame + for instance in instances: + assert instance.frame == labeled_frame + + # Test InstancesList.__delitem__ + instance_to_remove = instances[0] + del instances[0] + assert instance_to_remove not in instances + assert instance_to_remove.frame is None + + # Test InstancesList.insert + instances.insert(0, instance_to_remove) + assert instances[0] == instance_to_remove + assert instance_to_remove.frame == instances.labeled_frame + + # Test InstancesList.__setitem__ + new_instance = labeled_frame_1.instances[0] + new_instance.frame = None + instances[0] = new_instance + assert instances[0] == new_instance + assert new_instance.frame == instances.labeled_frame + + # Test InstancesList.pop + popped_instance = instances.pop(0) + assert popped_instance.frame is None + + # Test InstancesList.remove + instance_to_remove = instances[0] + instances.remove(instance_to_remove) + assert instance_to_remove.frame is None + assert instance_to_remove not in instances + + # Case 5: Create an instances list from an instances list + instances_1 = InstancesList(list_of_instances, labeled_frame=labeled_frame_1) + instances = InstancesList(instances_1) + assert len(instances) == len(instances_1) + assert instances._labeled_frame is None + assert instances.labeled_frame is None + for instance in instances: + assert instance.frame is None + + +def test_instances_list_with_labeled_frame(centered_pair_predictions): + labels: Labels = centered_pair_predictions + labels_lf_0: LabeledFrame = labels.labeled_frames[0] + video = labels_lf_0.video + frame_idx = labels_lf_0.frame_idx + + def test_post_init(labeled_frame: LabeledFrame): + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame + + # Create labeled frame from list of instances + instances = list(labels_lf_0.instances) + for instance in instances: + instance.frame = None # Change frame to None to test if it is set correctly + labeled_frame = LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + assert isinstance(labeled_frame.instances, InstancesList) + assert len(labeled_frame.instances) == len(instances) + test_post_init(labeled_frame) + + # Create labeled frame from instances list + instances = InstancesList(labels_lf_0.instances) + labeled_frame = LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) + assert isinstance(labeled_frame.instances, InstancesList) + assert len(labeled_frame.instances) == len(instances) + test_post_init(labeled_frame) + + # Test LabeledFrame.__len__ + assert len(labeled_frame.instances) == len(instances) + + # Test LabeledFrame.__getitem__ + assert labeled_frame[0] == instances[0] + + # Test LabeledFrame.index + assert labeled_frame.index(instances[0]) == instances.index(instances[0]) == 0 + + # Test LabeledFrame.__delitem__ + instance_to_remove = labeled_frame[0] + del labeled_frame[0] + assert instance_to_remove not in labeled_frame.instances + assert instance_to_remove.frame is None + + # Test LabeledFrame.__repr__ + print(labeled_frame) + + # Test LabeledFrame.insert + labeled_frame.insert(0, instance_to_remove) + assert labeled_frame[0] == instance_to_remove + assert instance_to_remove.frame == labeled_frame + + # Test LabeledFrame.__setitem__ + new_instance = instances[1] + new_instance.frame = None + labeled_frame[0] = new_instance + assert labeled_frame[0] == new_instance + assert new_instance.frame == labeled_frame + + # Test instances.setter (empty list) + labeled_frame.instances = [] + assert len(labeled_frame.instances) == 0 + assert labeled_frame.instances.labeled_frame == labeled_frame + # Test instances.setter (InstancesList) + labeled_frame.instances = labels.labeled_frames[1].instances + assert len(labeled_frame.instances) == len(labels.labeled_frames[1].instances) + assert labeled_frame.instances.labeled_frame == labeled_frame + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame + # Test instances.setter (populated list) + labeled_frame.instances = list(labels.labeled_frames[1].instances) + assert len(labeled_frame.instances) == len(labels.labeled_frames[1].instances) + assert labeled_frame.instances.labeled_frame == labeled_frame + for instance in labeled_frame.instances: + assert instance.frame == labeled_frame