Skip to content

Commit

Permalink
Add object keypoint similarity method
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Jul 23, 2024
1 parent 3e2bd25 commit b4374b3
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 16 deletions.
44 changes: 38 additions & 6 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ training:
This pipeline uses two models: a "<u>centroid</u>" model to
locate and crop around each animal in the frame, and a
"<u>centered-instance confidence map</u>" model for predicted node locations
for each individual animal predicted by the centroid model.'
for each individual animal predicted by the centroid model.'
- label: Max Instances
name: max_instances
type: optional_int
Expand Down Expand Up @@ -217,7 +217,7 @@ training:
- name: controller_port
label: Controller Port
type: int
default: 9000
default: 9000
range: 1024,65535

- name: publish_port
Expand Down Expand Up @@ -388,7 +388,7 @@ inference:
tracking-only:

- name: batch_size
label: Batch Size
label: Batch Size
type: int
default: 4
range: 1,512
Expand Down Expand Up @@ -439,7 +439,7 @@ inference:
label: Similarity Method
type: list
default: instance
options: instance,centroid,iou
options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
Expand Down Expand Up @@ -478,6 +478,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
- type: text
text: '<b>Object keypoint similarity options</b>:<br />
Only used if this similarity method is selected.'
- name: tracking.oks_errors
label: Keypoints errors in pixels
help: 'Standard error in pixels of the distance for each keypoint.
If the list is empty, defaults to 1. If singleton list, each keypoint has
the same error. Otherwise, the length should be the same as the number of
keypoints in the skeleton.'
type: string
default:
- name: tracking.oks_score_weighting
label: Use prediction score for weighting
help: 'Use prediction scores to weight the similarity of each keypoint'
type: bool
default: false
- type: text
text: '<b>Post-tracker data cleaning</b>:'
- name: tracking.post_connect_single_breaks
Expand Down Expand Up @@ -521,8 +537,8 @@ inference:
- name: tracking.similarity
label: Similarity Method
type: list
default: iou
options: instance,centroid,iou
default: instance
options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
Expand Down Expand Up @@ -557,6 +573,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
- type: text
text: '<b>Object keypoint similarity options</b>:<br />
Only used if this similarity method is selected.'
- name: tracking.oks_errors
label: Keypoints errors in pixels
help: 'Standard error in pixels of the distance for each keypoint.
If the list is empty, defaults to 1. If singleton list, each keypoint has
the same error. Otherwise, the length should be the same as the number of
keypoints in the skeleton.'
type: string
default:
- name: tracking.oks_score_weighting
label: Use prediction score for weighting
help: 'Use prediction scores to weight the similarity of each keypoint'
type: bool
default: false
- type: text
text: '<b>Post-tracker data cleaning</b>:'
- name: tracking.post_connect_single_breaks
Expand Down
8 changes: 8 additions & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,20 @@ def make_predict_cli_call(
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
)

for key in bool_items_as_ints:
if key in self.inference_params:
self.inference_params[key] = int(self.inference_params[key])

remove_spaces_items = ("tracking.similarity",)

for key in remove_spaces_items:
if key in self.inference_params:
value = self.inference_params[key]
self.inference_params[key] = value.replace(" ", "_")

for key, val in self.inference_params.items():
if not key.startswith(("_", "outputs.", "model.", "data.")):
cli_args.extend((f"--{key}", str(val)))
Expand Down
94 changes: 93 additions & 1 deletion sleap/nn/tracker/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"""
import operator
from collections import defaultdict
from typing import List, Tuple, Optional, TypeVar, Callable
import logging
from typing import List, Tuple, Union, Optional, TypeVar, Callable

import attr
import numpy as np
Expand All @@ -23,6 +24,8 @@
from sleap import PredictedInstance, Instance, Track
from sleap.nn import utils

logger = logging.getLogger(__name__)

InstanceType = TypeVar("InstanceType", Instance, PredictedInstance)


Expand All @@ -40,6 +43,95 @@ def instance_similarity(
return similarity


def factory_object_keypoint_similarity(
keypoint_errors: Optional[Union[List, int, float]] = None,
score_weighting: bool = False,
normalization_keypoints: str = "all",
) -> Callable:
"""Factory for similarity function based on object keypoints.
Args:
keypoint_errors: The standard error of the distance between the predicted
keypoint and the true value, in pixels.
If None or empty list, defaults to 1.
If a scalar or singleton list, every keypoint has the same error.
If a list, defines the error for each keypoint, the length should be equal
to the number of keypoints in the skeleton.
score_weighting: If True, use `score` of `PredictedPoint` to weigh
`keypoint_errors`. If False, do not add a weight to `keypoint_errors`.
normalization_keypoints: Determine how to normalize similarity score. One of
["all", "ref", "union"]. If "all", similarity score is normalized by number
of reference points. If "ref", similarity score is normalized by number of
visible reference points. If "union", similarity score is normalized by
number of points both visible in query and reference instance.
Default is "all".
Returns:
Callable that returns object keypoint similarity between two `Instance`s.
"""
keypoint_errors = 1 if keypoint_errors is None else keypoint_errors
with np.errstate(divide="ignore"):
kp_precision = 1 / (2 * np.array(keypoint_errors) ** 2)

def object_keypoint_similarity(
ref_instance: InstanceType, query_instance: InstanceType
) -> float:
nonlocal kp_precision
# Keypoints
ref_points = ref_instance.points_array
query_points = query_instance.points_array
# Keypoint scores
if score_weighting:
ref_scores = getattr(ref_instance, "scores", np.ones(len(ref_points)))
query_scores = getattr(query_instance, "scores", np.ones(len(query_points)))
else:
ref_scores = 1
query_scores = 1
# Number of keypoint for normalization
if normalization_keypoints in ("ref", "union"):
ref_visible = ~(np.isnan(ref_points).any(axis=1))
if normalization_keypoints == "ref":
max_n_keypoints = np.sum(ref_visible)
elif normalization_keypoints == "union":
query_visible = ~(np.isnan(query_points).any(axis=1))
max_n_keypoints = np.sum(np.logical_and(ref_visible, query_visible))
else: # if normalization_keypoints == "all":
max_n_keypoints = len(ref_points)
if max_n_keypoints == 0:
return 0

# Make sure the sizes of kp_precision and n_points match
if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size:
# Correct kp_precision size to fit number of points
n_points = ref_points.size // 2
mess = (
"keypoint_errors array should have the same size as the number of "
f"keypoints in the instance: {kp_precision.size} != {n_points}"
)

if kp_precision.size > n_points:
kp_precision = kp_precision[:n_points]
mess += "\nTruncating keypoint_errors array."

else: # elif kp_precision.size < n_points:
pad = n_points - kp_precision.size
kp_precision = np.pad(kp_precision, (0, pad), "edge")
mess += "\nPadding keypoint_errors array by repeating the last value."
logger.warning(mess)

# Compute distances
dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision

similarity = (
np.nansum(ref_scores * query_scores * np.exp(-dists)) / max_n_keypoints
)

return similarity

return object_keypoint_similarity


def centroid_distance(
ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict()
) -> float:
Expand Down
51 changes: 50 additions & 1 deletion sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sleap import Track, LabeledFrame, Skeleton

from sleap.nn.tracker.components import (
factory_object_keypoint_similarity,
instance_similarity,
centroid_distance,
instance_iou,
Expand Down Expand Up @@ -492,6 +493,7 @@ def get_candidates(
instance=instance_similarity,
centroid=centroid_distance,
iou=instance_iou,
object_keypoint=instance_similarity,
)

match_policies = dict(
Expand Down Expand Up @@ -838,6 +840,10 @@ def make_tracker_by_name(
# Max tracking options
max_tracks: Optional[int] = None,
max_tracking: bool = False,
# Object keypoint similarity options
oks_errors: Optional[list] = None,
oks_score_weighting: bool = False,
oks_normalization: str = "all",
**kwargs,
) -> BaseTracker:

Expand All @@ -858,7 +864,14 @@ def make_tracker_by_name(
raise ValueError(f"{match} is not a valid tracker matching function.")

candidate_maker = tracker_policies[tracker](min_points=min_match_points)
similarity_function = similarity_policies[similarity]
if similarity == "object_keypoint":
similarity_function = factory_object_keypoint_similarity(
keypoint_errors=oks_errors,
score_weighting=oks_score_weighting,
normalization_keypoints=oks_normalization,
)
else:
similarity_function = similarity_policies[similarity]
matching_function = match_policies[match]

if tracker == "flow":
Expand Down Expand Up @@ -1054,6 +1067,42 @@ def int_list_func(s):
] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used."
options.append(option)

def float_list_func(s):
return [float(x.strip()) for x in s.split(",")] if s else None

option = dict(name="oks_errors", default="1")
option["type"] = float_list_func
option["help"] = (
"For Object Keypoint similarity: the standard error of the distance "
"between the predicted keypoint and the true value, in pixels.\n"
"If None or empty list, defaults to 1. If a scalar or singleton list, "
"every keypoint has the same error. If a list, defines the error for each "
"keypoint, the length should be equal to the number of keypoints in the "
"skeleton."
)
options.append(option)

option = dict(name="oks_score_weighting", default="0")
option["type"] = int
option["help"] = (
"For Object Keypoint similarity: if 0 (default), only the distance between the reference "
"and query keypoint is used to compute the similarity. If 1, each distance is weighted "
"by the prediction scores of the reference and query keypoint."
)
options.append(option)

option = dict(name="oks_normalization", default="all")
option["type"] = str
option["options"] = ["all", "ref", "union"]
option["help"] = (
"For Object Keypoint similarity: Determine how to normalize similarity score. "
"If 'all', similarity score is normalized by number of reference points. "
"If 'ref', similarity score is normalized by number of visible reference points. "
"If 'union', similarity score is normalized by number of points both visible "
"in query and reference instance."
)
options.append(option)

return options

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def centered_pair_predictions():
return Labels.load_file(TEST_JSON_PREDICTIONS)


@pytest.fixture
def centered_pair_predictions_sorted(centered_pair_predictions):
labels: Labels = centered_pair_predictions
labels.labeled_frames.sort(key=lambda lf: lf.frame_idx)
return labels


@pytest.fixture
def min_labels():
return Labels.load_file(TEST_JSON_MIN_LABELS)
Expand Down
10 changes: 6 additions & 4 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,7 @@ def test_retracking(
# Create sleap-track command
cmd = (
f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 "
"--cpu"
"--tracking.similarity object_keypoint --cpu"
)
if tracker_method == "flow":
cmd += " --tracking.save_shifted_instances 1"
Expand All @@ -1393,6 +1393,8 @@ def test_retracking(
parser = _make_cli_parser()
args, _ = parser.parse_known_args(args=args)
tracker = _make_tracker_from_cli(args)
# Additional check for similarity method
assert tracker.similarity_function.__name__ == "object_keypoint_similarity"
output_path = f"{slp_path}.{tracker.get_name()}.slp"

# Assert tracked predictions file exists
Expand Down Expand Up @@ -1747,9 +1749,9 @@ def test_sleap_track_invalid_input(
sleap_track(args=args)


def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir):
"""Test flow tracker instances are pruned."""
labels: Labels = centered_pair_predictions
labels: Labels = centered_pair_predictions_sorted
track_window = 5

# Setup tracker
Expand All @@ -1759,7 +1761,7 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker)

# Run tracking
frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx)
frames = labels.labeled_frames

# Run tracking on subset of frames using psuedo-implementation of
# sleap.nn.tracking.run_tracker
Expand Down
Loading

0 comments on commit b4374b3

Please sign in to comment.