From 98ebb8905be6c671960664cdad8bbbfd7fc90c02 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 12 Dec 2023 09:51:25 -0700 Subject: [PATCH] Resolve #478 (#481) * Update link * Add standard keyword arguments for operators and linear operators constructed from functions --- scico/linop/_func.py | 10 +++++++++- scico/operator/_func.py | 10 +++++++++- scico/operator/_operator.py | 2 +- scico/test/linop/test_func.py | 10 ++++++++++ scico/test/operator/test_operator.py | 10 ++++++++++ 5 files changed, 39 insertions(+), 3 deletions(-) diff --git a/scico/linop/_func.py b/scico/linop/_func.py index 1b0b8e0d4..55f972c9c 100644 --- a/scico/linop/_func.py +++ b/scico/linop/_func.py @@ -49,6 +49,12 @@ def linop_from_function(f: Callable, classname: str, f_name: Optional[str] = Non implements complex-valued operations, this must be a complex dtype (typically :attr:`~numpy.complex64`) for correct adjoint and gradient calculation. + output_shape: Shape of output array. Defaults to ``None``. + If ``None``, `output_shape` is determined by evaluating + `self.__call__` on an input array of zeros. + output_dtype: `dtype` for output argument. Defaults to + ``None``. If ``None``, `output_dtype` is determined by + evaluating `self.__call__` on an input array of zeros. jit: If ``True``, call :meth:`~.LinearOperator.jit` on this :class:`LinearOperator` to jit the forward, adjoint, and gram functions. Same as calling @@ -62,12 +68,14 @@ def __init__( input_shape: Union[Shape, BlockShape], *args: Any, input_dtype: DType = snp.float32, + output_shape: Optional[Union[Shape, BlockShape]] = None, + output_dtype: Optional[DType] = None, jit: bool = True, **kwargs: Any, ): self._eval = lambda x: f(x, *args, **kwargs) self.kwargs = kwargs - super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore + super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore OpClass = type(classname, (LinearOperator,), {"__init__": __init__}) __class__ = OpClass # needed for super() to work diff --git a/scico/operator/_func.py b/scico/operator/_func.py index a863b1487..e1ab0eec5 100644 --- a/scico/operator/_func.py +++ b/scico/operator/_func.py @@ -53,6 +53,12 @@ def operator_from_function(f: Callable, classname: str, f_name: Optional[str] = implements complex-valued operations, this must be a complex dtype (typically :attr:`~numpy.complex64`) for correct adjoint and gradient calculation. + output_shape: Shape of output array. Defaults to ``None``. + If ``None``, `output_shape` is determined by evaluating + `self.__call__` on an input array of zeros. + output_dtype: `dtype` for output argument. Defaults to + ``None``. If ``None``, `output_dtype` is determined by + evaluating `self.__call__` on an input array of zeros. jit: If ``True``, call :meth:`.Operator.jit` on this `Operator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`.Operator.jit` after @@ -65,11 +71,13 @@ def __init__( input_shape: Union[Shape, BlockShape], *args: Any, input_dtype: DType = snp.float32, + output_shape: Optional[Union[Shape, BlockShape]] = None, + output_dtype: Optional[DType] = None, jit: bool = True, **kwargs: Any, ): self._eval = lambda x: f(x, *args, **kwargs) - super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore + super().__init__(input_shape, input_dtype=input_dtype, output_shape=output_shape, output_dtype=output_dtype, jit=jit) # type: ignore OpClass = type(classname, (Operator,), {"__init__": __init__}) __class__ = OpClass # needed for super() to work diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index eabaf356b..eabdf47ee 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -66,7 +66,7 @@ def __repr__(self): output_dtype : {self.output_dtype} """ - # See https://docs.scipy.org/doc/numpy-1.10.1/user/c-info.beyond-basics.html#ndarray.__array_priority__ + # See https://numpy.org/doc/stable/user/c-info.beyond-basics.html#ndarray.__array_priority__ __array_priority__ = 1 def __init__( diff --git a/scico/test/linop/test_func.py b/scico/test/linop/test_func.py index ce6efca1b..2376ca2f5 100644 --- a/scico/test/linop/test_func.py +++ b/scico/test/linop/test_func.py @@ -19,6 +19,16 @@ def test_transpose(): np.testing.assert_array_equal(H.T @ H @ x, x) +def test_transpose_ext_init(): + shape = (1, 2, 3, 4) + perm = (1, 0, 3, 2) + x, _ = randn(shape) + H = linop.Transpose( + shape, perm, input_dtype=snp.float32, output_shape=shape, output_dtype=snp.float32 + ) + np.testing.assert_array_equal(H @ x, x.transpose(perm)) + + def test_reshape(): shape = (1, 2, 3, 4) newshape = (2, 12) diff --git a/scico/test/operator/test_operator.py b/scico/test/operator/test_operator.py index 1c86063ec..80ecc2082 100644 --- a/scico/test/operator/test_operator.py +++ b/scico/test/operator/test_operator.py @@ -233,6 +233,16 @@ def test_make_func_op(): np.testing.assert_array_equal(H(x), snp.abs(x)) +def test_make_func_op_ext_init(): + AbsVal = operator_from_function(snp.abs, "AbsVal") + shape = (2,) + x, _ = randn(shape, dtype=np.float32) + H = AbsVal( + input_shape=shape, output_shape=shape, input_dtype=np.float32, output_dtype=np.float32 + ) + np.testing.assert_array_equal(H(x), snp.abs(x)) + + class TestJacobianProdReal: def setup_method(self): N = 7