Skip to content

Commit

Permalink
improved implementation of ridge filters (bug fixes and reduced memor…
Browse files Browse the repository at this point in the history
…y footprint) (#423)

related to #419

A large overhaul of the ridge filters, addressing inaccuracies and errors has been implemented for scikit-image 0.20. This PR ports the same changes to these functions to cuCIM.

upstream PRs: 
 - scikit-image/scikit-image#6149
 - scikit-image/scikit-image#6440 
 - scikit-image/scikit-image#6446
 - scikit-image/scikit-image#6509

These fix various bugs, simplify the implementation and reduce the memory footprint

Authors:
  - Gregory Lee (https://github.com/grlee77)

Approvers:
  - Gigon Bae (https://github.com/gigony)

URL: #423
  • Loading branch information
grlee77 authored Nov 16, 2022
1 parent f085f29 commit cebd032
Show file tree
Hide file tree
Showing 4 changed files with 430 additions and 290 deletions.
139 changes: 132 additions & 7 deletions python/cucim/src/cucim/skimage/feature/corner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import math
from itertools import combinations_with_replacement
from warnings import warn

import cupy as cp
import numpy as np
Expand All @@ -10,7 +11,7 @@

# from ..transform import integral_image
from .._shared._gradient import gradient
from .._shared.utils import _supported_float_type
from .._shared.utils import _supported_float_type, warn
from .peak import peak_local_max
from .util import _prepare_grayscale_input_nD

Expand Down Expand Up @@ -153,8 +154,96 @@ def structure_tensor(image, sigma=1, mode="constant", cval=0, order=None):
return A_elems


def hessian_matrix(image, sigma=1, mode='constant', cval=0, order='rc'):
"""Compute the Hessian matrix.
def _hessian_matrix_with_gaussian(image, sigma=1, mode='reflect', cval=0,
order='rc'):
"""Compute the Hessian via convolutions with Gaussian derivatives.
In 2D, the Hessian matrix is defined as:
H = [Hrr Hrc]
[Hrc Hcc]
which is computed by convolving the image with the second derivatives
of the Gaussian kernel in the respective r- and c-directions.
The implementation here also supports n-dimensional data.
Parameters
----------
image : ndarray
Input image.
sigma : float or sequence of float, optional
Standard deviation used for the Gaussian kernel, which sets the
amount of smoothing in terms of pixel-distances. It is
advised to not choose a sigma much less than 1.0, otherwise
aliasing artifacts may occur.
mode : {'constant', 'reflect', 'wrap', 'nearest', 'mirror'}, optional
How to handle values outside the image borders.
cval : float, optional
Used in conjunction with mode 'constant', the value outside
the image boundaries.
order : {'rc', 'xy'}, optional
This parameter allows for the use of reverse or forward order of
the image axes in gradient computation. 'rc' indicates the use of
the first axis initially (Hrr, Hrc, Hcc), whilst 'xy' indicates the
usage of the last axis initially (Hxx, Hxy, Hyy)
Returns
-------
H_elems : list of ndarray
Upper-diagonal elements of the hessian matrix for each pixel in the
input image. In 2D, this will be a three element list containing [Hrr,
Hrc, Hcc]. In nD, the list will contain ``(n**2 + n) / 2`` arrays.
"""
image = img_as_float(image)
float_dtype = _supported_float_type(image.dtype)
image = image.astype(float_dtype, copy=False)

if np.isscalar(sigma):
sigma = (sigma,) * image.ndim

# This function uses `scipy.ndimage.gaussian_filter` with the order
# argument to compute convolutions. For example, specifying
# ``order=[1, 0]`` would apply convolution with a first-order derivative of
# the Gaussian along the first axis and simple Gaussian smoothing along the
# second.

# For small sigma, the SciPy Gaussian filter suffers from aliasing and edge
# artifacts, given that the filter will approximate a sinc or sinc
# derivative which only goes to 0 very slowly (order 1/n**2). Thus, we use
# a much larger truncate value to reduce any edge artifacts.
truncate = 8 if all(s > 1 for s in sigma) else 100
sq1_2 = 1 / math.sqrt(2)
sigma_scaled = tuple(sq1_2 * s for s in sigma)
common_kwargs = dict(sigma=sigma_scaled, mode=mode, cval=cval,
truncate=truncate)
gaussian_ = functools.partial(ndi.gaussian_filter, **common_kwargs)

# Apply two successive first order Gaussian derivative operations, as
# detailed in:
# https://dsp.stackexchange.com/questions/78280/are-scipy-second-order-gaussian-derivatives-correct # noqa

# 1.) First order along one axis while smoothing (order=0) along the other
ndim = image.ndim

# orders in 2D = ([1, 0], [0, 1])
# in 3D = ([1, 0, 0], [0, 1, 0], [0, 0, 1])
# etc.
orders = tuple([0] * d + [1] + [0] * (ndim - d - 1) for d in range(ndim))
gradients = [gaussian_(image, order=orders[d]) for d in range(ndim)]

# 2.) apply the derivative along another axis as well
axes = range(ndim)
if order == 'rc':
axes = reversed(axes)
H_elems = [gaussian_(gradients[ax0], order=orders[ax1])
for ax0, ax1 in combinations_with_replacement(axes, 2)]
return H_elems


def hessian_matrix(image, sigma=1, mode='constant', cval=0, order='rc',
use_gaussian_derivatives=None):
r"""Compute the Hessian matrix.
In 2D, the Hessian matrix is defined as::
Expand Down Expand Up @@ -183,6 +272,9 @@ def hessian_matrix(image, sigma=1, mode='constant', cval=0, order='rc'):
the image axes in gradient computation. 'rc' indicates the use of
the first axis initially (Hrr, Hrc, Hcc), whilst 'xy' indicates the
usage of the last axis initially (Hxx, Hxy, Hyy)
use_gaussian_derivatives : boolean, optional
Indicates whether the Hessian is computed by convolving with Gaussian
derivatives, or by a simple finite-difference operation.
Returns
-------
Expand All @@ -191,13 +283,32 @@ def hessian_matrix(image, sigma=1, mode='constant', cval=0, order='rc'):
input image. In 2D, this will be a three element list containing [Hrr,
Hrc, Hcc]. In nD, the list will contain ``(n**2 + n) / 2`` arrays.
Notes
-----
The distributive property of derivatives and convolutions allows us to
restate the derivative of an image, I, smoothed with a Gaussian kernel, G,
as the convolution of the image with the derivative of G.
.. math::
\frac{\partial }{\partial x_i}(I * G) =
I * \left( \frac{\partial }{\partial x_i} G \right)
When ``use_gaussian_derivatives`` is ``True``, this property is used to
compute the second order derivatives that make up the Hessian matrix.
When ``use_gaussian_derivatives`` is ``False``, simple finite differences
on a Gaussian-smoothed image are used instead.
Examples
--------
>>> import cupy as cp
>>> from cucim.skimage.feature import hessian_matrix
>>> square = cp.zeros((5, 5))
>>> square[2, 2] = 4
>>> Hrr, Hrc, Hcc = hessian_matrix(square, sigma=0.1, order='rc')
>>> Hrr, Hrc, Hcc = hessian_matrix(square, sigma=0.1, order='rc',
... use_gaussian_derivatives=False)
>>> Hrc
array([[ 0., 0., 0., 0., 0.],
[ 0., 1., 0., -1., 0.],
Expand All @@ -211,8 +322,20 @@ def hessian_matrix(image, sigma=1, mode='constant', cval=0, order='rc'):
float_dtype = _supported_float_type(image.dtype)
image = image.astype(float_dtype, copy=False)

if use_gaussian_derivatives is None:
use_gaussian_derivatives = False
warn("use_gaussian_derivatives currently defaults to False, but will "
"change to True in a future version. Please specify this "
"argument explicitly to maintain the current behavior",
category=FutureWarning, stacklevel=2)

if use_gaussian_derivatives:
return _hessian_matrix_with_gaussian(image, sigma=sigma, mode=mode,
cval=cval, order=order)

# Autodetection as done internally to Gaussian, but set it here to silence
# a warning.
# TODO: eventually remove this as this behavior of gaussian is deprecated
channel_axis = -1 if (image.ndim == 3 and image.shape[-1] == 3) else None

gaussian_filtered = gaussian(image, sigma=sigma, mode=mode, cval=cval,
Expand Down Expand Up @@ -433,7 +556,8 @@ def hessian_matrix_eigvals(H_elems):
... hessian_matrix_eigvals)
>>> square = cp.zeros((5, 5))
>>> square[2, 2] = 4
>>> H_elems = hessian_matrix(square, sigma=0.1, order='rc')
>>> H_elems = hessian_matrix(square, sigma=0.1, order='rc',
... use_gaussian_derivatives=False)
>>> hessian_matrix_eigvals(H_elems)[0]
array([[ 0., 0., 2., 0., 0.],
[ 0., 1., 0., 1., 0.],
Expand Down Expand Up @@ -510,7 +634,8 @@ def shape_index(image, sigma=1, mode="constant", cval=0):
[ nan, nan, -0.5, nan, nan]])
"""

H = hessian_matrix(image, sigma=sigma, mode=mode, cval=cval, order="rc")
H = hessian_matrix(image, sigma=sigma, mode=mode, cval=cval, order="rc",
use_gaussian_derivatives=False)
l1, l2 = hessian_matrix_eigvals(H)

# don't warn on divide by 0 as occurs in the docstring example
Expand Down
16 changes: 12 additions & 4 deletions python/cucim/src/cucim/skimage/feature/tests/test_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def test_structure_tensor_sigma(ndim):
def test_hessian_matrix(dtype):
square = cp.zeros((5, 5), dtype=dtype)
square[2, 2] = 4
Hrr, Hrc, Hcc = hessian_matrix(square, sigma=0.1, order="rc")
Hrr, Hrc, Hcc = hessian_matrix(square, sigma=0.1, order="rc",
use_gaussian_derivatives=False)
out_dtype = _supported_float_type(dtype)
assert all(a.dtype == out_dtype for a in (Hrr, Hrc, Hcc))
# fmt: off
Expand All @@ -147,11 +148,17 @@ def test_hessian_matrix(dtype):
[0, 0, 2, 0, 0]])) # noqa
# fmt: on

with expected_warnings(["use_gaussian_derivatives currently defaults"]):
# FutureWarning warning when use_gaussian_derivatives is not
# specified.
hessian_matrix(square, sigma=0.1, order="rc")


def test_hessian_matrix_3d():
cube = cp.zeros((5, 5, 5))
cube[2, 2, 2] = 4
Hs = hessian_matrix(cube, sigma=0.1, order='rc')
Hs = hessian_matrix(cube, sigma=0.1, order='rc',
use_gaussian_derivatives=False)
assert len(Hs) == 6, "incorrect number of Hessian images (%i) for 3D" % len(
Hs
)
Expand Down Expand Up @@ -199,7 +206,8 @@ def test_structure_tensor_eigenvalues_3d():
def test_hessian_matrix_eigvals(dtype):
square = cp.zeros((5, 5), dtype=dtype)
square[2, 2] = 4
H = hessian_matrix(square, sigma=0.1, order='rc')
H = hessian_matrix(square, sigma=0.1, order='rc',
use_gaussian_derivatives=False)
l1, l2 = hessian_matrix_eigvals(H)
out_dtype = _supported_float_type(dtype)
assert all(a.dtype == out_dtype for a in (l1, l2))
Expand All @@ -220,7 +228,7 @@ def test_hessian_matrix_eigvals(dtype):
@pytest.mark.parametrize('dtype', [cp.float16, cp.float32, cp.float64])
def test_hessian_matrix_eigvals_3d(im3d, dtype):
im3d = im3d.astype(dtype, copy=False)
H = hessian_matrix(im3d)
H = hessian_matrix(im3d, use_gaussian_derivatives=False)
E = hessian_matrix_eigvals(H)
E = cp.asnumpy(E)
out_dtype = _supported_float_type(dtype)
Expand Down
Loading

0 comments on commit cebd032

Please sign in to comment.