diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 230d8c13a..7cde31c63 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -691,16 +691,18 @@ class InstanceGroupTableModel(GenericTableModel): item: 'InstanceGroup' which has information about the instance group """ - properties = ("name", "frame index", "cameras", "instances") + properties = ("name", "score", "frame index", "cameras", "instances") def item_to_data(self, obj, item: InstanceGroup): - return { + data = { "name": item.name, + "score": "" if item.score is None else str(round(item.score, 2)), "frame index": item.frame_idx, "cameras": len(item.camera_cluster.cameras), "instances": len(item.instances), } + return data def get_item_color(self, instance_group: InstanceGroup, key: str): color_manager = self.context.app.color_manager diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index b6d09b09c..d7f0edf47 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -15,7 +15,7 @@ # from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer from sleap.instance import Instance, LabeledFrame, PredictedInstance from sleap.io.video import Video -from sleap.util import deep_iterable_converter +from sleap.util import compute_oks, deep_iterable_converter logger = logging.getLogger(__name__) @@ -437,6 +437,9 @@ class InstanceGroup: cameras: List of `Camcorder` objects that have an `Instance` associated. instances: List of `Instance` objects. instance_by_camcorder: Dictionary of `Instance` objects by `Camcorder`. + score: Optional score for the `InstanceGroup`. Setting the score will also + update the score for all `instances` already in the `InstanceGroup`. The + score for `instances` will not be updated upon initialization. """ _name: str = field() @@ -445,6 +448,7 @@ class InstanceGroup: _camcorder_by_instance: Dict[Instance, Camcorder] = field(factory=dict) _dummy_instance: Optional[Instance] = field(default=None) camera_cluster: Optional[CameraCluster] = field(default=None) + _score: Optional[float] = field(default=None) def __attrs_post_init__(self): """Initialize `InstanceGroup` object.""" @@ -566,7 +570,7 @@ def return_unique_name(cls, name_registry: Set[str]) -> str: return new_name @property - def instances(self) -> List[Instance]: + def instances(self) -> List[Union[Instance, PredictedInstance]]: """List of `Instance` objects.""" return list(self._instance_by_camcorder.values()) @@ -580,7 +584,34 @@ def instance_by_camcorder(self) -> Dict[Camcorder, Instance]: """Dictionary of `Instance` objects by `Camcorder`.""" return self._instance_by_camcorder - def numpy(self, pred_as_nan: bool = False, invisible_as_nan=True) -> np.ndarray: + @property + def score(self) -> Optional[float]: + """Score for the `InstanceGroup`.""" + return self._score + + @score.setter + def score(self, score: Optional[float]): + """Set the score for the `InstanceGroup`. + + Also sets the score for all instances in the `InstanceGroup` if they have a + `score` attribute. + + Args: + score: Score to set for the `InstanceGroup`. + """ + + for instance in self.instances: + if hasattr(instance, "score"): + instance.score = score + + self._score = score + + def numpy( + self, + pred_as_nan: bool = False, + invisible_as_nan=True, + cams_to_include: Optional[List[Camcorder]] = None, + ) -> np.ndarray: """Return instances as a numpy array of shape (n_views, n_nodes, 2). The ordering of views is based on the ordering of `Camcorder`s in the @@ -594,13 +625,19 @@ def numpy(self, pred_as_nan: bool = False, invisible_as_nan=True) -> np.ndarray: self.dummy_instance. Default is False. invisible_as_nan: If True, then replaces invisible points with nan. Default is True. + cams_to_include: List of `Camcorder`s to include in the numpy array. If + None, then all `Camcorder`s in the `CameraCluster` are included. Default + is None. Returns: Numpy array of shape (n_views, n_nodes, 2). """ instance_numpys: List[np.ndarray] = [] # len(M) x N x 2 - for cam in self.camera_cluster.cameras: + if cams_to_include is None: + cams_to_include = self.camera_cluster.cameras + + for cam in cams_to_include: instance = self.get_instance(cam) # Determine whether to use a dummy (all nan) instance @@ -825,6 +862,11 @@ def update_points( f"Camcorders in `cams_to_include` ({len(cams_to_include)})." ) + # Calculate OKS scores for the points + gt_points = self.numpy( + pred_as_nan=True, invisible_as_nan=True, cams_to_include=cams_to_include + ) # M x N x 2 + oks_scores = np.full((n_views, n_nodes), np.nan) for cam_idx, cam in enumerate(cams_to_include): # Get the instance for the cam instance: Optional[Instance] = self.get_instance(cam) @@ -834,11 +876,22 @@ def update_points( ) continue - # Update the points (and scores) for the (predicted) instance + # Compute the OKS score for the instance if it is a ground truth instance + if not isinstance(instance, PredictedInstance): + instance_oks = compute_oks( + gt_points[cam_idx, :, :], + points[cam_idx, :, :], + ) + oks_scores[cam_idx] = instance_oks + + # Update the points for the instance instance.update_points( points=points[cam_idx, :, :], exclude_complete=exclude_complete ) + # Update the score for the InstanceGroup to be the average OKS score + self.score = np.nanmean(oks_scores) # scalar + def __getitem__( self, idx_or_key: Union[int, Camcorder, Instance] ) -> Union[Camcorder, Instance]: @@ -873,7 +926,8 @@ def __len__(self): def __repr__(self): return ( f"{self.__class__.__name__}(name={self.name}, frame_idx={self.frame_idx}, " - f"instances:{len(self)}, camera_cluster={self.camera_cluster})" + f"score={self.score}, instances:{len(self)}, camera_cluster=" + f"{self.camera_cluster})" ) def __hash__(self) -> int: @@ -885,6 +939,7 @@ def from_instance_by_camcorder_dict( instance_by_camcorder: Dict[Camcorder, Instance], name: str, name_registry: Set[str], + score: Optional[float] = None, ) -> Optional["InstanceGroup"]: """Creates an `InstanceGroup` object from a dictionary. @@ -892,6 +947,8 @@ def from_instance_by_camcorder_dict( instance_by_camcorder: Dictionary with `Camcorder` keys and `Instance` values. name: Name to use for the `InstanceGroup`. name_registry: Set of names to check for uniqueness. + score: Optional score for the `InstanceGroup`. This will NOT update the + score of the `Instance`s within the `InstanceGroup`. Default is None. Raises: ValueError: If the `InstanceGroup` name is already in use. @@ -935,6 +992,7 @@ def from_instance_by_camcorder_dict( frame_idx=frame_idx, camera_cluster=camera_cluster, instance_by_camcorder=instance_by_camcorder_copy, + score=score, ) def to_dict( @@ -962,10 +1020,14 @@ def to_dict( for cam, instance in self._instance_by_camcorder.items() } - return { + instance_group_dict = { "name": self.name, "camcorder_to_lf_and_inst_idx_map": camcorder_to_lf_and_inst_idx_map, } + if self.score is not None: + instance_group_dict["score"] = str(round(self.score, 4)) + + return instance_group_dict @classmethod def from_dict( @@ -989,6 +1051,13 @@ def from_dict( `InstanceGroup` object. """ + # Get the score (if available) + score = ( + float(instance_group_dict["score"]) + if "score" in instance_group_dict + else None + ) + # Get the `Instance` objects camcorder_to_lf_and_inst_idx_map: Dict[ str, Tuple[str, str] @@ -1010,6 +1079,7 @@ def from_dict( instance_by_camcorder=instance_by_camcorder, name=instance_group_dict["name"], name_registry=name_registry, + score=score, ) @@ -1689,7 +1759,7 @@ def numpy( Returns: Numpy array of shape (M, T, N, 2) where M is the number of views (determined - by self.cames_to_include), T is the number of `InstanceGroup`s, N is the + by self.cams_to_include), T is the number of `InstanceGroup`s, N is the number of Nodes, and 2 is for x, y. """ @@ -1705,19 +1775,17 @@ def numpy( f"{self.instance_groups}" ) - instance_group_numpys: List[np.ndarray] = [] # len(T) M=all x N x 2 + instance_group_numpys: List[np.ndarray] = [] # len(T) M=include x N x 2 for instance_group in instance_groups: instance_group_numpy = instance_group.numpy( - pred_as_nan=pred_as_nan - ) # M=all x N x 2 + pred_as_nan=pred_as_nan, + cams_to_include=self.cams_to_include, + ) # M=include x N x 2 instance_group_numpys.append(instance_group_numpy) - frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2 - cams_to_include_mask = np.array( - [cam in self.cams_to_include for cam in self.session.cameras] - ) # M=all x 1 + frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=include x TxNx2 - return frame_group_numpy[cams_to_include_mask] # M=include x T x N x 2 + return frame_group_numpy # M=include x T x N x 2 def add_instance( self, @@ -2110,7 +2178,6 @@ def upsert_points( This will update the points for existing `Instance`s in the `InstanceGroup`s and also add new `Instance`s if they do not exist. - Included cams are specified by `FrameGroup.cams_to_include`. The ordering of the `InstanceGroup`s in `instance_groups` should match the diff --git a/sleap/nn/evals.py b/sleap/nn/evals.py index 002f8a143..936d0b4de 100644 --- a/sleap/nn/evals.py +++ b/sleap/nn/evals.py @@ -45,6 +45,7 @@ TopDownMultiClassPredictor, SingleInstancePredictor, ) +from sleap.util import compute_oks logger = logging.getLogger(__name__) @@ -113,143 +114,6 @@ def find_frame_pairs( return frame_pairs -def compute_instance_area(points: np.ndarray) -> np.ndarray: - """Compute the area of the bounding box of a set of keypoints. - - Args: - points: A numpy array of coordinates. - - Returns: - The area of the bounding box of the points. - """ - if points.ndim == 2: - points = np.expand_dims(points, axis=0) - - min_pt = np.nanmin(points, axis=-2) - max_pt = np.nanmax(points, axis=-2) - - return np.prod(max_pt - min_pt, axis=-1) - - -def compute_oks( - points_gt: np.ndarray, - points_pr: np.ndarray, - scale: Optional[float] = None, - stddev: float = 0.025, - use_cocoeval: bool = True, -) -> np.ndarray: - """Compute the object keypoints similarity between sets of points. - - Args: - points_gt: Ground truth instances of shape (n_gt, n_nodes, n_ed), - where n_nodes is the number of body parts/keypoint types, and n_ed - is the number of Euclidean dimensions (typically 2 or 3). Keypoints - that are missing/not visible should be represented as NaNs. - points_pr: Predicted instance of shape (n_pr, n_nodes, n_ed). - use_cocoeval: Indicates whether the OKS score is calculated like cocoeval - method or not. True indicating the score is calculated using the - cocoeval method (widely used and the code can be found here at - https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L192C5-L233C20) - and False indicating the score is calculated using the method exactly - as given in the paper referenced in the Notes below. - scale: Size scaling factor to use when weighing the scores, typically - the area of the bounding box of the instance (in pixels). This - should be of the length n_gt. If a scalar is provided, the same - number is used for all ground truth instances. If set to None, the - bounding box area of the ground truth instances will be calculated. - stddev: The standard deviation associated with the spread in the - localization accuracy of each node/keypoint type. This should be of - the length n_nodes. "Easier" keypoint types will have lower values - to reflect the smaller spread expected in localizing it. - - Returns: - The object keypoints similarity between every pair of ground truth and - predicted instance, a numpy array of of shape (n_gt, n_pr) in the range - of [0, 1.0], with 1.0 denoting a perfect match. - - Notes: - It's important to set the stddev appropriately when accounting for the - difficulty of each keypoint type. For reference, the median value for - all keypoint types in COCO is 0.072. The "easiest" keypoint is the left - eye, with stddev of 0.025, since it is easy to precisely locate the - eyes when labeling. The "hardest" keypoint is the left hip, with stddev - of 0.107, since it's hard to locate the left hip bone without external - anatomical features and since it is often occluded by clothing. - - The implementation here is based off of the descriptions in: - Ronch & Perona. "Benchmarking and Error Diagnosis in Multi-Instance Pose - Estimation." ICCV (2017). - """ - if points_gt.ndim == 2: - points_gt = np.expand_dims(points_gt, axis=0) - if points_pr.ndim == 2: - points_pr = np.expand_dims(points_pr, axis=0) - - if scale is None: - scale = compute_instance_area(points_gt) - - n_gt, n_nodes, n_ed = points_gt.shape # n_ed = 2 or 3 (euclidean dimensions) - n_pr = points_pr.shape[0] - - # If scalar scale was provided, use the same for each ground truth instance. - if np.isscalar(scale): - scale = np.full(n_gt, scale) - - # If scalar standard deviation was provided, use the same for each node. - if np.isscalar(stddev): - stddev = np.full(n_nodes, stddev) - - # Compute displacement between each pair. - displacement = np.reshape(points_gt, (n_gt, 1, n_nodes, n_ed)) - np.reshape( - points_pr, (1, n_pr, n_nodes, n_ed) - ) - assert displacement.shape == (n_gt, n_pr, n_nodes, n_ed) - - # Convert to pairwise Euclidean distances. - distance = (displacement ** 2).sum(axis=-1) # (n_gt, n_pr, n_nodes) - assert distance.shape == (n_gt, n_pr, n_nodes) - - # Compute the normalization factor per keypoint. - if use_cocoeval: - # If use_cocoeval is True, then compute normalization factor according to cocoeval. - spread_factor = (2 * stddev) ** 2 - scale_factor = 2 * (scale + np.spacing(1)) - else: - # If use_cocoeval is False, then compute normalization factor according to the paper. - spread_factor = stddev ** 2 - scale_factor = 2 * ((scale + np.spacing(1)) ** 2) - normalization_factor = np.reshape(spread_factor, (1, 1, n_nodes)) * np.reshape( - scale_factor, (n_gt, 1, 1) - ) - assert normalization_factor.shape == (n_gt, 1, n_nodes) - - # Since a "miss" is considered as KS < 0.5, we'll set the - # distances for predicted points that are missing to inf. - missing_pr = np.any(np.isnan(points_pr), axis=-1) # (n_pr, n_nodes) - assert missing_pr.shape == (n_pr, n_nodes) - distance[:, missing_pr] = np.inf - - # Compute the keypoint similarity as per the top of Eq. 1. - ks = np.exp(-(distance / normalization_factor)) # (n_gt, n_pr, n_nodes) - assert ks.shape == (n_gt, n_pr, n_nodes) - - # Set the KS for missing ground truth points to 0. - # This is equivalent to the visibility delta function of the bottom - # of Eq. 1. - missing_gt = np.any(np.isnan(points_gt), axis=-1) # (n_gt, n_nodes) - assert missing_gt.shape == (n_gt, n_nodes) - ks[np.expand_dims(missing_gt, axis=1)] = 0 - - # Compute the OKS. - n_visible_gt = np.sum( - (~missing_gt).astype("float64"), axis=-1, keepdims=True - ) # (n_gt, 1) - oks = np.sum(ks, axis=-1) / n_visible_gt - assert oks.shape == (n_gt, n_pr) - - return oks - - def match_instances( frame_gt: LabeledFrame, frame_pr: LabeledFrame, diff --git a/sleap/util.py b/sleap/util.py index c27cb6c09..75e24b423 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -82,6 +82,143 @@ def deep_iterable_converter(member_converter, iterable_converter=None): return _DeepIterableConverter(member_converter, iterable_converter) +def compute_instance_area(points: np.ndarray) -> np.ndarray: + """Compute the area of the bounding box of a set of keypoints. + + Args: + points: A numpy array of coordinates. + + Returns: + The area of the bounding box of the points. + """ + if points.ndim == 2: + points = np.expand_dims(points, axis=0) + + min_pt = np.nanmin(points, axis=-2) + max_pt = np.nanmax(points, axis=-2) + + return np.prod(max_pt - min_pt, axis=-1) + + +def compute_oks( + points_gt: np.ndarray, + points_pr: np.ndarray, + scale: Optional[float] = None, + stddev: float = 0.025, + use_cocoeval: bool = True, +) -> np.ndarray: + """Compute the object keypoints similarity between sets of points. + + Args: + points_gt: Ground truth instances of shape (n_gt, n_nodes, n_ed), + where n_nodes is the number of body parts/keypoint types, and n_ed + is the number of Euclidean dimensions (typically 2 or 3). Keypoints + that are missing/not visible should be represented as NaNs. + points_pr: Predicted instance of shape (n_pr, n_nodes, n_ed). + use_cocoeval: Indicates whether the OKS score is calculated like cocoeval + method or not. True indicating the score is calculated using the + cocoeval method (widely used and the code can be found here at + https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L192C5-L233C20) + and False indicating the score is calculated using the method exactly + as given in the paper referenced in the Notes below. + scale: Size scaling factor to use when weighing the scores, typically + the area of the bounding box of the instance (in pixels). This + should be of the length n_gt. If a scalar is provided, the same + number is used for all ground truth instances. If set to None, the + bounding box area of the ground truth instances will be calculated. + stddev: The standard deviation associated with the spread in the + localization accuracy of each node/keypoint type. This should be of + the length n_nodes. "Easier" keypoint types will have lower values + to reflect the smaller spread expected in localizing it. + + Returns: + The object keypoints similarity between every pair of ground truth and + predicted instance, a numpy array of of shape (n_gt, n_pr) in the range + of [0, 1.0], with 1.0 denoting a perfect match. + + Notes: + It's important to set the stddev appropriately when accounting for the + difficulty of each keypoint type. For reference, the median value for + all keypoint types in COCO is 0.072. The "easiest" keypoint is the left + eye, with stddev of 0.025, since it is easy to precisely locate the + eyes when labeling. The "hardest" keypoint is the left hip, with stddev + of 0.107, since it's hard to locate the left hip bone without external + anatomical features and since it is often occluded by clothing. + + The implementation here is based off of the descriptions in: + Ronch & Perona. "Benchmarking and Error Diagnosis in Multi-Instance Pose + Estimation." ICCV (2017). + """ + if points_gt.ndim == 2: + points_gt = np.expand_dims(points_gt, axis=0) + if points_pr.ndim == 2: + points_pr = np.expand_dims(points_pr, axis=0) + + if scale is None: + scale = compute_instance_area(points_gt) + + n_gt, n_nodes, n_ed = points_gt.shape # n_ed = 2 or 3 (euclidean dimensions) + n_pr = points_pr.shape[0] + + # If scalar scale was provided, use the same for each ground truth instance. + if np.isscalar(scale): + scale = np.full(n_gt, scale) + + # If scalar standard deviation was provided, use the same for each node. + if np.isscalar(stddev): + stddev = np.full(n_nodes, stddev) + + # Compute displacement between each pair. + displacement = np.reshape(points_gt, (n_gt, 1, n_nodes, n_ed)) - np.reshape( + points_pr, (1, n_pr, n_nodes, n_ed) + ) + assert displacement.shape == (n_gt, n_pr, n_nodes, n_ed) + + # Convert to pairwise Euclidean distances. + distance = (displacement ** 2).sum(axis=-1) # (n_gt, n_pr, n_nodes) + assert distance.shape == (n_gt, n_pr, n_nodes) + + # Compute the normalization factor per keypoint. + if use_cocoeval: + # If use_cocoeval is True, then compute normalization factor according to cocoeval. + spread_factor = (2 * stddev) ** 2 + scale_factor = 2 * (scale + np.spacing(1)) + else: + # If use_cocoeval is False, then compute normalization factor according to the paper. + spread_factor = stddev ** 2 + scale_factor = 2 * ((scale + np.spacing(1)) ** 2) + normalization_factor = np.reshape(spread_factor, (1, 1, n_nodes)) * np.reshape( + scale_factor, (n_gt, 1, 1) + ) + assert normalization_factor.shape == (n_gt, 1, n_nodes) + + # Since a "miss" is considered as KS < 0.5, we'll set the + # distances for predicted points that are missing to inf. + missing_pr = np.any(np.isnan(points_pr), axis=-1) # (n_pr, n_nodes) + assert missing_pr.shape == (n_pr, n_nodes) + distance[:, missing_pr] = np.inf + + # Compute the keypoint similarity as per the top of Eq. 1. + ks = np.exp(-(distance / normalization_factor)) # (n_gt, n_pr, n_nodes) + assert ks.shape == (n_gt, n_pr, n_nodes) + + # Set the KS for missing ground truth points to 0. + # This is equivalent to the visibility delta function of the bottom + # of Eq. 1. + missing_gt = np.any(np.isnan(points_gt), axis=-1) # (n_gt, n_nodes) + assert missing_gt.shape == (n_gt, n_nodes) + ks[np.expand_dims(missing_gt, axis=1)] = 0 + + # Compute the OKS. + n_visible_gt = np.sum( + (~missing_gt).astype("float64"), axis=-1, keepdims=True + ) # (n_gt, 1) + oks = np.sum(ks, axis=-1) / n_visible_gt + assert oks.shape == (n_gt, n_pr) + + return oks + + def json_loads(json_str: str) -> Dict: """A simple wrapper around the JSON decoder we are using. diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index c98fa39c7..87614de8e 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -404,6 +404,7 @@ def create_instance_group( frame_idx: int, add_dummy: bool = False, name: Optional[str] = None, + score: Optional[float] = None, ) -> Union[ InstanceGroup, Tuple[InstanceGroup, Dict[Camcorder, Instance], Instance, Camcorder] ]: @@ -449,6 +450,7 @@ def create_instance_group( instance_by_camcorder=instance_by_camera, name="test_instance_group", name_registry={}, + score=score, ) return ( (instance_group, instance_by_camera, dummy_instance, cam) @@ -539,6 +541,20 @@ def test_instance_group( assert isinstance(instance_group_dict, dict) assert instance_group_dict["name"] == instance_group.name assert "camcorder_to_lf_and_inst_idx_map" in instance_group_dict + assert "score" not in instance_group_dict + + # Test `score` property (and `to_dict`) + assert instance_group.score is None + instance_group.score = 0.5 + for instance in instance_group.instances: + assert instance.score == 0.5 + instance_group.score = 0.75 + for instance in instance_group.instances: + assert instance.score == 0.75 + instance_group_dict = instance_group.to_dict( + instance_to_lf_and_inst_idx=instance_to_lf_and_inst_idx + ) + assert instance_group_dict["score"] == str(0.75) # Test `from_dict` instance_group_2 = InstanceGroup.from_dict( @@ -590,6 +606,20 @@ def test_instance_group( name="test_instance_group", name_registry={}, ) + instance_by_camera = { + cam: instance_group.get_instance(cam) for cam in instance_group.cameras + } + instance_group_from_dict = InstanceGroup.from_instance_by_camcorder_dict( + instance_by_camcorder=instance_by_camera, + name="test_instance_group", + name_registry={}, + score=0.5, + ) + assert instance_group_from_dict.score == 0.5 + # The score of instances will NOT be updated on initialization. + for instance in instance_group_from_dict.instances: + if isinstance(instance, PredictedInstance): + assert instance.score != instance_group_from_dict.score # Test `__repr__` print(instance_group) @@ -624,9 +654,22 @@ def test_instance_group( # Test `update_points` method assert not np.all(instance_group.numpy(invisible_as_nan=False) == 72317) - instance_group.update_points(np.full((n_views, n_nodes, n_coords), 72317)) + # Remove some Instances to "expose" underlying PredictedInstances + for inst in instance_group.instances[:2]: + lf = inst.frame + labels.remove_instance(lf, inst) + instance_group.update_points(points=np.full((n_views, n_nodes, n_coords), 72317)) + for inst in instance_group.instances: + if isinstance(inst, PredictedInstance): + assert inst.score == instance_group.score + prev_score = instance_group.score + instance_group.update_points(points=np.full((n_views, n_nodes, n_coords), 72317)) + for inst in instance_group.instances: + if isinstance(inst, PredictedInstance): + assert inst.score == instance_group.score instance_group_numpy = instance_group.numpy(invisible_as_nan=False) assert np.all(instance_group_numpy == 72317) + assert instance_group.score == 1.0 # Score should be 1.0 because same points # Test `add_instance`, `replace_instance`, and `remove_instance` cam = instance_group.cameras[0] @@ -858,3 +901,7 @@ def test_frame_group( assert camera in frame_group.cameras assert labeled_frame_created in frame_group.labeled_frames assert labeled_frame in frame_group.session.labels.labeled_frames + + +if __name__ == "__main__": + pytest.main([f"{__file__}::test_instance_group"]) diff --git a/tests/nn/test_evals.py b/tests/nn/test_evals.py index 265994056..a60398623 100644 --- a/tests/nn/test_evals.py +++ b/tests/nn/test_evals.py @@ -13,7 +13,6 @@ from sleap.nn.evals import ( compute_dists, compute_dist_metrics, - compute_oks, load_metrics, evaluate_model, ) @@ -23,48 +22,6 @@ sleap.use_cpu_only() -def test_compute_oks(): - # Test compute_oks function with the cocoutils implementation - inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 1) - - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 2 / 3) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 1) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 1) - - # Test compute_oks function with the implementation from the paper - inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 1) - - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 2 / 3) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 1) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 1) - - def test_compute_dists(instances, predicted_instances): # Make some changes to the instances error_start = 10 diff --git a/tests/test_util.py b/tests/test_util.py index a7916d47f..cd4e01aad 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -154,3 +154,45 @@ def test_decode_preview_image(flies13_skeleton: Skeleton): img_b64 = skeleton.preview_image img = decode_preview_image(img_b64) assert img.mode == "RGBA" + + +def test_compute_oks(): + # Test compute_oks function with the cocoutils implementation + inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 1) + + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 2 / 3) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 1) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 1) + + # Test compute_oks function with the implementation from the paper + inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 1) + + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 2 / 3) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 1) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 1)