Skip to content

Commit

Permalink
Merge pull request #2 from elephant-track/fix-ellipsoid
Browse files Browse the repository at this point in the history
BugFix: avoid returning empty arrays in ellipsoid
  • Loading branch information
ksugar authored Apr 10, 2021
2 parents 5621321 + 990d333 commit b917e93
Showing 1 changed file with 41 additions and 37 deletions.
78 changes: 41 additions & 37 deletions elephant-core/elephant/util/ellipsoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def ellipsoid(center, radii, rotation, scales=None, shape=None, minarea=0):
"""
assert center.shape == (3,)
assert radii.shape == (3,)
assert 0 < radii.max(), "radii should contain at least one positive value"
assert rotation.shape == (3, 3)
if scales is None:
scales = (1.,) * 3
Expand All @@ -106,46 +107,49 @@ def ellipsoid(center, radii, rotation, scales=None, shape=None, minarea=0):
# containing the ellipsoid.
factor = np.array([
[i, j, k] for k in (-1, 1) for j in (-1, 1) for i in (-1, 1)]).T
radii_rot = np.abs(
np.diag(1. / scales).dot(rotation.dot(np.diag(radii).dot(factor)))
).max(axis=1)
# In the original scikit-image code, ceil and floor were replaced.
# https://github.com/scikit-image/scikit-image/blob/master/skimage/draw/draw.py#L127
upper_left_bottom = np.floor(scaled_center - radii_rot).astype(int)
lower_right_top = np.ceil(scaled_center + radii_rot).astype(int)
while True:
radii_rot = np.abs(
np.diag(1. / scales).dot(rotation.dot(np.diag(radii).dot(factor)))
).max(axis=1)
# In the original scikit-image code, ceil and floor were replaced.
# https://github.com/scikit-image/scikit-image/blob/master/skimage/draw/draw.py#L127
upper_left_bottom = np.floor(scaled_center - radii_rot).astype(int)
lower_right_top = np.ceil(scaled_center + radii_rot).astype(int)

if shape is not None:
# Constrain upper_left and lower_ight by shape boundary.
upper_left_bottom = np.maximum(upper_left_bottom, np.array([0, 0, 0]))
lower_right_top = np.minimum(lower_right_top, np.array(shape[:3]) - 1)
if shape is not None:
# Constrain upper_left and lower_ight by shape boundary.
upper_left_bottom = np.maximum(
upper_left_bottom, np.array([0, 0, 0]))
lower_right_top = np.minimum(
lower_right_top, np.array(shape[:3]) - 1)

bounding_shape = lower_right_top - upper_left_bottom + 1
bounding_shape = lower_right_top - upper_left_bottom + 1

d_lim, r_lim, c_lim = np.ogrid[0:float(bounding_shape[0]),
0:float(bounding_shape[1]),
0:float(bounding_shape[2])]
d_org, r_org, c_org = scaled_center - upper_left_bottom
d_rad, r_rad, c_rad = radii
rotation_inv = np.linalg.inv(rotation)
conversion_matrix = rotation_inv.dot(np.diag(scales))
d, r, c = (d_lim - d_org), (r_lim - r_org), (c_lim - c_org)
distances = (
((d * conversion_matrix[0, 0] +
r * conversion_matrix[0, 1] +
c * conversion_matrix[0, 2]) / d_rad) ** 2 +
((d * conversion_matrix[1, 0] +
r * conversion_matrix[1, 1] +
c * conversion_matrix[1, 2]) / r_rad) ** 2 +
((d * conversion_matrix[2, 0] +
r * conversion_matrix[2, 1] +
c * conversion_matrix[2, 2]) / c_rad) ** 2
)
if distances.size < minarea:
print('Skip: minarea ({}) exceeds the size of distances array ({})'
.format(minarea, distances.size))
return (np.empty(0, dtype=int),
np.empty(0, dtype=int),
np.empty(0, dtype=int))
d_lim, r_lim, c_lim = np.ogrid[0:float(bounding_shape[0]),
0:float(bounding_shape[1]),
0:float(bounding_shape[2])]
d_org, r_org, c_org = scaled_center - upper_left_bottom
d_rad, r_rad, c_rad = radii
rotation_inv = np.linalg.inv(rotation)
conversion_matrix = rotation_inv.dot(np.diag(scales))
d, r, c = (d_lim - d_org), (r_lim - r_org), (c_lim - c_org)
distances = (
((d * conversion_matrix[0, 0] +
r * conversion_matrix[0, 1] +
c * conversion_matrix[0, 2]) / d_rad) ** 2 +
((d * conversion_matrix[1, 0] +
r * conversion_matrix[1, 1] +
c * conversion_matrix[1, 2]) / r_rad) ** 2 +
((d * conversion_matrix[2, 0] +
r * conversion_matrix[2, 1] +
c * conversion_matrix[2, 2]) / c_rad) ** 2
)
if distances.size < minarea:
old_radii = radii.copy()
radii *= 1.1
print('Increase radii from ({}) to ({})'.format(old_radii, radii))
else:
break
distance_thresh = 1
while True:
dd, rr, cc = np.nonzero(distances < distance_thresh)
Expand Down

0 comments on commit b917e93

Please sign in to comment.