Skip to content

Commit

Permalink
Faster hessian_matrix_* and structure_tensor_eigvals via analytic…
Browse files Browse the repository at this point in the history
…al eigenvalues for the 3D case (#434)

closes #354

This MR implements faster 2D and 3D pixelwise eigenvalue computations for `hessian_matrix_eigvals`, `structure_tensor_eigvals` and `hessian_matrix_det`. The 2D case already had a fairly fast code path, but it is further improved here by switching from a fused kernel to an elementwise kernel that removed the need for a separate call to `cupy.stack`. In 3D runtime is reduced by ~30x for float32 and >100x for float64. The 3D case also uses MUCH less RAM than previously (>20x reduction). For example computing the eigenvalues for size (128, 128, 128) float32 arrays would run out of memory even on an A6000 (40GB). With the changes here, it works even for 16x larger data of shape (512, 512, 512).

Functions that benefit from this are:
	
- `cucim.skimage.feature.hessian_matrix_det`
- `cucim.skimage.feature.hessian_matrix_eigvals`	
- `cucim.skimage.feature.structure_tensor_eigenvalues`
- `cucim.skimage.feature.shape_index`
- `cucim.skimage.feature.multiscale_basic_features`
- `cucim.skimage.filters.meijering`
- `cucim.skimage.filters.sato`
- `cucim.skimage.filters.frangi`
- `cucim.skimage.filters.hessian`

Independently of the above, the function `cucim.skimage.measure.inertia_tensor_eigvals` was updated with custom kernels so it can operate purely on the GPU for the 2D and 3D cases (formly these used copies to/from the host). These operate on tiny arrays, so they use only a single GPU thread. Despite the lack of paralellism, this is lower overhead than round trip host/device transfer. This will also improve region properties making use of these eigenvalues (e.g. the `axis_major_length` and `axis_minor_length` properties for `regionprops_table`)

Authors:
  - Gregory Lee (https://github.com/grlee77)

Approvers:
  - Gigon Bae (https://github.com/gigony)

URL: #434
  • Loading branch information
grlee77 authored Nov 16, 2022
1 parent f54b560 commit 762d739
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 86 deletions.
272 changes: 241 additions & 31 deletions python/cucim/src/cucim/skimage/feature/corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,43 +399,253 @@ def hessian_matrix_det(image, sigma=1, approximate=True):
# integral = integral_image(image)
# return cp.asarray(_hessian_matrix_det(integral, sigma))
else: # slower brute-force implementation for nD images
if image.ndim in [2, 3]:
# Compute determinant as the product of the eigenvalues.
# This avoids the huge memory overhead of forming
# `_symmetric_image` as in the code below.
# Could optimize further by computing the determinant directly
# using ElementwiseKernels rather than reusing the eigenvalue ones.
H = hessian_matrix(image, sigma)
evs = hessian_matrix_eigvals(H)
return cp.prod(evs, axis=0)
hessian_mat_array = _symmetric_image(hessian_matrix(image, sigma))
return cp.linalg.det(hessian_mat_array)


@cp.fuse()
def _image_orthogonal_matrix22_eigvals(M00, M01, M11):
@cp.memoize()
def _get_real_symmetric_2x2_eigvals_kernel(sort='ascending', abs_sort=False):

operation = """
F tmp1, tmp2;
double m00 = static_cast<double>(M00);
double m01 = static_cast<double>(M01);
double m11 = static_cast<double>(M11);
tmp1 = m01 * m01;
tmp1 *= 4;
tmp2 = m00 - m11;
tmp2 *= tmp2;
tmp2 += tmp1;
tmp2 = sqrt(tmp2);
tmp2 /= 2;
tmp1 = m00 + m11;
tmp1 /= 2;
"""
analytical formula below optimized for in-place computations.
if sort == 'ascending':
operation += """
lam1 = tmp1 - tmp2;
lam2 = tmp1 + tmp2;
"""
if abs_sort:
operation += """
F stmp;
if (abs(lam1) > abs(lam2)) {
stmp = lam1;
lam1 = lam2;
lam2 = stmp;
}
"""
elif sort == 'descending':
operation += """
lam1 = tmp1 + tmp2;
lam2 = tmp1 - tmp2;
"""
if abs_sort:
operation += """
F stmp;
if (abs(lam1) < abs(lam2)) {
stmp = lam1;
lam1 = lam2;
lam2 = stmp;
}
"""
else:
raise ValueError(f"unknown sort type: {sort}")
return cp.ElementwiseKernel(
in_params="F M00, F M01, F M11",
out_params="F lam1, F lam2",
operation=operation,
name="cucim_skimage_symmetric_eig22_kernel")


def _image_orthogonal_matrix22_eigvals(
M00, M01, M11, sort='descending', abs_sort=False
):
r"""Analytical expressions of the eigenvalues of a symmetric 2 x 2 matrix.
It corresponds to::
l1 = (M00 + M11) / 2 + cp.sqrt(4 * M01 ** 2 + (M00 - M11) ** 2) / 2
l2 = (M00 + M11) / 2 - cp.sqrt(4 * M01 ** 2 + (M00 - M11) ** 2) / 2
Parameters
----------
M00, M01, M11 : cp.ndarray
Images corresponding to the individual components of the matrix. For
example, ``M01 = M[0, 1]``.
sort : {"ascending", "descending"}, optional
Eigenvalues should be sorted in the specified order.
abs_sort : boolean, optional
If ``True``, sort based on the absolute values.
References
----------
.. [1] C. Deledalle, L. Denis, S. Tabti, F. Tupin. Closed-form expressions
of the eigen decomposition of 2 x 2 and 3 x 3 Hermitian matrices.
[Research Report] Université de Lyon. 2017.
https://hal.archives-ouvertes.fr/hal-01501221/file/matrix_exp_and_log_formula.pdf
""" # noqa
if M00.dtype.kind != "f":
raise ValueError("expected real-valued floating point matrices")
kernel = _get_real_symmetric_2x2_eigvals_kernel(
sort=sort, abs_sort=abs_sort
)
eigs = cp.empty((2,) + M00.shape, dtype=M00.dtype)
kernel(M00, M01, M11, eigs[0], eigs[1])
return eigs


@cp.memoize()
def _get_real_symmetric_3x3_eigvals_kernel(sort='ascending', abs_sort=False):

operation = """
double x1, x2, phi;
double a = static_cast<double>(aa);
double b = static_cast<double>(bb);
double c = static_cast<double>(cc);
double d = static_cast<double>(dd);
double e = static_cast<double>(ee);
double f = static_cast<double>(ff);
double d_sq = d * d;
double e_sq = e * e;
double f_sq = f * f;
double tmpa = (2*a - b - c);
double tmpb = (2*b - a - c);
double tmpc = (2*c - a - b);
x2 = - tmpa * tmpb * tmpc;
x2 += 9 * (tmpc*d_sq + tmpb*f_sq + tmpa*e_sq);
x2 -= 54 * (d * e * f);
x1 = a*a + b*b + c*c - a*b - a*c - b*c + 3 * (d_sq + e_sq + f_sq);
if (x2 == 0.0) {
phi = M_PI / 2.0;
} else {
// grlee77: added max() here for numerical stability
// (avoid NaN values in test_hessian_matrix_eigvals_3d)
double arg = max(4*x1*x1*x1 - x2*x2, 0.0);
phi = atan(sqrt(arg)/x2);
if (x2 < 0) {
phi += M_PI;
}
}
double x1_term = (2.0 / 3.0) * sqrt(x1);
double abc = (a + b + c) / 3.0;
lam1 = abc - x1_term * cos(phi / 3.0);
lam2 = abc + x1_term * cos((phi - M_PI) / 3.0);
lam3 = abc + x1_term * cos((phi + M_PI) / 3.0);
"""
sort_template = """
F stmp;
if ({prefix}{var1} > {prefix}{var2}) {{
stmp = {var2};
{var2} = {var1};
{var1} = stmp;
}} if ({prefix}{var1} > {prefix}{var3}) {{
stmp = {var3};
{var3} = {var1};
{var1} = stmp;
}} if ({prefix}{var2} > {prefix}{var3}) {{
stmp = {var3};
{var3} = {var2};
{var2} = stmp;
}}
"""
tmp1 = M01 * M01
tmp1 *= 4
if abs_sort:
operation += """
F abs_lam1 = abs(lam1);
F abs_lam2 = abs(lam2);
F abs_lam3 = abs(lam3);
"""
prefix = "abs_"
else:
prefix = ""
if sort == 'ascending':
var1 = "lam1"
var3 = "lam3"
elif sort == 'descending':
var1 = "lam3"
var3 = "lam1"
operation += sort_template.format(
prefix=prefix, var1=var1, var2="lam2", var3=var3
)
return cp.ElementwiseKernel(
in_params="F aa, F bb, F cc, F dd, F ee, F ff",
out_params="F lam1, F lam2, F lam3",
operation=operation,
name="cucim_skimage_symmetric_eig33_kernel")

tmp2 = M00 - M11
tmp2 *= tmp2
tmp2 += tmp1
cp.sqrt(tmp2, out=tmp2)
tmp2 /= 2

tmp1 = M00 + M11
tmp1 /= 2
l1 = tmp1 + tmp2
l2 = tmp1 - tmp2
return l1, l2
def _image_orthogonal_matrix33_eigvals(
a, d, f, b, e, c, sort='descending', abs_sort=False
):
r"""Analytical expressions of the eigenvalues of a symmetric 3 x 3 matrix.
Follows the expressions given for hermitian symmetric 3 x 3 matrices in
[1]_, but simplified to handle real-valued matrices only.
We are computing moments at each voxel of the volume, so each of ``a``,
``d``, ``f``, ``b``, ``e``, and ``c`` will be equal in shape to the 3D
volume.
Invidual arguments correspond to the following moment matrix entries
.. math::
M = \begin{bmatrix}
a & d & f\\
d & b & e\\
f & e & c
\end{bmatrix}
def _symmetric_compute_eigenvalues(S_elems):
Parameters
----------
a, d, f, b, e, c : cp.ndarray
Images corresponding to the individual components of the matrix, `M`,
shown above. For example, ``d = M[0, 1]``.
sort : {"ascending", "descending"}, optional
Eigenvalues should be sorted in the specified order.
abs_sort : boolean, optional
If ``True``, sort based on the absolute values.
References
----------
.. [1] C. Deledalle, L. Denis, S. Tabti, F. Tupin. Closed-form expressions
of the eigen decomposition of 2 x 2 and 3 x 3 Hermitian matrices.
[Research Report] Université de Lyon. 2017.
https://hal.archives-ouvertes.fr/hal-01501221/file/matrix_exp_and_log_formula.pdf
""" # noqa
if a.dtype.kind != "f":
raise ValueError("expected real-valued floating point matrices")
kernel = _get_real_symmetric_3x3_eigvals_kernel(
sort=sort, abs_sort=abs_sort
)
eigs = cp.empty((3,) + a.shape, dtype=a.dtype)
kernel(a, b, c, d, e, f, eigs[0], eigs[1], eigs[2])
return eigs


def _symmetric_compute_eigenvalues(S_elems, sort='descending', abs_sort=False):
"""Compute eigenvalues from the upperdiagonal entries of a symmetric matrix
Parameters
----------
S_elems : list of ndarray
The upper-diagonal elements of the matrix, as returned by
`hessian_matrix` or `structure_tensor`.
sort : {"ascending", "descending"}, optional
Eigenvalues should be sorted in the specified order.
abs_sort : boolean, optional
If ``True``, sort based on the absolute values.
Returns
-------
Expand All @@ -445,14 +655,26 @@ def _symmetric_compute_eigenvalues(S_elems):
ith-largest eigenvalue at position (j, k).
"""

if len(S_elems) == 3: # Use fast Cython code for 2D
eigs = cp.stack(_image_orthogonal_matrix22_eigvals(*S_elems))
if len(S_elems) == 3: # Use fast analytical kernel for 2D
eigs = _image_orthogonal_matrix22_eigvals(
*S_elems, sort=sort, abs_sort=abs_sort
)
elif len(S_elems) == 6: # Use fast analytical kernel for 3D
eigs = _image_orthogonal_matrix33_eigvals(
*S_elems, sort=sort, abs_sort=abs_sort
)
else:
# n-dimensional case. warning: extremely memory inefficient!
matrices = _symmetric_image(S_elems)
# eigvalsh returns eigenvalues in increasing order. We want decreasing
eigs = cp.linalg.eigvalsh(matrices)[..., ::-1]
eigs = cp.linalg.eigvalsh(matrices)
leading_axes = tuple(range(eigs.ndim - 1))
eigs = cp.transpose(eigs, (eigs.ndim - 1,) + leading_axes)
if abs_sort:
# (sort by magnitude)
eigs = cp.take_along_axis(eigs, cp.abs(eigs).argsort(0), 0)
if sort == 'descending':
eigs = eigs[::-1, ...]
return eigs


Expand Down Expand Up @@ -521,18 +743,6 @@ def structure_tensor_eigenvalues(A_elems):
return _symmetric_compute_eigenvalues(A_elems)


"""
TODO: add an _image_symmetric_real33_eigvals() based on:
Oliver K. Smith. 1961.
Eigenvalues of a symmetric 3 × 3 matrix.
Commun. ACM 4, 4 (April 1961), 168.
DOI:https://doi.org/10.1145/355578.366316
def _image_symmetric_real33_eigvals(M00, M01, M02, M11, M12, M22):
"""


def hessian_matrix_eigvals(H_elems):
"""Compute eigenvalues of Hessian matrix.
Expand Down
26 changes: 26 additions & 0 deletions python/cucim/src/cucim/skimage/feature/tests/test_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
peak_local_max, shape_index,
structure_tensor,
structure_tensor_eigenvalues)
from cucim.skimage.feature.corner import _symmetric_image
from cucim.skimage.morphology import cube


Expand Down Expand Up @@ -252,6 +253,31 @@ def test_hessian_matrix_eigvals_3d(im3d, dtype):
assert np.max(response0) > 0


def _reference_eigvals_computation(S_elems):
"""Legacy eigenvalue implementation based on cp.linalg.eigvalsh."""
matrices = _symmetric_image(S_elems)
# eigvalsh returns eigenvalues in increasing order. We want decreasing
eigs = cp.linalg.eigvalsh(matrices)[..., ::-1]
leading_axes = tuple(range(eigs.ndim - 1))
eigs = cp.transpose(eigs, (eigs.ndim - 1,) + leading_axes)
return eigs


@pytest.mark.parametrize(
'shape', [(64, 64), (512, 1024), (8, 16, 24)]
)
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_custom_eigvals_kernels_vs_linalg_eigvalsh(shape, dtype):
rng = cp.random.default_rng(seed=5)
img = rng.integers(0, 256, shape)
H = hessian_matrix(img)
H = tuple(h.astype(dtype, copy=False) for h in H)
evs1 = _reference_eigvals_computation(H)
evs2 = hessian_matrix_eigvals(H)
atol = 1e-10
cp.testing.assert_allclose(evs1, evs2, atol=atol)


def test_hessian_matrix_det():
image = cp.zeros((5, 5))
image[2, 2] = 1
Expand Down
Loading

0 comments on commit 762d739

Please sign in to comment.