Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use oks score for reprojections #1836

6 changes: 4 additions & 2 deletions sleap/gui/dataviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,16 +691,18 @@ class InstanceGroupTableModel(GenericTableModel):
item: 'InstanceGroup' which has information about the instance group
"""

properties = ("name", "frame index", "cameras", "instances")
properties = ("name", "score", "frame index", "cameras", "instances")

def item_to_data(self, obj, item: InstanceGroup):

return {
data = {
"name": item.name,
"score": "" if item.score is None else str(round(item.score, 2)),
"frame index": item.frame_idx,
"cameras": len(item.camera_cluster.cameras),
"instances": len(item.instances),
}
return data

def get_item_color(self, instance_group: InstanceGroup, key: str):
color_manager = self.context.app.color_manager
Expand Down
101 changes: 84 additions & 17 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer
from sleap.instance import Instance, LabeledFrame, PredictedInstance
from sleap.io.video import Video
from sleap.util import deep_iterable_converter
from sleap.util import compute_oks, deep_iterable_converter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -437,6 +437,9 @@ class InstanceGroup:
cameras: List of `Camcorder` objects that have an `Instance` associated.
instances: List of `Instance` objects.
instance_by_camcorder: Dictionary of `Instance` objects by `Camcorder`.
score: Optional score for the `InstanceGroup`. Setting the score will also
update the score for all `instances` already in the `InstanceGroup`. The
score for `instances` will not be updated upon initialization.
"""

_name: str = field()
Expand All @@ -445,6 +448,7 @@ class InstanceGroup:
_camcorder_by_instance: Dict[Instance, Camcorder] = field(factory=dict)
_dummy_instance: Optional[Instance] = field(default=None)
camera_cluster: Optional[CameraCluster] = field(default=None)
_score: Optional[float] = field(default=None)

def __attrs_post_init__(self):
"""Initialize `InstanceGroup` object."""
Expand Down Expand Up @@ -566,7 +570,7 @@ def return_unique_name(cls, name_registry: Set[str]) -> str:
return new_name

@property
def instances(self) -> List[Instance]:
def instances(self) -> List[Union[Instance, PredictedInstance]]:
"""List of `Instance` objects."""
return list(self._instance_by_camcorder.values())

Expand All @@ -580,7 +584,34 @@ def instance_by_camcorder(self) -> Dict[Camcorder, Instance]:
"""Dictionary of `Instance` objects by `Camcorder`."""
return self._instance_by_camcorder

def numpy(self, pred_as_nan: bool = False, invisible_as_nan=True) -> np.ndarray:
@property
def score(self) -> Optional[float]:
"""Score for the `InstanceGroup`."""
return self._score

@score.setter
def score(self, score: Optional[float]):
"""Set the score for the `InstanceGroup`.

Also sets the score for all instances in the `InstanceGroup` if they have a
`score` attribute.

Args:
score: Score to set for the `InstanceGroup`.
"""

for instance in self.instances:
if hasattr(instance, "score"):
instance.score = score

self._score = score

def numpy(
self,
pred_as_nan: bool = False,
invisible_as_nan=True,
cams_to_include: Optional[List[Camcorder]] = None,
) -> np.ndarray:
"""Return instances as a numpy array of shape (n_views, n_nodes, 2).

The ordering of views is based on the ordering of `Camcorder`s in the
Expand All @@ -594,13 +625,19 @@ def numpy(self, pred_as_nan: bool = False, invisible_as_nan=True) -> np.ndarray:
self.dummy_instance. Default is False.
invisible_as_nan: If True, then replaces invisible points with nan. Default
is True.
cams_to_include: List of `Camcorder`s to include in the numpy array. If
None, then all `Camcorder`s in the `CameraCluster` are included. Default
is None.

Returns:
Numpy array of shape (n_views, n_nodes, 2).
"""

instance_numpys: List[np.ndarray] = [] # len(M) x N x 2
for cam in self.camera_cluster.cameras:
if cams_to_include is None:
cams_to_include = self.camera_cluster.cameras

for cam in cams_to_include:
instance = self.get_instance(cam)

# Determine whether to use a dummy (all nan) instance
Expand Down Expand Up @@ -825,6 +862,11 @@ def update_points(
f"Camcorders in `cams_to_include` ({len(cams_to_include)})."
)

# Calculate OKS scores for the points
gt_points = self.numpy(
pred_as_nan=True, invisible_as_nan=True, cams_to_include=cams_to_include
) # M x N x 2
oks_scores = np.full((n_views, n_nodes), np.nan)
for cam_idx, cam in enumerate(cams_to_include):
# Get the instance for the cam
instance: Optional[Instance] = self.get_instance(cam)
Expand All @@ -834,11 +876,22 @@ def update_points(
)
continue

# Update the points (and scores) for the (predicted) instance
# Compute the OKS score for the instance if it is a ground truth instance
if not isinstance(instance, PredictedInstance):
instance_oks = compute_oks(
gt_points[cam_idx, :, :],
points[cam_idx, :, :],
)
oks_scores[cam_idx] = instance_oks

# Update the points for the instance
instance.update_points(
points=points[cam_idx, :, :], exclude_complete=exclude_complete
)

# Update the score for the InstanceGroup to be the average OKS score
self.score = np.nanmean(oks_scores) # scalar

def __getitem__(
self, idx_or_key: Union[int, Camcorder, Instance]
) -> Union[Camcorder, Instance]:
Expand Down Expand Up @@ -873,7 +926,8 @@ def __len__(self):
def __repr__(self):
return (
f"{self.__class__.__name__}(name={self.name}, frame_idx={self.frame_idx}, "
f"instances:{len(self)}, camera_cluster={self.camera_cluster})"
f"score={self.score}, instances:{len(self)}, camera_cluster="
f"{self.camera_cluster})"
)

def __hash__(self) -> int:
Expand All @@ -885,13 +939,16 @@ def from_instance_by_camcorder_dict(
instance_by_camcorder: Dict[Camcorder, Instance],
name: str,
name_registry: Set[str],
score: Optional[float] = None,
) -> Optional["InstanceGroup"]:
"""Creates an `InstanceGroup` object from a dictionary.

Args:
instance_by_camcorder: Dictionary with `Camcorder` keys and `Instance` values.
name: Name to use for the `InstanceGroup`.
name_registry: Set of names to check for uniqueness.
score: Optional score for the `InstanceGroup`. This will NOT update the
score of the `Instance`s within the `InstanceGroup`. Default is None.

Raises:
ValueError: If the `InstanceGroup` name is already in use.
Expand Down Expand Up @@ -935,6 +992,7 @@ def from_instance_by_camcorder_dict(
frame_idx=frame_idx,
camera_cluster=camera_cluster,
instance_by_camcorder=instance_by_camcorder_copy,
score=score,
)

def to_dict(
Expand Down Expand Up @@ -962,10 +1020,14 @@ def to_dict(
for cam, instance in self._instance_by_camcorder.items()
}

return {
instance_group_dict = {
"name": self.name,
"camcorder_to_lf_and_inst_idx_map": camcorder_to_lf_and_inst_idx_map,
}
if self.score is not None:
instance_group_dict["score"] = str(round(self.score, 4))

return instance_group_dict

@classmethod
def from_dict(
Expand All @@ -989,6 +1051,13 @@ def from_dict(
`InstanceGroup` object.
"""

# Get the score (if available)
score = (
float(instance_group_dict["score"])
if "score" in instance_group_dict
else None
)

# Get the `Instance` objects
camcorder_to_lf_and_inst_idx_map: Dict[
str, Tuple[str, str]
Expand All @@ -1010,6 +1079,7 @@ def from_dict(
instance_by_camcorder=instance_by_camcorder,
name=instance_group_dict["name"],
name_registry=name_registry,
score=score,
)


Expand Down Expand Up @@ -1689,7 +1759,7 @@ def numpy(

Returns:
Numpy array of shape (M, T, N, 2) where M is the number of views (determined
by self.cames_to_include), T is the number of `InstanceGroup`s, N is the
by self.cams_to_include), T is the number of `InstanceGroup`s, N is the
number of Nodes, and 2 is for x, y.
"""

Expand All @@ -1705,19 +1775,17 @@ def numpy(
f"{self.instance_groups}"
)

instance_group_numpys: List[np.ndarray] = [] # len(T) M=all x N x 2
instance_group_numpys: List[np.ndarray] = [] # len(T) M=include x N x 2
roomrys marked this conversation as resolved.
Show resolved Hide resolved
for instance_group in instance_groups:
instance_group_numpy = instance_group.numpy(
pred_as_nan=pred_as_nan
) # M=all x N x 2
pred_as_nan=pred_as_nan,
cams_to_include=self.cams_to_include,
) # M=include x N x 2
instance_group_numpys.append(instance_group_numpy)

frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2
cams_to_include_mask = np.array(
[cam in self.cams_to_include for cam in self.session.cameras]
) # M=all x 1
frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=include x TxNx2

return frame_group_numpy[cams_to_include_mask] # M=include x T x N x 2
return frame_group_numpy # M=include x T x N x 2

def add_instance(
self,
Expand Down Expand Up @@ -2110,7 +2178,6 @@ def upsert_points(
This will update the points for existing `Instance`s in the `InstanceGroup`s and
also add new `Instance`s if they do not exist.


Included cams are specified by `FrameGroup.cams_to_include`.

The ordering of the `InstanceGroup`s in `instance_groups` should match the
Expand Down
Loading
Loading