Skip to content

Commit

Permalink
Log a warning instead of raising an issue
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Nov 15, 2022
1 parent 86d9b4e commit 82472a7
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions sleap/nn/tracker/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,26 @@ def object_keypoint_similarity(
if max_n_keypoints == 0:
return 0

# Compute distances
# Make sure the sizes of kp_precision and n_points match
if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size:
# Correct kp_precision size to fit number of points
n_points = ref_points.size // 2
mess = (
"keypoint_errors array should have the same size as the number of "
f"keypoints in the instance: {kp_precision.size} != {ref_points.size // 2}"
f"keypoints in the instance: {kp_precision.size} != {n_points}"
)
raise ValueError(mess)

if kp_precision.size > n_points:
kp_precision = kp_precision[:n_points]
mess += "\nTruncating keypoint_errors array."

else: # elif kp_precision.size < n_points:
pad = n_points - kp_precision.size
kp_precision = np.pad(kp_precision, (0, pad), "edge")
mess += "\nPadding keypoint_errors array by repeating the last value."
logger.warning(mess)

# Compute distances
dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision

similarity = (
Expand Down

0 comments on commit 82472a7

Please sign in to comment.