Skip to content

Commit

Permalink
Improve handling of h_center parameter in linop.CircularConvolve (#…
Browse files Browse the repository at this point in the history
…299)

* Support broader range of h_center values

* Add tests for h_center parameter
  • Loading branch information
bwohlberg authored May 17, 2022
1 parent 1d1154b commit 039489b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
30 changes: 22 additions & 8 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import math
import operator
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Sequence, Tuple, Union

import numpy as np

from jax.dtypes import result_type

from jaxlib.xla_extension import DeviceArray

import scico.numpy as snp
from scico._generic_operators import Operator
from scico.numpy.util import is_nested
Expand All @@ -27,10 +29,10 @@
class CircularConvolve(LinearOperator):
r"""A circular convolution linear operator.
This linear operator implements circular, n-dimensional convolution
via pointwise multiplication in the DFT domain. In its simplest form,
it implements a single convolution and can be represented by linear
operator :math:`H` such that
This linear operator implements circular, multi-dimensional
convolution via pointwise multiplication in the DFT domain. In its
simplest form, it implements a single convolution and can be
represented by linear operator :math:`H` such that
.. math::
H \mb{x} = \mb{h} \ast \mb{x} \;,
Expand Down Expand Up @@ -83,7 +85,7 @@ def __init__(
ndims: Optional[int] = None,
input_dtype: DType = snp.float32,
h_is_dft: bool = False,
h_center: Optional[JaxArray] = None,
h_center: Optional[Union[JaxArray, Sequence, float, int]] = None,
jit: bool = True,
**kwargs,
):
Expand All @@ -99,7 +101,8 @@ def __init__(
h_is_dft: Flag indicating whether `h` is in the DFT domain.
h_center: Array of length `ndims` specifying the center of
the filter. Defaults to the upper left corner, i.e.,
`h_center = [0, 0, ..., 0]`, may be noninteger.
`h_center = [0, 0, ..., 0]`, may be noninteger. May be a
``float`` or ``int`` if `h` is one-dimensional.
jit: If ``True``, jit the evaluation, adjoint, and gram
functions of the LinearOperator.
"""
Expand All @@ -124,7 +127,18 @@ def __init__(
output_dtype = result_type(h.dtype, input_dtype)

if self.h_center is not None:
offset = -self.h_center
if isinstance(self.h_center, DeviceArray):
offset = -self.h_center
else:
# support float or int values for h_center
if isinstance(self.h_center, (float, int)):
offset = -snp.array(
[
self.h_center,
]
)
else: # support list/tuple values for h_center
offset = -snp.array(self.h_center)
shifts: Tuple[Array, ...] = np.ix_(
*tuple(
np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s))
Expand Down
21 changes: 21 additions & 0 deletions scico/test/linop/test_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,27 @@ def test_matches_convolve(self, input_dtype, jit):
desired = A @ x
np.testing.assert_allclose(actual, desired, atol=1e-6)

@pytest.mark.parametrize(
"center",
[
1,
[
1,
],
snp.array([2]),
],
)
def test_center(self, center):
x, key = uniform(minval=-1, maxval=1, shape=(16,), key=self.key)
h = snp.array([0.5, 1.0, 0.25])
A = CircularConvolve(h=h, input_shape=x.shape, h_center=center)
B = CircularConvolve(h=h, input_shape=x.shape)
if isinstance(center, int):
shift = -center
else:
shift = -center[0]
np.testing.assert_allclose(A @ x, snp.roll(B @ x, shift), atol=1e-5)

@pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS)
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("jit_old_op", [True, False])
Expand Down

0 comments on commit 039489b

Please sign in to comment.