Skip to content

Commit

Permalink
Resolve #478 (#481)
Browse files Browse the repository at this point in the history
* Update link

* Add standard keyword arguments for operators and linear operators constructed from functions
  • Loading branch information
bwohlberg authored Dec 12, 2023
1 parent 55e4359 commit 98ebb89
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 3 deletions.
10 changes: 9 additions & 1 deletion scico/linop/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def linop_from_function(f: Callable, classname: str, f_name: Optional[str] = Non
implements complex-valued operations, this must be a
complex dtype (typically :attr:`~numpy.complex64`) for
correct adjoint and gradient calculation.
output_shape: Shape of output array. Defaults to ``None``.
If ``None``, `output_shape` is determined by evaluating
`self.__call__` on an input array of zeros.
output_dtype: `dtype` for output argument. Defaults to
``None``. If ``None``, `output_dtype` is determined by
evaluating `self.__call__` on an input array of zeros.
jit: If ``True``, call :meth:`~.LinearOperator.jit` on this
:class:`LinearOperator` to jit the forward, adjoint, and
gram functions. Same as calling
Expand All @@ -62,12 +68,14 @@ def __init__(
input_shape: Union[Shape, BlockShape],
*args: Any,
input_dtype: DType = snp.float32,
output_shape: Optional[Union[Shape, BlockShape]] = None,
output_dtype: Optional[DType] = None,
jit: bool = True,
**kwargs: Any,
):
self._eval = lambda x: f(x, *args, **kwargs)
self.kwargs = kwargs
super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore
super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore

OpClass = type(classname, (LinearOperator,), {"__init__": __init__})
__class__ = OpClass # needed for super() to work
Expand Down
10 changes: 9 additions & 1 deletion scico/operator/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def operator_from_function(f: Callable, classname: str, f_name: Optional[str] =
implements complex-valued operations, this must be a
complex dtype (typically :attr:`~numpy.complex64`) for
correct adjoint and gradient calculation.
output_shape: Shape of output array. Defaults to ``None``.
If ``None``, `output_shape` is determined by evaluating
`self.__call__` on an input array of zeros.
output_dtype: `dtype` for output argument. Defaults to
``None``. If ``None``, `output_dtype` is determined by
evaluating `self.__call__` on an input array of zeros.
jit: If ``True``, call :meth:`.Operator.jit` on this
`Operator` to jit the forward, adjoint, and gram
functions. Same as calling :meth:`.Operator.jit` after
Expand All @@ -65,11 +71,13 @@ def __init__(
input_shape: Union[Shape, BlockShape],
*args: Any,
input_dtype: DType = snp.float32,
output_shape: Optional[Union[Shape, BlockShape]] = None,
output_dtype: Optional[DType] = None,
jit: bool = True,
**kwargs: Any,
):
self._eval = lambda x: f(x, *args, **kwargs)
super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore
super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore

OpClass = type(classname, (Operator,), {"__init__": __init__})
__class__ = OpClass # needed for super() to work
Expand Down
2 changes: 1 addition & 1 deletion scico/operator/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __repr__(self):
output_dtype : {self.output_dtype}
"""

# See https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__
# See https://numpy.org/doc/stable/user/c-info.beyond-basics.html#ndarray.__array_priority__
__array_priority__ = 1

def __init__(
Expand Down
10 changes: 10 additions & 0 deletions scico/test/linop/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def test_transpose():
np.testing.assert_array_equal(H.T @ H @ x, x)


def test_transpose_ext_init():
shape = (1, 2, 3, 4)
perm = (1, 0, 3, 2)
x, _ = randn(shape)
H = linop.Transpose(
shape, perm, input_dtype=snp.float32, output_shape=shape, output_dtype=snp.float32
)
np.testing.assert_array_equal(H @ x, x.transpose(perm))


def test_reshape():
shape = (1, 2, 3, 4)
newshape = (2, 12)
Expand Down
10 changes: 10 additions & 0 deletions scico/test/operator/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,16 @@ def test_make_func_op():
np.testing.assert_array_equal(H(x), snp.abs(x))


def test_make_func_op_ext_init():
AbsVal = operator_from_function(snp.abs, "AbsVal")
shape = (2,)
x, _ = randn(shape, dtype=np.float32)
H = AbsVal(
input_shape=shape, output_shape=shape, input_dtype=np.float32, output_dtype=np.float32
)
np.testing.assert_array_equal(H(x), snp.abs(x))


class TestJacobianProdReal:
def setup_method(self):
N = 7
Expand Down

0 comments on commit 98ebb89

Please sign in to comment.