From 0101d986bd4760ec71ebc8b2be4e36249c466bfc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 13 Jan 2022 16:35:31 -0700 Subject: [PATCH 01/49] Address mypy errors --- scico/diagnostics.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/scico/diagnostics.py b/scico/diagnostics.py index 8d81eb53c..1fb34af5b 100644 --- a/scico/diagnostics.py +++ b/scico/diagnostics.py @@ -70,23 +70,23 @@ def __init__( if not isinstance(fields, dict): raise TypeError("Parameter fields must be an instance of dict") # Subsampling rate of results that are to be displayed - self.period = period + self.period: int = period # Flag indicating whether to display and overwrite, or not display at all - self.overwrite = overwrite + self.overwrite: bool = overwrite # Number of spaces seperating fields in displayed tables - self.colsep = colsep + self.colsep: int = colsep # Main list of inserted values - self.iterations = [] + self.iterations: List = [] # Total length of header string in displayed tables - self.headlength = 0 + self.headlength: int = 0 # List of field names - self.fieldname = [] + self.fieldname: List[str] = [] # List of field format strings - self.fieldformat = [] + self.fieldformat: List[str] = [] # List of lengths of each field in displayed tables - self.fieldlength = [] + self.fieldlength: List[int] = [] # Names of fields in namedtuple used to record iteration values - self.tuplefields = [] + self.tuplefields: List[str] = [] # Compile regex for decomposing format strings fmre = re.compile(r"%(\+?-?)((?:\d+)?)(\.?)((?:\d+)?)([a-z])") # Iterate over field names @@ -133,7 +133,7 @@ def __init__( self.headlength -= colsep # Construct namedtuple used to record values - self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) + self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) # type: ignore # Set up table header string display if requested self.display = display From 1f3af7451fc2fa5d5fd27c5ca641ea48f8dd03c4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 13 Jan 2022 16:46:05 -0700 Subject: [PATCH 02/49] Change action name --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 70b1b2242..6c03feeee 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -1,6 +1,6 @@ # Install scico requirements and run pytest -name: test +name: pytest # Controls when the workflow will run on: From 296c1f6264bb661aad0f3dc7c05d02e6cc2ce056 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 13 Jan 2022 16:56:47 -0700 Subject: [PATCH 03/49] Add mypy github action --- .github/workflows/mypy.yml | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 .github/workflows/mypy.yml diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 000000000..0b4510b4b --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,36 @@ +# Install and run mypy + +name: mypy + +# Controls when the workflow will run +on: + # Triggers the workflow on push or pull request events but only for the main branch + push: + branches: [ main ] + pull_request: + branches: [ main ] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + pytest: + # The type of runner that the job will run on + runs-on: ubuntu-latest + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + with: + submodules: recursive + - name: Install Python 3 + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Install dependencies + run: | + pip install mypy + - name: Run mypy + run: | + mypy --ignore-missing-imports scico/ From 68f34066bc777d3185f327e41cdfcd52ef340758 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 13 Jan 2022 17:09:55 -0700 Subject: [PATCH 04/49] Fix typing errors --- scico/typing.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scico/typing.py b/scico/typing.py index 15bc3ad50..1c2c7b3e9 100644 --- a/scico/typing.py +++ b/scico/typing.py @@ -9,6 +9,12 @@ from typing import Any, Tuple, Union +try: + # available in python 3.10 + from typing import EllipsisType # type: ignore +except ImportError: + EllipsisType = Any + import numpy as np import jax @@ -38,7 +44,7 @@ Axes = Union[int, Tuple[int, ...]] """Specification of one or more array axes.""" -AxisIndex = Union[slice, type(Ellipsis), int] +AxisIndex = Union[slice, EllipsisType, int] """An entity suitable for indexing/slicing of a single array axis; either a slice object, Ellipsis, or int.""" From 7c67fec3328f2aa84bb08ff87c0759ad970f03d1 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 13 Jan 2022 17:46:27 -0700 Subject: [PATCH 05/49] Fix typing errors --- scico/blockarray.py | 15 +++++++++------ scico/linop/_circconv.py | 6 +++--- scico/linop/_linop.py | 2 +- scico/typing.py | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/scico/blockarray.py b/scico/blockarray.py index 56efc15cb..7c7e691a1 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -496,7 +496,8 @@ def atleast_1d(*arys): # Append docstring from original jax.numpy function atleast_1d.__doc__ = ( - atleast_1d.__doc__.replace("\n ", "\n") # deal with indentation differences + # deal with indentation differences + atleast_1d.__doc__.replace("\n ", "\n") # type: ignore + "\nDocstring for :func:`jax.numpy.atleast_1d`:\n\n" + "\n".join(jax.numpy.atleast_1d.__doc__.split("\n")[2:]) ) @@ -584,7 +585,9 @@ def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: return np.prod(shape) -def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: +def _decompose_index( + idx: Union[int, Tuple[AxisIndex, ...]] +) -> Tuple[int, Union[None, Tuple[AxisIndex, ...]]]: """Decompose a BlockArray indexing expression into components. Decompose a BlockArray indexing expression into block and array @@ -604,8 +607,8 @@ def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: TypeError: If the block index is not an integer. """ if isinstance(idx, tuple): - idxblk = idx[0] - idxarr = idx[1:] + idxblk: int = idx[0] + idxarr: Union[None, Tuple[AxisIndex, ...]] = idx[1:] else: idxblk = idx idxarr = None @@ -614,7 +617,7 @@ def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple: return idxblk, idxarr -def indexed_shape(shape: Shape, idx: Union[int, Tuple(AxisIndex)]) -> Tuple[int]: +def indexed_shape(shape: Shape, idx: Union[int, Tuple[AxisIndex, ...]]) -> Tuple[int, ...]: """Determine the shape of the result of indexing a BlockArray. Args: @@ -866,7 +869,7 @@ def __init__(self, aval: _AbstractBlockArray, data: JaxArray): def __repr__(self): return "scico.blockarray.BlockArray: \n" + self._data.__repr__() - def __getitem__(self, idx: Union[int, Tuple(AxisIndex)]) -> JaxArray: + def __getitem__(self, idx: Union[int, Tuple[AxisIndex, ...]]) -> JaxArray: idxblk, idxarr = _decompose_index(idx) if idxblk < 0: idxblk = self.num_blocks + idxblk diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index 1d2df2f4f..e5377d26e 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -10,7 +10,7 @@ import math import operator from functools import partial -from typing import Optional +from typing import Optional, Tuple import numpy as np @@ -18,7 +18,7 @@ import scico.numpy as snp from scico._generic_operators import Operator -from scico.typing import DType, JaxArray, Shape +from scico.typing import Array, DType, JaxArray, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -129,7 +129,7 @@ def __init__( if h_center is not None: offset = -self.h_center - shifts = np.ix_( + shifts: Tuple[Array, ...] = np.ix_( *tuple( np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s)) for k, s in zip(offset, input_shape[-self.ndims :]) diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index f7d76b02d..61a4a93e6 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -286,7 +286,7 @@ def __init__( input_ndim = len(input_shape) sum_axis = util.parse_axes(sum_axis, shape=input_shape) - self.sum_axis: Tuple[int, ...] = sum_axis + self.sum_axis: Union[None, int, Tuple[int, ...]] = sum_axis super().__init__(input_shape=input_shape, input_dtype=input_dtype, jit=jit, **kwargs) def _eval(self, x: JaxArray) -> JaxArray: diff --git a/scico/typing.py b/scico/typing.py index 1c2c7b3e9..7f2f97d62 100644 --- a/scico/typing.py +++ b/scico/typing.py @@ -48,5 +48,5 @@ """An entity suitable for indexing/slicing of a single array axis; either a slice object, Ellipsis, or int.""" -ArrayIndex = Union[AxisIndex, Tuple[AxisIndex]] +ArrayIndex = Union[AxisIndex, Tuple[AxisIndex, ...]] """An entity suitable for indexing/slicing of multi-dimentional arrays.""" From 837214cb5ba21f1df8b74476ef72f5ccc298a0aa Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 13 Jan 2022 21:04:22 -0700 Subject: [PATCH 06/49] Fix typing errors --- scico/linop/_diff.py | 12 ++++++------ scico/linop/radon_astra.py | 4 ++-- scico/linop/radon_svmbir.py | 32 ++++++++++++++++---------------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 08d0c378a..1e639fd58 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -17,7 +17,7 @@ import numpy as np import scico.numpy as snp -from scico.typing import Axes, DType, JaxArray, Shape +from scico.typing import Axes, DType, JaxArray, Shape, Union from scico.util import parse_axes from ._linop import LinearOperator @@ -63,7 +63,7 @@ def __init__( operations, this must be `complex64` for proper adjoint and gradient calculation. axes: Axis or axes over which to apply finite difference - operator. If not specified, or `None`, differences are + operator. If not specified, or ``None``, differences are evaluated along all axes. append: Value to append to the input along each axis before taking differences. Zero is a typical choice. If not @@ -78,12 +78,12 @@ def __init__( self.axes = parse_axes(axes, input_shape) if axes is None: - axes_list = range(len(input_shape)) + axes_list: Union[range, list, tuple] = range(len(input_shape)) elif isinstance(axes, (list, tuple)): axes_list = axes else: axes_list = (axes,) - single_kwargs = dict(append=append, circular=circular, jit=False, input_dtype=input_dtype) + single_kwargs = dict(input_dtype=input_dtype, append=append, circular=circular, jit=False) ops = [FiniteDifferenceSingleAxis(axis, input_shape, **single_kwargs) for axis in axes_list] super().__init__( @@ -118,7 +118,7 @@ def __init__( taking differences. Defaults to 0. circular: If ``True``, perform circular differences, i.e., include x[-1] - x[0]. If ``True``, `append` must be - `None`. + ``None``. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ diff --git a/scico/linop/radon_astra.py b/scico/linop/radon_astra.py index f9f97af84..a64beef2b 100644 --- a/scico/linop/radon_astra.py +++ b/scico/linop/radon_astra.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -97,7 +97,7 @@ def __init__( "for specifics." ) else: - self.vol_geom: dict = astra.create_vol_geom(*input_shape) + self.vol_geom = astra.create_vol_geom(*input_shape) dev0 = jax.devices()[0] if dev0.device_kind == "cpu" or device == "cpu": diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 34b355233..252c94ab9 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -11,7 +11,7 @@ `svmbir `_ package. """ -from typing import Optional +from typing import Optional, Tuple, Union import numpy as np @@ -20,7 +20,7 @@ import scico.numpy as snp from scico.loss import WeightedSquaredL2Loss -from scico.typing import JaxArray, Shape +from scico.typing import Array, Shape from ._linop import LinearOperator @@ -46,9 +46,9 @@ class ParallelBeamProjector(LinearOperator): def __init__( self, input_shape: Shape, - angles: np.ndarray, + angles: Array, num_channels: int, - is_masked: Optional[bool] = False, + is_masked: bool = False, ): """ Args: @@ -56,7 +56,7 @@ def __init__( angles: Array of projection angles in radians, should be increasing. num_channels: Number of pixels in the sinogram. - is_masked: If True, the valid region of the image is + is_masked: If ``True``, the valid region of the image is determined by a mask defined as the circle inscribed within the image boundary. Otherwise, the whole image array is taken into account by projections. @@ -66,7 +66,7 @@ def __init__( if len(input_shape) == 2: # 2D input self.svmbir_input_shape = (1,) + input_shape - output_shape = (len(angles), num_channels) + output_shape: Tuple[int, ...] = (len(angles), num_channels) self.svmbir_output_shape = output_shape[0:1] + (1,) + output_shape[1:2] elif len(input_shape) == 3: # 3D input self.svmbir_input_shape = input_shape @@ -100,8 +100,8 @@ def __init__( @staticmethod def _proj( - x: JaxArray, angles: JaxArray, num_channels: int, roi_radius: Optional[float] = None - ) -> JaxArray: + x: Array, angles: Array, num_channels: int, roi_radius: Optional[float] = None + ) -> Array: return svmbir.project( np.array(x), np.array(angles), num_channels, verbose=0, roi_radius=roi_radius ) @@ -118,8 +118,8 @@ def _proj_hcb(self, x): @staticmethod def _bproj( - y: JaxArray, - angles: JaxArray, + y: Array, + angles: Array, num_rows: int, num_cols: int, roi_radius: Optional[float] = None, @@ -180,7 +180,7 @@ def __init__( A: Forward operator. scale: Scaling parameter. W: Weighting diagonal operator. Must be non-negative. - If None, defaults to :class:`.Identity`. + If ``None``, defaults to :class:`.Identity`. prox_kwargs: Dictionary of arguments passed to the :meth:`svmbir.recon` prox routine. Note that omitting this dictionary will cause the default dictionary to be @@ -209,13 +209,13 @@ def __init__( self.positivity = positivity - def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: + def prox(self, v: Array, lam: float, **kwargs) -> Array: v = v.reshape(self.A.svmbir_input_shape) y = self.y.reshape(self.A.svmbir_output_shape) weights = self.W.diagonal.reshape(self.A.svmbir_output_shape) sigma_p = snp.sqrt(lam) if "v0" in kwargs and kwargs["v0"] is not None: - v0 = np.reshape(np.array(kwargs["v0"]), self.A.svmbir_input_shape) + v0: Union[float, Array] = np.reshape(np.array(kwargs["v0"]), self.A.svmbir_input_shape) else: v0 = 0.0 @@ -241,8 +241,8 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: return jax.device_put(result.reshape(self.A.input_shape)) -def _unsqueeze(x: JaxArray, input_shape: Shape) -> JaxArray: - """If x is 2D, make it 3D according to SVMBIR's convention.""" +def _unsqueeze(x: Array, input_shape: Shape) -> Array: + """If x is 2D, make it 3D according to the SVMBIR convention.""" if len(input_shape) == 2: x = x[snp.newaxis, :, :] return x From 9de67238c413dc5b97e724edf989ac74ea829f4e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 13 Jan 2022 22:15:07 -0700 Subject: [PATCH 07/49] Correct job name --- .github/workflows/mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 0b4510b4b..100b68789 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -14,7 +14,7 @@ on: workflow_dispatch: jobs: - pytest: + mypy: # The type of runner that the job will run on runs-on: ubuntu-latest From 7cbea48e4484c284edb91cd3aef9e2c7d619d923 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 14 Jan 2022 18:44:55 -0700 Subject: [PATCH 08/49] Fix typing errors --- scico/array.py | 12 +++++------- scico/examples.py | 19 ++++++++++--------- scico/util.py | 12 ++++++------ 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/scico/array.py b/scico/array.py index a6410c16c..a3e91191f 100644 --- a/scico/array.py +++ b/scico/array.py @@ -19,9 +19,9 @@ from jax.interpreters.pxla import ShardedDeviceArray from jax.interpreters.xla import DeviceArray -import scico.blockarray import scico.numpy as snp -from scico.typing import ArrayIndex, Axes, AxisIndex, JaxArray, Shape +from scico.blockarray import BlockArray +from scico.typing import Array, ArrayIndex, Axes, AxisIndex, DType, JaxArray, Shape __author__ = """\n""".join( [ @@ -33,9 +33,7 @@ ) -def ensure_on_device( - *arrays: Union[np.ndarray, JaxArray, scico.blockarray.BlockArray] -) -> Union[JaxArray, scico.blockarray.BlockArray]: +def ensure_on_device(*arrays: Union[Array, BlockArray]) -> Union[JaxArray, BlockArray]: """Cast ndarrays to DeviceArrays. Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, @@ -70,7 +68,7 @@ def ensure_on_device( arrays[i] = jax.device_put(arrays[i]) elif not isinstance( array, - (DeviceArray, scico.blockarray.BlockArray, ShardedDeviceArray), + (DeviceArray, BlockArray, ShardedDeviceArray), ): raise TypeError( "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " @@ -167,7 +165,7 @@ def slice_length(length: int, slc: AxisIndex) -> int: return (stop - start + stride - 1) // stride -def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int]: +def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: """Determine the shape of an array after indexing/slicing. Args: diff --git a/scico/examples.py b/scico/examples.py index bab59e848..2f7f3ec02 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -12,6 +12,7 @@ import os import tempfile import zipfile +from typing import Tuple import numpy as np @@ -19,7 +20,7 @@ import scico.numpy as snp from scico import util -from scico.typing import JaxArray +from scico.typing import Array from scipy.ndimage import zoom __author__ = """\n""".join( @@ -27,7 +28,7 @@ ) -def volume_read(path: str, ext: str = "tif") -> JaxArray: +def volume_read(path: str, ext: str = "tif") -> Array: """Read a 3D volume from a set of files in the specified directory. All files with extension `ext` (i.e. matching glob `*.ext`) @@ -106,7 +107,7 @@ def get_epfl_deconv_data(channel: int, path: str, verbose: bool = False): # pra np.savez(npz_file, y=y, psf=psf) -def epfl_deconv_data(channel: int, verbose: bool = False, cache_path: str = None) -> JaxArray: +def epfl_deconv_data(channel: int, verbose: bool = False, cache_path: str = None) -> Array: """Get deconvolution problem data from EPFL Biomedical Imaging Group. If the data has previously been downloaded, it will be retrieved from @@ -144,7 +145,7 @@ def epfl_deconv_data(channel: int, verbose: bool = False, cache_path: str = None return y, psf -def downsample_volume(vol: JaxArray, rate: int) -> JaxArray: +def downsample_volume(vol: Array, rate: int) -> Array: """Downsample a 3D array. Downsample a 3D array. If the volume dimensions can be divided by @@ -173,7 +174,7 @@ def downsample_volume(vol: JaxArray, rate: int) -> JaxArray: return vol -def tile_volume_slices(x: JaxArray, sep_width: int = 10) -> JaxArray: +def tile_volume_slices(x: Array, sep_width: int = 10) -> Array: """Make an image with tiled slices from an input volume. Make an image with tiled `xy`, `xz`, and `yz` slices from an input @@ -190,7 +191,7 @@ def tile_volume_slices(x: JaxArray, sep_width: int = 10) -> JaxArray: """ if x.ndim == 3: - fshape = (x.shape[0], sep_width) + fshape: Tuple[int, ...] = (x.shape[0], sep_width) else: fshape = (x.shape[0], sep_width, 3) out = snp.concatenate( @@ -203,9 +204,9 @@ def tile_volume_slices(x: JaxArray, sep_width: int = 10) -> JaxArray: ) if x.ndim == 3: - fshape0 = (sep_width, out.shape[1]) - fshape1 = (x.shape[2], x.shape[2] + sep_width) - trans = (1, 0) + fshape0: Tuple[int, ...] = (sep_width, out.shape[1]) + fshape1: Tuple[int, ...] = (x.shape[2], x.shape[2] + sep_width) + trans: Tuple[int, ...] = (1, 0) else: fshape0 = (sep_width, out.shape[1], 3) diff --git a/scico/util.py b/scico/util.py index 5d14518c4..0ea1b0e49 100644 --- a/scico/util.py +++ b/scico/util.py @@ -16,7 +16,7 @@ import urllib.request as urlrequest from functools import wraps from timeit import default_timer as timer -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import jax from jax.interpreters.batching import BatchTracer @@ -130,8 +130,8 @@ def __init__( """ # Initialise current and accumulated time dictionaries - self.t0 = {} - self.td = {} + self.t0: Dict[str, Optional[float]] = {} + self.td: Dict[str, float] = {} # Record default label and string indicating all labels self.default_label = default_label self.all_label = all_label @@ -195,7 +195,7 @@ def stop(self, labels: Optional[Union[str, List[str]]] = None): # All timers are affected if label is equal to self.all_label, # otherwise only the timer(s) specified by label if labels == self.all_label: - labels = self.t0.keys() + labels = list(self.t0.keys()) elif not isinstance(labels, (list, tuple)): labels = [ labels, @@ -230,7 +230,7 @@ def reset(self, labels: Optional[Union[str, List[str]]] = None): # All timers are affected if label is equal to self.all_label, # otherwise only the timer(s) specified by label if labels == self.all_label: - labels = self.t0.keys() + labels = list(self.t0.keys()) elif not isinstance(labels, (list, tuple)): labels = [ labels, @@ -291,7 +291,7 @@ def labels(self) -> List[str]: List of timer labels. """ - return self.t0.keys() + return list(self.t0.keys()) def __str__(self) -> str: """Return string representation of object. From 9f8667d81f4225ad1c062da9558e8a884d3ae63f Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 19 Jan 2022 13:31:09 -0700 Subject: [PATCH 09/49] Fix typing errors --- scico/_generic_operators.py | 18 +++++++++++------- scico/array.py | 32 ++++++++++++++++++-------------- scico/random.py | 2 +- scico/util.py | 6 +++--- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 7cb059137..f784878a2 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -58,7 +58,7 @@ def wrapper(a, b): class Operator: - """Generic Operator class""" + """Generic Operator class.""" def __repr__(self): return f"""{type(self)} @@ -81,8 +81,7 @@ def __init__( jit: bool = False, is_smooth: bool = None, ): - r"""Operator init method. - + r""" Args: input_shape: Shape of input array. output_shape: Shape of output array. @@ -122,19 +121,22 @@ def __init__( #: Consists of (output_size, input_size) self.matrix_shape: Tuple[int, int] - #: Shape of Operator. Consists of (output_shape, input_shape). + #: Shape of Operator. Consists of (output_shape, input_shape). self.shape: Tuple[Union[Shape, BlockShape], Union[Shape, BlockShape]] #: Dtype of input self.input_dtype: DType + #: Dtype of operator + self.dtype: DType + if isinstance(input_shape, int): self.input_shape = (input_shape,) else: self.input_shape = input_shape self.input_dtype = input_dtype - # Allows for dynamic creation of new Operator/LinearOperator, eg for adjoints + # Allows for dynamic creation of new Operator/LinearOperator, eg. for adjoints if eval_fn: self._eval = eval_fn # type: ignore @@ -471,9 +473,9 @@ def __init__( ) if not hasattr(self, "_adj"): - self._adj = None + self._adj: Optional[Callable] = None if not hasattr(self, "_gram"): - self._gram = None + self._gram: Optional[Callable] = None if callable(adj_fn): self._adj = adj_fn self._gram = lambda x: self.adj(self(x)) @@ -621,6 +623,7 @@ def adj( f"""Shapes do not conform: input array with shape {y.shape} does not match LinearOperator output_shape {self.output_shape}""" ) + assert self._adj is not None return self._adj(y) @property @@ -731,6 +734,7 @@ def gram( """ if self._gram is None: self._set_adjoint() + assert self._gram is not None return self._gram(x) diff --git a/scico/array.py b/scico/array.py index a3e91191f..22acdb94a 100644 --- a/scico/array.py +++ b/scico/array.py @@ -33,7 +33,13 @@ ) -def ensure_on_device(*arrays: Union[Array, BlockArray]) -> Union[JaxArray, BlockArray]: +JaxOrBlockArray = Union[JaxArray, BlockArray] +"""A jax array or a BlockArray.""" + + +def ensure_on_device( + *arrays: Union[Array, BlockArray] +) -> Union[JaxOrBlockArray, List[JaxOrBlockArray]]: """Cast ndarrays to DeviceArrays. Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, @@ -53,36 +59,34 @@ def ensure_on_device(*arrays: Union[Array, BlockArray]) -> Union[JaxArray, Block TypeError: If the arrays contain something that is neither ndarray, DeviceArray, BlockArray, nor ShardedDeviceArray. """ - arrays = list(arrays) + array_list = list(arrays) - for i, array in enumerate(arrays): + for i, array in enumerate(array_list): if isinstance(array, np.ndarray): warnings.warn( - f"Argument {i+1} of {len(arrays)} is an np.ndarray. " + f"Argument {i+1} of {len(array_list)} is an np.ndarray. " f"Will cast it to DeviceArray. " - f"To suppress this warning cast all np.ndarrays to DeviceArray first.", + f"To suppress this warning cast all np.ndarray_list to DeviceArray first.", stacklevel=2, ) - arrays[i] = jax.device_put(arrays[i]) + array_list[i] = jax.device_put(array_list[i]) elif not isinstance( array, (DeviceArray, BlockArray, ShardedDeviceArray), ): raise TypeError( - "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " - f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." + "Each item of `array_list` must be ndarray, DeviceArray, BlockArray, or " + f"ShardedDeviceArray; Argument {i+1} of {len(array_list)} is {type(array_list[i])}." ) - if len(arrays) == 1: - return arrays[0] - return arrays + if len(array_list) == 1: + return array_list[0] + return array_list -def no_nan_divide( - x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] -) -> Union[BlockArray, JaxArray]: +def no_nan_divide(x: JaxOrBlockArray, y: JaxOrBlockArray) -> JaxOrBlockArray: """Return `x/y`, with 0 instead of NaN where `y` is 0. Args: diff --git a/scico/random.py b/scico/random.py index 2e56b90a9..e629e2124 100644 --- a/scico/random.py +++ b/scico/random.py @@ -210,4 +210,4 @@ def randn( - **x** : (DeviceArray): Generated random array. - **key** : Updated random PRNGKey. """ - return normal(shape, dtype, key, seed) + return normal(shape, dtype, key, seed) # type: ignore diff --git a/scico/util.py b/scico/util.py index 0ea1b0e49..6a1cb6f05 100644 --- a/scico/util.py +++ b/scico/util.py @@ -209,7 +209,7 @@ def stop(self, labels: Optional[Union[str, List[str]]] = None): if self.t0[lbl] is not None: # Increment time accumulator from the elapsed time # since most recent start call - self.td[lbl] += t - self.t0[lbl] + self.td[lbl] += t - self.t0[lbl] # type: ignore # Set start time to None to indicate timer is not running self.t0[lbl] = None @@ -278,7 +278,7 @@ def elapsed(self, label: Optional[str] = None, total: bool = True) -> float: # return just the time since the current start call te = 0.0 if self.t0[label] is not None: - te = t - self.t0[label] + te = t - self.t0[label] # type: ignore if total: te += self.td[label] @@ -320,7 +320,7 @@ def __str__(self) -> str: if self.t0[lbl] is None: ts = " Stopped" else: - ts = f" {(t - self.t0[lbl]):.2e} s" % (t - self.t0[lbl]) + ts = f" {(t - self.t0[lbl]):.2e} s" % (t - self.t0[lbl]) # type: ignore s += f"{lbl:{lfldln}s} {td:.2e} s {ts}\n" return s From f81b40bf3c19f68473f526f255162559668c7f52 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 19 Jan 2022 13:44:26 -0700 Subject: [PATCH 10/49] Fix typing error --- scico/linop/optics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 195869b0c..9c79ee06f 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -49,7 +49,7 @@ def radial_transverse_frequency( Returns: If `len(input_shape)==1`, returns an ndarray containing - corresponding Fourier coordinates. If `len(input_shape) == 2`, + corresponding Fourier coordinates. If `len(input_shape) == 2`, returns an ndarray containing the radial Fourier coordinates :math:`\sqrt{k_x^2 + k_y^2}\,`. """ @@ -59,7 +59,7 @@ def radial_transverse_frequency( raise ValueError("Invalid input dimensions; must be 1 or 2") if np.isscalar(dx): - dx = (dx,) * ndim + dx = (dx,) * ndim # type: ignore else: if len(dx) != ndim: raise ValueError( From 55ce35458114115f35e70e529cf0bb7ac5c679ea Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 27 Jan 2022 16:35:39 -0700 Subject: [PATCH 11/49] Fix merge error --- scico/linop/radon_svmbir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 413e1567b..5b02120ce 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -20,7 +20,7 @@ import scico.numpy as snp from scico.loss import Loss, WeightedSquaredL2Loss -from scico.typing import Array, Shape +from scico.typing import Array, JaxArray, Shape from ._linop import Diagonal, Identity, LinearOperator From ec3fd939b18d0d8ebd1a18131f31296680ed01c6 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Feb 2022 16:02:46 -0700 Subject: [PATCH 12/49] Address test failure --- scico/examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/examples.py b/scico/examples.py index 113893aff..d87868923 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -20,7 +20,7 @@ import scico.numpy as snp from scico import util -from scico.typing import Array +from scico.typing import Array, JaxArray from scipy.ndimage import zoom From 1ff915b4daa80a555ee6dc37af05c5b004d275bf Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Feb 2022 17:17:20 -0700 Subject: [PATCH 13/49] Fix typing errors --- scico/optimize/_ladmm.py | 6 +++--- scico/optimize/_primaldual.py | 6 +++--- scico/optimize/admm.py | 9 +++++---- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index 971d68a8a..bce3876f6 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -134,7 +134,7 @@ def __init__( "Time": "%8.2e", } itstat_attrib = ["itnum", "timer.elapsed()"] - # objective function can be evaluated if all 'g' functions can be evaluated + # objective function can be evaluated if 'g' function can be evaluated if g.has_eval: itstat_fields.update({"Objective": "%9.3e"}) itstat_attrib.append("objective()") @@ -144,7 +144,7 @@ def __init__( # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" - scope = {} + scope: dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) # determine itstat options and initialize IterationStats object @@ -155,7 +155,7 @@ def __init__( } if itstat_options: default_itstat_options.update(itstat_options) - self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore self.itstat_object = IterationStats(**default_itstat_options) if x0 is None: diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 17dd93268..4f3ecd48c 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -140,7 +140,7 @@ def __init__( "Time": "%8.2e", } itstat_attrib = ["itnum", "timer.elapsed()"] - # objective function can be evaluated if all 'g' functions can be evaluated + # objective function can be evaluated if 'g' function can be evaluated if g.has_eval: itstat_fields.update({"Objective": "%9.3e"}) itstat_attrib.append("objective()") @@ -150,7 +150,7 @@ def __init__( # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" - scope = {} + scope: dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) # determine itstat options and initialize IterationStats object @@ -161,7 +161,7 @@ def __init__( } if itstat_options: default_itstat_options.update(itstat_options) - self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore self.itstat_object = IterationStats(**default_itstat_options) if x0 is None: diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index 61cc1448a..53ec88c9a 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -80,7 +80,7 @@ def __init__(self, minimize_kwargs: dict = {"options": {"maxiter": 100}}): :func:`scico.solver.minimize`. """ self.minimize_kwargs = minimize_kwargs - self.info = {} + self.info: dict = {} def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """Solve the ADMM step. @@ -484,7 +484,7 @@ def __init__( # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" - scope = {} + scope: dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) # determine itstat options and initialize IterationStats object @@ -495,7 +495,7 @@ def __init__( } if itstat_options: default_itstat_options.update(itstat_options) - self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore self.itstat_object = IterationStats(**default_itstat_options) if x0 is None: @@ -534,6 +534,7 @@ def objective( if x is None: x = self.x z_list = self.z_list + assert z_list is not None out = 0.0 if self.f: out += self.f(x) @@ -598,7 +599,7 @@ def z_init(self, x0: Union[JaxArray, BlockArray]): Args: x0: Initial value of :math:`\mb{x}`. """ - z_list = [Ci(x0) for Ci in self.C_list] + z_list: List[Union[JaxArray, BlockArray]] = [Ci(x0) for Ci in self.C_list] z_list_old = z_list.copy() return z_list, z_list_old From 0ca1c8812ab6927f37c6882ce41aebc9a9d0da94 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Feb 2022 17:19:27 -0700 Subject: [PATCH 14/49] Fix typing errors and clean up itstat default mechanism --- scico/optimize/pgm.py | 50 ++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index 8be2cd539..ce5b0b826 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -180,8 +180,8 @@ def __init__(self, kappa: float = 0.5): self.kappa: float = kappa self.xprev: Union[JaxArray, BlockArray] = None self.gradprev: Union[JaxArray, BlockArray] = None - self.Lbb1prev: float = None - self.Lbb2prev: float = None + self.Lbb1prev: Optional[float] = None + self.Lbb2prev: Optional[float] = None def update(self, v: Union[JaxArray, BlockArray]) -> float: """Update the reciprocal of the step size. @@ -442,39 +442,41 @@ def x_step(v, L): self.x_step = jax.jit(x_step) + # iteration number and time fields + itstat_fields = { + "Iter": "%d", + "Time": "%8.2e", + } + itstat_attrib = ["itnum", "timer.elapsed()"] + # objective function can be evaluated if 'g' function can be evaluated if g.has_eval: - itstat_fields = { - "Iter": "%d", - "Time": "%8.2e", - "Objective": "%9.3e", - "L": "%9.3e", - "Residual": "%9.3e", - } - itstat_func = lambda pgm: ( - pgm.itnum, - pgm.timer.elapsed(), - pgm.objective(self.x), - pgm.L, - pgm.norm_residual(), - ) - else: - itstat_fields = {"Iter": "%d", "Time": "%8.2e", "Residual": "%9.3e"} - itstat_func = lambda pgm: (pgm.itnum, pgm.timer.elapsed(), pgm.norm_residual()) - - default_itstat_options = { + itstat_fields.update({"Objective": "%9.3e"}) + itstat_attrib.append("objective()") + # step size and residual fields + itstat_fields.update({"L": "%9.3e", "Residual": "%9.3e"}) + itstat_attrib.extend(["L", "norm_residual()"]) + + # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 + itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" + scope: dict[str, Callable] = {} + exec("def itstat_func(obj): " + itstat_return, scope) + + default_itstat_options: dict[str, Union[dict, Callable, bool]] = { "fields": itstat_fields, - "itstat_func": itstat_func, + "itstat_func": scope["itstat_func"], "display": False, } if itstat_options: default_itstat_options.update(itstat_options) - self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) + self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func") # type: ignore self.itstat_object = IterationStats(**default_itstat_options) self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution - def objective(self, x) -> float: + def objective(self, x=None) -> float: r"""Evaluate the objective function :math:`f(\mb{x}) + g(\mb{x})`.""" + if x is None: + x = self.x return self.f(x) + self.g(x) def f_quad_approx(self, x, y, L) -> float: From 06b66134ac6a7c59d4ab93fb07df7a11fc4fef56 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 4 Feb 2022 19:11:06 -0700 Subject: [PATCH 15/49] Fix a bug and some typing errors --- scico/functional/_functional.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 7b7206e3a..807f3905a 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -59,10 +59,8 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: x: Point at which to evaluate this functional. """ - if not self.has_eval: - raise NotImplementedError( - f"Functional {type(self)} cannot be evaluated; has_eval={self.has_eval}" - ) + # Functionals that can be evaluated should override this method. + raise NotImplementedError(f"Functional {type(self)} cannot be evaluated.") def prox( self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs @@ -89,10 +87,8 @@ def prox( minimizer in the defintion of :math:`\mathrm{prox}`. """ - if not self.has_prox: - raise NotImplementedError( - f"Functional {type(self)} does not have a prox; has_prox={self.has_prox}" - ) + # Functionals that have a prox should override this method. + raise NotImplementedError(f"Functional {type(self)} does not have a prox.") def conj_prox( self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs @@ -230,7 +226,7 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray: if len(v.shape) == len(self.functional_list): return BlockArray.array([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)]) raise ValueError( - f"Number of blocks in x, {len(x.shape)}, and length of functional_list, " + f"Number of blocks in v, {len(v.shape)}, and length of functional_list, " f"{len(self.functional_list)}, do not match" ) From b5379273efef70d7eaab254ca8307e378d313216 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Feb 2022 08:20:25 -0700 Subject: [PATCH 16/49] Exclude modules with dynamically generated functions --- .github/workflows/mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 100b68789..b87266c23 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -33,4 +33,4 @@ jobs: pip install mypy - name: Run mypy run: | - mypy --ignore-missing-imports scico/ + mypy --ignore-missing-imports --exclude numpy scico --exclude scipy scico/ From 36cfab2bd2838c0784f678956d1a354883f7c816 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Feb 2022 08:21:24 -0700 Subject: [PATCH 17/49] Make docstring phrasing imperative --- scico/numpy/_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/numpy/_util.py b/scico/numpy/_util.py index af38cdb1e..3b76eb271 100644 --- a/scico/numpy/_util.py +++ b/scico/numpy/_util.py @@ -65,7 +65,7 @@ def _attach_wrapped_func(funclist, wrapper, module_name, fix_mod_name=False): def _get_module_functions(module): - """Finds functions in module. + """Find functions in module. This function is a slightly modified version of :func:`jax._src.util.get_module_functions`. Unlike the JAX version, From acda6d8f880215d1d2ce557b6ac9a2c3976e699f Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Feb 2022 08:22:43 -0700 Subject: [PATCH 18/49] Suppress typing errors --- scico/numpy/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index c296b23f5..276826082 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -144,8 +144,8 @@ def vdot(a, b): ) # divide is just an alias to true_divide -divide = true_divide -conj = conjugate +divide = true_divide # type: ignore +conj = conjugate # type: ignore # Find functions that exist in jax.numpy but not scico.numpy # see jax.numpy.__init__.py From 69103486c9b32566d3a97bba0fd972c9397e114b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Feb 2022 08:24:36 -0700 Subject: [PATCH 19/49] Fix typing error and clean up --- scico/linop/_circconv.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index e5377d26e..e5b5251c9 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -22,14 +22,6 @@ from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar -__author__ = """\n""".join( - [ - "Brendt Wohlberg ", - "Luke Pfister ", - "Michael McCann ", - ] -) - class CircularConvolve(LinearOperator): r"""A circular convolution linear operator. @@ -127,7 +119,7 @@ def __init__( self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes) output_dtype = result_type(h.dtype, input_dtype) - if h_center is not None: + if self.h_center is not None: offset = -self.h_center shifts: Tuple[Array, ...] = np.ix_( *tuple( From 665fdc50459aa136e76f2b5f197dc184c11098b1 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Feb 2022 08:26:39 -0700 Subject: [PATCH 20/49] Fix typing errors --- scico/linop/optics.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/scico/linop/optics.py b/scico/linop/optics.py index f5c4f0f9a..476de908c 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -61,11 +61,13 @@ def radial_transverse_frequency( if np.isscalar(dx): dx = (dx,) * ndim # type: ignore else: + assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "dx must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}" ) + assert isinstance(dx, tuple) if ndim == 1: kx = 2 * np.pi * np.fft.fftfreq(input_shape[0], dx[0]) @@ -107,13 +109,15 @@ def __init__( raise ValueError("Invalid input dimensions; must be 1 or 2") if np.isscalar(dx): - dx = (dx,) * ndim + dx = (dx,) * ndim # type: ignore else: + assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "dx must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}" ) + assert isinstance(dx, tuple) #: Illumination wavenumber; 2π/wavelength self.k0: float = k0 @@ -435,13 +439,15 @@ def __init__( raise ValueError("Invalid input dimensions; must be 1 or 2") if np.isscalar(dx): - dx = (dx,) * ndim + dx = (dx,) * ndim # type: ignore else: + assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "dx must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}" ) + assert isinstance(dx, tuple) L: Tuple[float, ...] = tuple(s * d for s, d in zip(input_shape, dx)) @@ -458,7 +464,7 @@ def __init__( self.dx_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * l)) for l in L) #: Destination plane side length self.L_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * d)) for d in dx) - x_D = tuple(np.r_[-l / 2 : l / 2 : d] for l, d in zip(self.L_D, self.dx_D)) + x_D = tuple(np.r_[int(-l / 2) : int(l / 2) : int(d)] for l, d in zip(self.L_D, self.dx_D)) # set up radial coordinate system; either x^2 or (x^2 + y^2) if ndim == 1: From cf43c1405f33bf73ec10ed326215121a6c4316ec Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Feb 2022 08:29:32 -0700 Subject: [PATCH 21/49] Fix or suppress typing errors --- scico/functional/_denoiser.py | 4 ++-- scico/functional/_flax.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/scico/functional/_denoiser.py b/scico/functional/_denoiser.py index 981da0e82..e3b750711 100644 --- a/scico/functional/_denoiser.py +++ b/scico/functional/_denoiser.py @@ -56,7 +56,7 @@ def __init__(self, is_rgb: bool = False): super().__init__() - def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: + def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore r"""Apply BM3D denoiser. Args: @@ -147,7 +147,7 @@ def __init__(self, variant: Optional[str] = "6M"): variables = load_weights(_flax_data_path("dncnn%s.npz" % variant)) super().__init__(model, variables) - def prox(self, x: JaxArray, lam: float = 1, **kwargs) -> JaxArray: + def prox(self, x: JaxArray, lam: float = 1, **kwargs) -> JaxArray: # type: ignore r"""Apply DnCNN denoiser. *Warning*: The `lam` parameter is ignored, and has no effect on diff --git a/scico/functional/_flax.py b/scico/functional/_flax.py index 82e36e825..c5d8cef16 100644 --- a/scico/functional/_flax.py +++ b/scico/functional/_flax.py @@ -7,7 +7,7 @@ """Evaluate NN models implemented in flax.""" -from typing import Any, Callable +from typing import Any from flax import linen as nn from scico.blockarray import BlockArray @@ -24,7 +24,7 @@ class FlaxMap(Functional): has_eval = False has_prox = True - def __init__(self, model: Callable[..., nn.Module], variables: PyTree): + def __init__(self, model: nn.Module, variables: PyTree): r"""Initialize a :class:`FlaxMap` object. Args: @@ -35,7 +35,7 @@ def __init__(self, model: Callable[..., nn.Module], variables: PyTree): self.variables = variables super().__init__() - def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: + def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore r"""Apply trained flax model. *Warning*: The ``lam`` parameter is ignored, and has no effect on From 8d2b0ccb3f20bea4b9c82274503ebdaa1825b71e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 5 Feb 2022 08:32:57 -0700 Subject: [PATCH 22/49] Typo fix --- .github/workflows/mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index b87266c23..d632a19b1 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -33,4 +33,4 @@ jobs: pip install mypy - name: Run mypy run: | - mypy --ignore-missing-imports --exclude numpy scico --exclude scipy scico/ + mypy --ignore-missing-imports --exclude numpy --exclude scipy scico/ From 67e6b248def8fe635741091f483e3bdbdb171a4d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 7 Feb 2022 19:37:27 -0700 Subject: [PATCH 23/49] Revert erroneous attempt to resolve typing error --- scico/linop/optics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 476de908c..48020a0aa 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -464,7 +464,7 @@ def __init__( self.dx_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * l)) for l in L) #: Destination plane side length self.L_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * d)) for d in dx) - x_D = tuple(np.r_[int(-l / 2) : int(l / 2) : int(d)] for l, d in zip(self.L_D, self.dx_D)) + x_D = tuple(np.r_[-l / 2 : l / 2 : d] for l, d in zip(self.L_D, self.dx_D)) # type: ignore # set up radial coordinate system; either x^2 or (x^2 + y^2) if ndim == 1: From 58ad1dd60a7bdc94882297f57c821769b18b1376 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 8 Feb 2022 09:49:53 -0700 Subject: [PATCH 24/49] Typing annotation fix and suppress some spurious typing errors --- scico/loss.py | 12 ++++++------ scico/optimize/admm.py | 22 +++++++++++----------- scico/optimize/pgm.py | 2 +- scico/solver.py | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index ba31a8732..b7fc742d4 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -131,10 +131,10 @@ def __init__( Args: y: Measurement. - A: Forward operator. If None, defaults to :class:`.Identity`. + A: Forward operator. If ``None``, defaults to :class:`.Identity`. scale: Scaling parameter. - W: Weighting diagonal operator. Must be non-negative. - If None, defaults to :class:`.Identity`. + W: Weighting diagonal operator. Must be non-negative. + If ``None``, defaults to :class:`.Identity`. """ y = ensure_on_device(y) @@ -210,8 +210,8 @@ def hessian(self) -> linop.LinearOperator: return linop.LinearOperator( input_shape=A.input_shape, output_shape=A.input_shape, - eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), - adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), + eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore + adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore ) raise NotImplementedError( @@ -278,7 +278,7 @@ def __init__( super().__init__(y=y, A=A, scale=scale) #: Constant term in Poisson log likehood; equal to ln(y!) - self.const: float = gammaln(self.y + 1) # ln(y!) + self.const = gammaln(self.y + 1.0) # type: ignore def __call__(self, x: Union[JaxArray, BlockArray]) -> float: Ax = self.A(x) diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index 53ec88c9a..dca8e6fbe 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -12,7 +12,7 @@ from __future__ import annotations from functools import reduce -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import jax from jax.scipy.sparse.linalg import cg as jax_cg @@ -157,7 +157,7 @@ class LinearSubproblemSolver(SubproblemSolver): :math:`\mb{x}` update step. """ - def __init__(self, cg_kwargs: Optional[dict] = None, cg_function: str = "scico"): + def __init__(self, cg_kwargs: Optional[dict[str, Any]] = None, cg_function: str = "scico"): """Initialize a :class:`LinearSubproblemSolver` object. Args: @@ -196,16 +196,16 @@ def __init__(self, cg_kwargs: Optional[dict] = None, cg_function: str = "scico") def internal_init(self, admm): if admm.f is not None: - if not isinstance(admm.f.A, LinearOperator): - raise ValueError( - f"LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; " - f"got {type(admm.f.A)}" - ) if not isinstance(admm.f, WeightedSquaredL2Loss): # SquaredL2Loss is subclass raise ValueError( f"LinearSubproblemSolver requires f to be a scico.loss.WeightedSquaredL2Loss" f"or scico.loss.SquaredL2Loss; got {type(admm.f)}" ) + if not isinstance(admm.f.A, LinearOperator): + raise ValueError( + f"LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; " + f"got {type(admm.f.A)}" + ) super().internal_init(admm) @@ -239,8 +239,8 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]: rhs = snp.zeros(C0.input_shape, C0.input_dtype) if self.admm.f is not None: - ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal * self.admm.f.y) - rhs += 2.0 * self.admm.f.scale * ATWy + ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal * self.admm.f.y) # type: ignore + rhs += 2.0 * self.admm.f.scale * ATWy # type: ignore for rhoi, Ci, zi, ui in zip( self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list @@ -259,7 +259,7 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """ x0 = ensure_on_device(x0) rhs = self.compute_rhs() - x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs) + x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs) # type: ignore return x @@ -496,7 +496,7 @@ def __init__( if itstat_options: default_itstat_options.update(itstat_options) self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore if x0 is None: input_shape = C_list[0].input_shape diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index ce5b0b826..8077f7f0b 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -469,7 +469,7 @@ def x_step(v, L): if itstat_options: default_itstat_options.update(itstat_options) self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func") # type: ignore - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution diff --git a/scico/solver.py b/scico/solver.py index 255af9b11..5d02081cb 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -298,7 +298,7 @@ def cg( maxiter: int = 1000, info: bool = False, M: Optional[Callable] = None, -) -> Union[JaxArray, dict]: +) -> Tuple[JaxArray, dict]: r"""Conjugate Gradient solver. Solve the linear system :math:`A\mb{x} = \mb{b}`, where :math:`A` is From 08f078a505c2a44841262556c42e8f7356ea116d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 8 Feb 2022 18:20:35 -0700 Subject: [PATCH 25/49] Address typing error and rephrase error messages --- scico/linop/radon_svmbir.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 5b02120ce..240ae2786 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -226,7 +226,7 @@ def __init__( super().__init__(*args, **kwargs) if not isinstance(self.A, ParallelBeamProjector): - raise ValueError("LinearOperator A must be a radon_svmbir.ParallelBeamProjector.") + raise ValueError("A must be a ParallelBeamProjector.") self.has_prox = True @@ -331,10 +331,8 @@ def __init__( """ super().__init__(*args, **kwargs, prox_kwargs=prox_kwargs, positivity=False) - if self.A.is_masked: - raise ValueError( - "is_masked must be false for the ParallelBeamProjector in SVMBIRWeightedSquaredL2Loss." - ) + if not isinstance(self.A, ParallelBeamProjector) or self.A.is_masked: + raise ValueError("A must be a ParallelBeamProjector with is_masked set to False.") def _unsqueeze(x: JaxArray, input_shape: Shape) -> JaxArray: From 9bfeb3c59b851851531f1353fbf85e803ffea5c9 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 11 Feb 2022 17:02:41 -0700 Subject: [PATCH 26/49] Fix some typing errors --- scico/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scico/array.py b/scico/array.py index b4e4575da..984af9a6e 100644 --- a/scico/array.py +++ b/scico/array.py @@ -92,7 +92,7 @@ def no_nan_divide(x: JaxOrBlockArray, y: JaxOrBlockArray) -> JaxOrBlockArray: def parse_axes( axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None -) -> List[int]: +) -> Tuple[int, ...]: """Normalize `axes` to a list and optionally ensure correctness. Normalize `axes` to a list and (optionally) ensure that entries refer @@ -183,14 +183,14 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: idx = (idx,) if len(idx) > len(shape): raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") - idx_shape = list(shape) + idx_shape: List[Optional[int]] = list(shape) offset = 0 for axis, ax_idx in enumerate(idx): if ax_idx is Ellipsis: offset = len(shape) - len(idx) continue idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) - return tuple(filter(lambda x: x is not None, idx_shape)) + return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore def is_nested(x: Any) -> bool: From ebc439225a34eccb6880d84510446d6370648bef Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 11 Feb 2022 17:03:19 -0700 Subject: [PATCH 27/49] Supress some typing errors --- scico/loss.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index c994e2da8..f5653c306 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -64,7 +64,7 @@ def __init__( self.y = ensure_on_device(y) if A is None: # y and x must have same shape - A = linop.Identity(self.y.shape) + A = linop.Identity(self.y.shape) # type: ignore self.A = A self.scale = scale @@ -139,9 +139,9 @@ def __init__( self.W: linop.Diagonal if W is None: - self.W = linop.Identity(y.shape) + self.W = linop.Identity(y.shape) # type: ignore elif isinstance(W, linop.Diagonal): - if snp.all(W.diagonal >= 0): + if snp.all(W.diagonal >= 0): # type: ignore self.W = W else: raise ValueError(f"The weights, W.diagonal, must be non-negative.") @@ -158,7 +158,7 @@ def __init__( self.has_prox = True def __call__(self, x: Union[JaxArray, BlockArray]) -> float: - return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() + return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() # type: ignore def prox( self, v: Union[JaxArray, BlockArray], lam: float, **kwargs @@ -173,8 +173,8 @@ def prox( c = 2.0 * self.scale * lam A = self.A.diagonal W = self.W.diagonal - lhs = c * A.conj() * W * self.y + v - ATWA = c * A.conj() * W * A + lhs = c * A.conj() * W * self.y + v # type: ignore + ATWA = c * A.conj() * W * A # type: ignore return lhs / (ATWA + 1.0) # prox_{f}(v) = arg min 1/2 || v - x ||^2 + λ α || A x - y ||^2_W From f9b6cda9d23cf66055f504693021de5cc58d61c5 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 11 Feb 2022 17:03:48 -0700 Subject: [PATCH 28/49] Address typing error --- scico/linop/optics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 48020a0aa..276c8302e 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -138,7 +138,7 @@ def __init__( self.F = DFT(input_shape=input_shape, output_shape=self.padded_shape, jit=False) # Diagonal operator; phase shifting - self.D = Identity(self.kp.shape) + self.D: LinearOperator = Identity(self.kp.shape) super().__init__( input_shape=input_shape, From 3ef307bd309333a71526cd481f72c173219d5810 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 11 Feb 2022 17:04:19 -0700 Subject: [PATCH 29/49] Fix typing errors and docstring style issues --- scico/examples.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/scico/examples.py b/scico/examples.py index 238be19d6..2f9d23318 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -242,18 +242,20 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array: return out -def create_cone(img_shape: Shape, center: Optional[list] = None): +def create_cone(img_shape: Shape, center: Optional[list[float]] = None) -> Array: """Compute a 2D map of the distance from a center pixel. Args: - img_shape : Shape of the image for which the distance map is being computed. - center : Tuple of center pixel ids. If None, this is set to the center of the image + img_shape: Shape of the image for which the distance map is being + computed. + center: Tuple of center pixel ids. If ``None``, this is set to + the center of the image. Returns: - An image containing a 2D map of the distances + An image containing a 2D map of the distances. """ - if center == None: + if center is None: center = [img_dim // 2 for img_dim in img_shape] coords = [snp.arange(0, img_dim) for img_dim in img_shape] @@ -267,17 +269,18 @@ def create_cone(img_shape: Shape, center: Optional[list] = None): def create_circular_phantom( img_shape: Shape, radius_list: list, val_list: list, center: Optional[list] = None -): - """Construct a circular phantom with given radii and intensities +) -> Array: + """Construct a circular phantom with given radii and intensities. Args: - img_shape : Shape of the phontom to be created - radius_list : List of radii of the rings in the phantom - val_list : List of intensity values of the rings in the phantom - center : Tuple of center pixel ids. If None, this is set to the center of the image + img_shape: Shape of the phantom to be created. + radius_list: List of radii of the rings in the phantom. + val_list: List of intensity values of the rings in the phantom. + center: Tuple of center pixel ids. If ``None``, this is set to + the center of the image. Returns: - The computed circular phantom + The computed circular phantom. """ dist_map = create_cone(img_shape, center) From 88100e50643eac26bcf90b43371387253f71d483 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 11 Feb 2022 17:18:41 -0700 Subject: [PATCH 30/49] Address test failure --- scico/examples.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/examples.py b/scico/examples.py index 2f9d23318..82c3c2707 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -12,7 +12,7 @@ import os import tempfile import zipfile -from typing import Optional, Tuple +from typing import List, Optional import numpy as np @@ -242,7 +242,7 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array: return out -def create_cone(img_shape: Shape, center: Optional[list[float]] = None) -> Array: +def create_cone(img_shape: Shape, center: Optional[List[float]] = None) -> Array: """Compute a 2D map of the distance from a center pixel. Args: From 60cc81a2972df95ad14cbdd24903ea5fa6a93b1d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 18 Feb 2022 16:14:20 -0700 Subject: [PATCH 31/49] Suppress/address some typing errors --- scico/examples.py | 2 +- scico/linop/_circconv.py | 10 +++++++--- scico/linop/_diff.py | 9 ++++----- scico/linop/_stack.py | 4 ++-- scico/linop/abel.py | 2 +- scico/optimize/_ladmm.py | 8 +++++--- scico/optimize/_primaldual.py | 4 ++-- scico/optimize/admm.py | 8 +++++--- 8 files changed, 27 insertions(+), 20 deletions(-) diff --git a/scico/examples.py b/scico/examples.py index 3241b6a2c..59506f7ac 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -12,7 +12,7 @@ import os import tempfile import zipfile -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index f07f10290..adcfd2a42 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -18,6 +18,7 @@ import scico.numpy as snp from scico._generic_operators import Operator +from scico.array import is_nested from scico.typing import Array, DType, JaxArray, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -172,7 +173,7 @@ def _eval(self, x: JaxArray) -> JaxArray: hx = hx.real return hx - def _adj(self, x: JaxArray) -> JaxArray: + def _adj(self, x: JaxArray) -> JaxArray: # type: ignore x_dft = snp.fft.fftn(x, axes=self.ifft_axes) H_adj_x = snp.fft.ifftn( snp.conj(self.h_dft) * x_dft, @@ -251,13 +252,16 @@ def from_operator( jit: If ``True``, jit the resulting `CircularConvolve`. """ + if is_nested(H.input_shape): + raise ValueError("Operator H may not take BlockArray input.") + if ndims is None: ndims = len(H.input_shape) else: ndims = ndims if center is None: - center = tuple(d // 2 for d in H.input_shape[-ndims:]) + center = tuple(d // 2 for d in H.input_shape[-ndims:]) # type: ignore # compute impulse response d = snp.zeros(H.input_shape, H.input_dtype) @@ -267,7 +271,7 @@ def from_operator( # build CircularConvolve return CircularConvolve( Hd, - H.input_shape, + H.input_shape, # type: ignore ndims=ndims, input_dtype=H.input_dtype, h_center=snp.array(center), diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index d7f13add9..13be24426 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -18,7 +18,7 @@ import scico.numpy as snp from scico.array import parse_axes -from scico.typing import Axes, DType, JaxArray, Shape, Union +from scico.typing import Axes, DType, JaxArray, Shape from ._linop import LinearOperator from ._stack import LinearOperatorStack @@ -73,19 +73,18 @@ def __init__( functions of the LinearOperator. """ - self.axes = parse_axes(axes, input_shape) - if axes is None: - axes_list: Union[range, list, tuple] = range(len(input_shape)) + axes_list = tuple(range(len(input_shape))) elif isinstance(axes, (list, tuple)): axes_list = axes else: axes_list = (axes,) + self.axes = parse_axes(axes_list, input_shape) single_kwargs = dict(input_dtype=input_dtype, append=append, circular=circular, jit=False) ops = [FiniteDifferenceSingleAxis(axis, input_shape, **single_kwargs) for axis in axes_list] super().__init__( - ops, + ops, # type: ignore jit=jit, **kwargs, ) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 8b09642ca..65c427056 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -35,7 +35,7 @@ def __init__( r""" Args: ops: Operators to stack. - collapse: If `True` and the output would be a `BlockArray` + collapse: If ``True`` and the output would be a `BlockArray` with shape ((m, n, ...), (m, n, ...), ...), the output is instead a `DeviceArray` with shape (S, m, n, ...) where S is the length of `ops`. Defaults to True. @@ -87,7 +87,7 @@ def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]: return snp.stack([op @ x for op in self.ops]) return BlockArray.array([op @ x for op in self.ops]) - def _adj(self, y: Union[JaxArray, BlockArray]) -> JaxArray: + def _adj(self, y: Union[JaxArray, BlockArray]) -> JaxArray: # type: ignore return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) def scale_ops(self, scalars: JaxArray): diff --git a/scico/linop/abel.py b/scico/linop/abel.py index cd0119f27..500048bf8 100644 --- a/scico/linop/abel.py +++ b/scico/linop/abel.py @@ -56,7 +56,7 @@ def _eval(self, x: JaxArray) -> JaxArray: self.output_dtype ) - def _adj(self, x: JaxArray) -> JaxArray: + def _adj(self, x: JaxArray) -> JaxArray: # type: ignore return _pyabel_transform(x, direction="transpose", proj_mat_quad=self.proj_mat_quad).astype( self.input_dtype ) diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index bce3876f6..cc256e3da 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -156,7 +156,7 @@ def __init__( if itstat_options: default_itstat_options.update(itstat_options) self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore if x0 is None: input_shape = C.input_shape @@ -237,7 +237,9 @@ def norm_dual_residual(self) -> float: """ return norm(self.C.adj(self.z - self.z_old)) - def z_init(self, x0: Union[JaxArray, BlockArray]): + def z_init( + self, x0: Union[JaxArray, BlockArray] + ) -> Tuple[Union[JaxArray, BlockArray], Union[JaxArray, BlockArray]]: r"""Initialize auxiliary variable :math:`\mb{z}`. Initialized to @@ -254,7 +256,7 @@ def z_init(self, x0: Union[JaxArray, BlockArray]): z_old = z return z, z_old - def u_init(self, x0: Union[JaxArray, BlockArray]): + def u_init(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: r"""Initialize scaled Lagrange multiplier :math:`\mb{u}`. Initialized to diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 4f3ecd48c..3bffecae3 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -162,7 +162,7 @@ def __init__( if itstat_options: default_itstat_options.update(itstat_options) self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore if x0 is None: input_shape = C.input_shape @@ -213,7 +213,7 @@ def norm_primal_residual(self) -> float: Current value of primal residual. """ - return norm(self.x - self.x_old) / self.tau + return norm(self.x - self.x_old) / self.tau # type: ignore def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index a2be482b2..f9286fdc3 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -12,7 +12,7 @@ from __future__ import annotations from functools import reduce -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import jax from jax.scipy.sparse.linalg import cg as jax_cg @@ -585,7 +585,9 @@ def norm_dual_residual(self) -> float: out += norm(Ci.adj(zi - ziold)) ** 2 return snp.sqrt(out) - def z_init(self, x0: Union[JaxArray, BlockArray]): + def z_init( + self, x0: Union[JaxArray, BlockArray] + ) -> Tuple[List[Union[JaxArray, BlockArray]], List[Union[JaxArray, BlockArray]]]: r"""Initialize auxiliary variables :math:`\mb{z}_i`. Initialized to @@ -603,7 +605,7 @@ def z_init(self, x0: Union[JaxArray, BlockArray]): z_list_old = z_list.copy() return z_list, z_list_old - def u_init(self, x0: Union[JaxArray, BlockArray]): + def u_init(self, x0: Union[JaxArray, BlockArray]) -> List[Union[JaxArray, BlockArray]]: r"""Initialize scaled Lagrange multipliers :math:`\mb{u}_i`. Initialized to From 920527919b485eb82b18577ca7c948e57831a4c4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 3 Mar 2022 10:05:57 -0700 Subject: [PATCH 32/49] Fix merge error --- scico/optimize/admm.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index 79eb09cb0..6c58488e1 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -193,11 +193,6 @@ def __init__(self, cg_kwargs: Optional[dict[str, Any]] = None, cg_function: str def internal_init(self, admm): if admm.f is not None: - if not isinstance(admm.f, WeightedSquaredL2Loss): # SquaredL2Loss is subclass - raise ValueError( - f"LinearSubproblemSolver requires f to be a scico.loss.WeightedSquaredL2Loss" - f"or scico.loss.SquaredL2Loss; got {type(admm.f)}" - ) if not isinstance(admm.f.A, LinearOperator): raise ValueError( f"LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; " From f1aec60c95369e31fc31d05e68b0471e9b15fb29 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 22 Apr 2022 12:45:32 -0600 Subject: [PATCH 33/49] Fix merge error --- scico/numpy/util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scico/numpy/util.py b/scico/numpy/util.py index e7c3594a0..165caab66 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -63,6 +63,10 @@ def ensure_on_device( arrays[i] = jax.device_put(arrays[i]) + if len(arrays) == 1: + return arrays[0] + return arrays + def parse_axes( axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None From 6f02f7b7f2365834b5c8ba0afa7c945f62c3ca94 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 22 Apr 2022 19:53:13 -0600 Subject: [PATCH 34/49] Improve code style --- scico/random.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scico/random.py b/scico/random.py index 63b857334..96a803a3e 100644 --- a/scico/random.py +++ b/scico/random.py @@ -108,12 +108,12 @@ def fun_alt(*args, key=None, seed=None, **kwargs): fun_alt.__doc__ = "\n\n".join( lines[0:1] + [ - f" Wrapped version of `jax.random.{fun.__name__} `_. " - "The SCICO version of this function moves the `key` argument to the end of the argument list, " - "adds an additional `seed` argument after that, and allows the `shape` argument " - "to accept a nested list, in which case a `BlockArray` is returned. " - "Always returns a `(result, key)` tuple.", - " Original docstring below.", + f" Wrapped version of `jax.random.{fun.__name__} " + f"`_. " + "The SCICO version of this function moves the `key` argument to the end of the " + "argument list, adds an additional `seed` argument after that, and allows the " + "`shape` argument to accept a nested list, in which case a `BlockArray` is returned. " + "Always returns a `(result, key)` tuple. Original docstring below.", ] + lines[1:] ) From 6ab90d05accda3699fcfe8ca7f8acc7db08bd3fc Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 22 Apr 2022 19:54:44 -0600 Subject: [PATCH 35/49] Resolve mypy errors --- scico/_flax.py | 2 +- scico/_version.py | 4 ++-- scico/examples.py | 16 ++++++++-------- scico/functional/_denoiser.py | 2 +- scico/linop/_linop.py | 18 ++++++++++++++++-- scico/linop/_stack.py | 10 +++++++--- scico/linop/radon_astra.py | 4 ++-- scico/loss.py | 12 ++++++------ scico/operator/biconvolve.py | 9 +++++---- scico/optimize/_ladmm.py | 2 +- 10 files changed, 49 insertions(+), 30 deletions(-) diff --git a/scico/_flax.py b/scico/_flax.py index bda3facc7..cf0d1aa2a 100644 --- a/scico/_flax.py +++ b/scico/_flax.py @@ -203,5 +203,5 @@ def __call__(self, x: JaxArray) -> JaxArray: x = x.reshape((1,) + x.shape + (1,)) elif x.ndim == 3: x = x.reshape((1,) + x.shape) - y = self.model.apply(self.variables, x, train=False, mutable=False) + y = self.model.apply(self.variables, x, train=False, mutable=False) # type: ignore return y.reshape(x_shape) diff --git a/scico/_version.py b/scico/_version.py index e6281e19a..08de1705b 100644 --- a/scico/_version.py +++ b/scico/_version.py @@ -41,7 +41,7 @@ def variable_assign_value(path: str, var: str) -> Any: with open(path) as f: try: # See http://stackoverflow.com/questions/2058802 - value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s + value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s # type: ignore except StopIteration: raise RuntimeError(f"Could not find initialization of variable {var}") return value @@ -70,7 +70,7 @@ def current_git_hash() -> Optional[str]: # nosec pragma: no cover Short git hash of current commit, or ``None`` if no git repo found. """ process = Popen(["git", "rev-parse", "--short", "HEAD"], shell=False, stdout=PIPE, stderr=PIPE) - git_hash = process.communicate()[0].strip().decode("utf-8") + git_hash: Optional[str] = process.communicate()[0].strip().decode("utf-8") if git_hash == "": git_hash = None return git_hash diff --git a/scico/examples.py b/scico/examples.py index e8681fc01..ac38c33e4 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -300,7 +300,7 @@ def create_3D_foam_phantom( r_std: float = 0.001, pad: float = 0.01, is_random: bool = False, -): +) -> JaxArray: """Construct a 3D phantom with random radii and centers. Args: @@ -316,7 +316,7 @@ def create_3D_foam_phantom( process deterministic. Default ``False``. Returns: - 3D phantom of shape im_shape + 3D phantom of shape `im_shape`. """ c_lo = 0.0 c_hi = 1.0 @@ -331,10 +331,10 @@ def create_3D_foam_phantom( radii = r_std * np.random.randn(N_sphere) + r_mean im = snp.zeros(im_shape) + c_lo - for c, r in zip(centers, radii): + for c, r in zip(centers, radii): # type: ignore dist = snp.sum((x - c) ** 2, axis=-1) if snp.mean(im[dist < r**2] - c_lo) < 0.01 * c_hi: - # In numpy: im[dist < r**2] = c_hi + # equivalent to im[dist < r**2] = c_hi in numpy im = im.at[dist < r**2].set(c_hi) return im @@ -354,13 +354,13 @@ def spnoise(img: Array, nfrac: float, nmin: float = 0.0, nmax: float = 1.0) -> A """ if isinstance(img, np.ndarray): - spm = np.random.uniform(-1.0, 1.0, img.shape) + spm = np.random.uniform(-1.0, 1.0, img.shape) # type: ignore imgn = img.copy() imgn[spm < nfrac - 1.0] = nmin imgn[spm > 1.0 - nfrac] = nmax else: - spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) + spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) # type: ignore imgn = img - imgn = imgn.at[spm < nfrac - 1.0].set(nmin) - imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) + imgn = imgn.at[spm < nfrac - 1.0].set(nmin) # type: ignore + imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) # type: ignore return imgn diff --git a/scico/functional/_denoiser.py b/scico/functional/_denoiser.py index e6046b075..dd5982a07 100644 --- a/scico/functional/_denoiser.py +++ b/scico/functional/_denoiser.py @@ -65,7 +65,7 @@ def __init__(self): r"""Initialize a :class:`BM4D` object.""" super().__init__() - def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: + def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore r"""Apply BM4D denoiser. Args: diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 3866b55bf..0412479ea 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -15,6 +15,8 @@ from functools import partial from typing import Any, Callable, Optional, Union +import jax + import scico.numpy as snp from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar from scico.numpy import BlockArray @@ -144,14 +146,26 @@ def valid_adjoint( else: if x.shape != A.input_shape: raise ValueError("Shape of x array not appropriate as an input for operator A") + assert isinstance( + x, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) + ) if y is None: y, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype) else: if y.shape != AT.input_shape: raise ValueError("Shape of y array not appropriate as an input for operator AT") + assert isinstance( + y, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) + ) u = A(x) v = AT(y) + assert isinstance( + u, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) + ) + assert isinstance( + v, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) + ) yTu = snp.dot(y.ravel().conj(), u.ravel()) vTx = snp.dot(v.ravel().conj(), x.ravel()) err = snp.abs(yTu - vTx) / max(snp.abs(yTu), snp.abs(vTx)) @@ -331,13 +345,13 @@ def __init__( **kwargs: Any, ): self._eval = lambda x: f(x, *args, **kwargs) - super().__init__(input_shape, input_dtype=input_dtype, jit=jit) + super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore OpClass = type(classname, (LinearOperator,), {"__init__": __init__}) __class__ = OpClass # needed for super() to work OpClass.__doc__ = f"Linear operator version of :func:`{f_name}`." - OpClass.__init__.__doc__ = rf""" + OpClass.__init__.__doc__ = rf""" # type: ignore Args: input_shape: Shape of input array. diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 609da2ba5..4b69b34da 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -17,7 +17,7 @@ import scico.numpy as snp from scico.numpy import BlockArray -from scico.typing import JaxArray +from scico.typing import BlockShape, JaxArray, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -38,7 +38,7 @@ def __init__( collapse: If ``True`` and the output would be a `BlockArray` with shape ((m, n, ...), (m, n, ...), ...), the output is instead a `DeviceArray` with shape (S, m, n, ...) where S - is the length of `ops`. Defaults to True. + is the length of `ops`. Defaults to ``True``. jit: see `jit` in :class:`LinearOperator`. """ @@ -62,11 +62,15 @@ def __init__( ) self.collapse = collapse + output_shape: Union[Shape, BlockShape] output_shape = tuple(op.shape[0] for op in ops) # assumes BlockArray output # check if collapsable and adjust output_shape if needed - self.collapsable = all(output_shape[0] == s for s in output_shape) + self.collapsable = isinstance(output_shape[0], tuple) and all( + output_shape[0] == s for s in output_shape + ) if self.collapsable and self.collapse: + assert isinstance(output_shape[0], tuple) output_shape = (len(ops),) + output_shape[0] # collapse to DeviceArray output_dtypes = [op.output_dtype for op in ops] diff --git a/scico/linop/radon_astra.py b/scico/linop/radon_astra.py index 2819da36e..d97bb9398 100644 --- a/scico/linop/radon_astra.py +++ b/scico/linop/radon_astra.py @@ -107,9 +107,9 @@ def __init__( # Wrap our non-jax function to indicate we will supply fwd/rev mode functions self._eval = jax.custom_vjp(self._proj) - self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) + self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj) - self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),)) + self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),)) # type: ignore super().__init__( input_shape=self.input_shape, diff --git a/scico/loss.py b/scico/loss.py index 5cd988d94..8a5c19c7c 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -97,7 +97,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * self.f(self.A(x) - self.y) def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1, **kwargs ) -> Union[JaxArray, BlockArray]: r"""Scaled proximal operator of loss function. @@ -120,6 +120,7 @@ def prox( f"prox is not implemented for {type(self)} when A is {type(self.A)}; " "must be Identity" ) + assert self.f is not None return self.f.prox(v - self.y, self.scale * lam, **kwargs) + self.y @_loss_mul_div_wrapper @@ -195,7 +196,6 @@ def __init__( if prox_kwargs: default_prox_kwargs.update(prox_kwargs) self.prox_kwargs = default_prox_kwargs - prox_kwargs: dict = ({"maxiter": 100, "tol": 1e-5},) if isinstance(self.A, linop.LinearOperator): self.has_prox = True @@ -204,7 +204,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2) def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: if not isinstance(self.A, linop.LinearOperator): raise NotImplementedError( @@ -237,7 +237,7 @@ def prox( hessian = self.hessian # = (2𝛼 A^T W A) lhs = linop.Identity(v.shape) + lam * hessian rhs = v + 2 * lam * 𝛼 * A.adj(W(y)) - x, _ = cg(lhs, rhs, x0, **self.prox_kwargs) + x, _ = cg(lhs, rhs, x0, **self.prox_kwargs) # type: ignore return x @property @@ -357,7 +357,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * (self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x))) ** 2).sum() def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: if not self.has_prox: raise NotImplementedError(f"prox is not implemented.") @@ -564,7 +564,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * (self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x)) ** 2) ** 2).sum() def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: if not self.has_prox: raise NotImplementedError(f"prox is not implemented.") diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index 77a0be36c..46a3a6326 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -7,6 +7,7 @@ """Biconvolution operator.""" +from typing import Tuple, cast import numpy as np @@ -16,7 +17,7 @@ from scico.linop import Convolve, ConvolveByX from scico.numpy import BlockArray from scico.numpy.util import is_nested -from scico.typing import BlockShape, DType, JaxArray +from scico.typing import DType, JaxArray, Shape class BiConvolve(Operator): @@ -32,7 +33,7 @@ class BiConvolve(Operator): def __init__( self, - input_shape: BlockShape, + input_shape: Tuple[Shape, Shape], input_dtype: DType = np.float32, mode: str = "full", jit: bool = True, @@ -87,7 +88,7 @@ def freeze(self, argnum: int, val: JaxArray) -> LinearOperator: if argnum == 0: return ConvolveByX( x=val, - input_shape=self.input_shape[1], + input_shape=cast(Shape, self.input_shape[1]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, @@ -95,7 +96,7 @@ def freeze(self, argnum: int, val: JaxArray) -> LinearOperator: if argnum == 1: return Convolve( h=val, - input_shape=self.input_shape[0], + input_shape=cast(Shape, self.input_shape[0]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index d34f118f7..3af0824c8 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -11,7 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import scico.numpy as snp from scico.diagnostics import IterationStats From feb1830247950064cd5dd6a45e71ec9e1afb7bc3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 23 Apr 2022 06:36:44 -0600 Subject: [PATCH 36/49] Resolve mypy errors --- scico/_generic_operators.py | 7 ++++--- scico/linop/_linop.py | 38 +++++++++++++++++++------------------ scico/linop/_stack.py | 3 ++- scico/linop/radon_svmbir.py | 11 +++++++---- scico/numpy/util.py | 2 +- 5 files changed, 34 insertions(+), 27 deletions(-) diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index c1cf7e412..5a0191cc5 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -143,7 +143,7 @@ def __init__( if output_shape is None or output_dtype is None: tmp = self(snp.zeros(self.input_shape, dtype=input_dtype)) if output_shape is None: - self.output_shape = tmp.shape + self.output_shape = tmp.shape # type: ignore else: self.output_shape = (output_shape,) if isinstance(output_shape, int) else output_shape @@ -315,10 +315,11 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator: f"{self.input_shape[argnum]}, got {val.shape}" ) - input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) + input_shape: Union[Shape, BlockShape] + input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) # type: ignore if len(input_shape) == 1: - input_shape = input_shape[0] + input_shape = input_shape[0] # type: ignore def concat_args(args): # Creates a blockarray with args and the frozen value in the correct place diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 0412479ea..6a7b8a751 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -64,7 +64,6 @@ def operator_norm(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] :math:`A`, .. math:: - \| A \|_2 &= \max \{ \| A \mb{x} \|_2 \, : \, \| \mb{x} \|_2 \leq 1 \} \\ &= \sqrt{ \lambda_{ \mathrm{max} }( A^H A ) } = \sigma_{\mathrm{max}}(A) \;, @@ -273,7 +272,7 @@ class Slice(LinearOperator): def __init__( self, idx: ArrayIndex, - input_shape: Shape, + input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, jit: bool = True, **kwargs, @@ -296,8 +295,9 @@ def __init__( functions of the LinearOperator. """ + output_shape: Union[Shape, BlockShape] if is_nested(input_shape): - output_shape = input_shape[idx] + output_shape = input_shape[idx] # type: ignore else: output_shape = indexed_shape(input_shape, idx) @@ -336,6 +336,22 @@ def linop_from_function(f: Callable, classname: str, f_name: Optional[str] = Non if f_name is None: f_name = f"{f.__module__}.{f.__name__}" + f_doc = rf""" + + Args: + input_shape: Shape of input array. + args: Positional arguments passed to :func:`{f_name}`. + input_dtype: `dtype` for input argument. + Defaults to ``float32``. If `LinearOperator` implements + complex-valued operations, this must be ``complex64`` for + proper adjoint and gradient calculation. + jit: If ``True``, call :meth:`.Operator.jit` on this + `LinearOperator` to jit the forward, adjoint, and gram + functions. Same as calling :meth:`.Operator.jit` after + the `LinearOperator` is created. + kwargs: Keyword arguments passed to :func:`{f_name}`. + """ + def __init__( self, input_shape: Union[Shape, BlockShape], @@ -351,21 +367,7 @@ def __init__( __class__ = OpClass # needed for super() to work OpClass.__doc__ = f"Linear operator version of :func:`{f_name}`." - OpClass.__init__.__doc__ = rf""" # type: ignore - - Args: - input_shape: Shape of input array. - args: Positional arguments passed to :func:`{f_name}`. - input_dtype: `dtype` for input argument. - Defaults to ``float32``. If `LinearOperator` implements - complex-valued operations, this must be ``complex64`` for - proper adjoint and gradient calculation. - jit: If ``True``, call :meth:`.Operator.jit` on this - `LinearOperator` to jit the forward, adjoint, and gram - functions. Same as calling :meth:`.Operator.jit` after - the `LinearOperator` is created. - kwargs: Keyword arguments passed to :func:`{f_name}`. - """ + OpClass.__init__.__doc__ = f_doc # type: ignore return OpClass diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 4b69b34da..32475a50b 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -63,7 +63,8 @@ def __init__( self.collapse = collapse output_shape: Union[Shape, BlockShape] - output_shape = tuple(op.shape[0] for op in ops) # assumes BlockArray output + # start by assuming BlockArray output + output_shape = tuple(op.shape[0] for op in ops) # type: ignore # check if collapsable and adjust output_shape if needed self.collapsable = isinstance(output_shape[0], tuple) and all( diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 3806feedd..a73760c0d 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -162,10 +162,10 @@ def __init__( # Set up custom_vjp for _eval and _adj so jax.grad works on them. self._eval = jax.custom_vjp(self._proj_hcb) - self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),)) + self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj_hcb) - self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),)) + self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),)) # type: ignore super().__init__( input_shape=input_shape, @@ -305,6 +305,9 @@ class SVMBIRExtendedLoss(Loss): described in class :class:`.TomographicProjector`. """ + A: TomographicProjector + W: Union[Identity, Diagonal] + def __init__( self, *args, @@ -328,7 +331,7 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - super().__init__(*args, scale=scale, **kwargs) + super().__init__(*args, scale=scale, **kwargs) # type: ignore if not isinstance(self.A, TomographicProjector): raise ValueError("LinearOperator A must be a radon_svmbir.TomographicProjector.") @@ -367,7 +370,7 @@ def __call__(self, x: JaxArray) -> float: else: return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() - def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: + def prox(self, v: JaxArray, lam: float = 1, **kwargs) -> JaxArray: v = v.reshape(self.A.svmbir_input_shape) y = self.y.reshape(self.A.svmbir_output_shape) weights = self.W.diagonal.reshape(self.A.svmbir_output_shape) diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 165caab66..980eca104 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -187,7 +187,7 @@ def no_nan_divide( return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) -def shape_to_size(shape: Union[Shape, BlockShape]) -> Axes: +def shape_to_size(shape: Union[Shape, BlockShape]) -> int: r"""Compute the size corresponding to a (possibly nested) shape. Args: From 07537c5103385c737cefbbab904632c6c854d994 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 23 Apr 2022 06:40:48 -0600 Subject: [PATCH 37/49] Resolve mypy errors --- scico/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/loss.py b/scico/loss.py index 8a5c19c7c..3bcd26bc1 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -19,7 +19,7 @@ from scico import functional, linop, operator from scico.numpy import BlockArray from scico.numpy.util import ensure_on_device, no_nan_divide -from scico.scipy.special import gammaln +from scico.scipy.special import gammaln # type: ignore from scico.solver import cg from scico.typing import JaxArray From d0edf91896636ae3eec8f782523af74453bca1af Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 23 Apr 2022 06:52:04 -0600 Subject: [PATCH 38/49] Modify mypy configuration in workflow --- .github/workflows/mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index d632a19b1..bb209f2d1 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -33,4 +33,4 @@ jobs: pip install mypy - name: Run mypy run: | - mypy --ignore-missing-imports --exclude numpy --exclude scipy scico/ + mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/ From 8e48de4602950af8c415b68f18d443b595d7e1a0 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 24 Apr 2022 12:14:37 -0600 Subject: [PATCH 39/49] Trivial edit --- scico/_generic_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 5a0191cc5..e435c0e3f 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -588,7 +588,7 @@ def adj( input `y`. Args: - y: Point at which to compute adjoint. If `y` is + y: Point at which to compute adjoint. If `y` is :class:`DeviceArray` or :class:`.BlockArray`, must have `shape == self.output_shape`. If `y` is a :class:`.LinearOperator`, must have From 72d2cd427d538dc3e93e9138e9b3f96d24afb97e Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 24 Apr 2022 12:15:15 -0600 Subject: [PATCH 40/49] Bug fix --- scico/linop/_linop.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 6a7b8a751..57fbac99c 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -15,8 +15,6 @@ from functools import partial from typing import Any, Callable, Optional, Union -import jax - import scico.numpy as snp from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar from scico.numpy import BlockArray @@ -145,28 +143,16 @@ def valid_adjoint( else: if x.shape != A.input_shape: raise ValueError("Shape of x array not appropriate as an input for operator A") - assert isinstance( - x, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) - ) if y is None: y, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype) else: if y.shape != AT.input_shape: raise ValueError("Shape of y array not appropriate as an input for operator AT") - assert isinstance( - y, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) - ) u = A(x) v = AT(y) - assert isinstance( - u, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) - ) - assert isinstance( - v, (jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray) - ) - yTu = snp.dot(y.ravel().conj(), u.ravel()) - vTx = snp.dot(v.ravel().conj(), x.ravel()) + yTu = snp.dot(y.ravel().conj(), u.ravel()) # type: ignore + vTx = snp.dot(v.ravel().conj(), x.ravel()) # type: ignore err = snp.abs(yTu - vTx) / max(snp.abs(yTu), snp.abs(vTx)) if eps is None: return err @@ -191,7 +177,6 @@ def __init__( broadcast-compatiable with `diagonal.shape`. input_dtype: `dtype` of input argument. The default, ``None``, means `diagonal.dtype`. - """ self.diagonal = ensure_on_device(diagonal) @@ -207,9 +192,9 @@ def __init__( elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape): output_shape = snp.broadcast_shapes(input_shape, self.diagonal.shape) elif isinstance(diagonal, BlockArray): - raise ValueError(f"`diagonal` was a BlockArray but `input_shape` was not nested.") + raise ValueError("`diagonal` was a BlockArray but `input_shape` was not nested.") else: - raise ValueError(f"`diagonal` was a not BlockArray but `input_shape` was nested.") + raise ValueError("`diagonal` was a not BlockArray but `input_shape` was nested.") super().__init__( input_shape=input_shape, From 9fbf68d3e8e7c360490e887a606989c71360cc04 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 24 Apr 2022 12:15:47 -0600 Subject: [PATCH 41/49] Consistency improvement --- scico/linop/_stack.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 32475a50b..a612ae635 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -43,21 +43,21 @@ def __init__( """ if not isinstance(ops, (list, tuple)): - raise ValueError("expected a list of `LinearOperator`") + raise ValueError("Expected a list of `LinearOperator`") self.ops = ops input_shapes = [op.shape[1] for op in ops] if not all(input_shapes[0] == s for s in input_shapes): raise ValueError( - "expected all `LinearOperator`s to have the same input shapes, " + "Expected all `LinearOperator`s to have the same input shapes, " f"but got {input_shapes}" ) input_dtypes = [op.input_dtype for op in ops] if not all(input_dtypes[0] == s for s in input_dtypes): raise ValueError( - "expected all `LinearOperator`s to have the same input dtype, " + "Expected all `LinearOperator`s to have the same input dtype, " f"but got {input_dtypes}." ) @@ -76,7 +76,7 @@ def __init__( output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): - raise ValueError("expected all `LinearOperator`s to have the same output dtype") + raise ValueError("Expected all `LinearOperator`s to have the same output dtype") super().__init__( input_shape=input_shapes[0], From 68d2850390ef84e19eb4afa025795b4df3ac9dc6 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Mon, 25 Apr 2022 09:06:21 -0600 Subject: [PATCH 42/49] Address CodeFactor complex function --- scico/linop/_stack.py | 53 +++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index a612ae635..ca43ce79d 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -42,10 +42,37 @@ def __init__( jit: see `jit` in :class:`LinearOperator`. """ - if not isinstance(ops, (list, tuple)): - raise ValueError("Expected a list of `LinearOperator`") + + LinearOperatorStack.check_if_stackable(ops) self.ops = ops + self.collapse = collapse + + # start by assuming BlockArray output + output_shape: Union[Shape, BlockShape] + output_shape = tuple(op.shape[0] for op in ops) # type: ignore + + # check if collapsable and adjust output_shape if needed + self.collapsable = isinstance(output_shape[0], tuple) and all( + output_shape[0] == s for s in output_shape + ) + if self.collapsable and self.collapse: + output_shape = (len(ops),) + output_shape[0] # collapse to DeviceArray + + super().__init__( + input_shape=ops[0].input_shape, + output_shape=output_shape, + input_dtype=ops[0].input_dtype, + output_dtype=ops[0].output_dtype, + jit=jit, + **kwargs, + ) + + @staticmethod + def check_if_stackable(ops): + """Check that input ops are suitable for stack creation.""" + if not isinstance(ops, (list, tuple)): + raise ValueError("Expected a list of `LinearOperator`") input_shapes = [op.shape[1] for op in ops] if not all(input_shapes[0] == s for s in input_shapes): @@ -61,32 +88,10 @@ def __init__( f"but got {input_dtypes}." ) - self.collapse = collapse - output_shape: Union[Shape, BlockShape] - # start by assuming BlockArray output - output_shape = tuple(op.shape[0] for op in ops) # type: ignore - - # check if collapsable and adjust output_shape if needed - self.collapsable = isinstance(output_shape[0], tuple) and all( - output_shape[0] == s for s in output_shape - ) - if self.collapsable and self.collapse: - assert isinstance(output_shape[0], tuple) - output_shape = (len(ops),) + output_shape[0] # collapse to DeviceArray - output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): raise ValueError("Expected all `LinearOperator`s to have the same output dtype") - super().__init__( - input_shape=input_shapes[0], - output_shape=output_shape, - input_dtype=input_dtypes[0], - output_dtype=output_dtypes[0], - jit=jit, - **kwargs, - ) - def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]: if self.collapsable and self.collapse: return snp.stack([op @ x for op in self.ops]) From 7c61e25cbab8a3d70c9319fc3bd3189d15ba661f Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 26 Apr 2022 07:45:29 -0600 Subject: [PATCH 43/49] Fix type error --- scico/linop/_stack.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index ca43ce79d..10f020ba7 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -17,7 +17,8 @@ import scico.numpy as snp from scico.numpy import BlockArray -from scico.typing import BlockShape, JaxArray, Shape +from scico.numpy.util import is_nested +from scico.typing import JaxArray from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -48,16 +49,13 @@ def __init__( self.ops = ops self.collapse = collapse - # start by assuming BlockArray output - output_shape: Union[Shape, BlockShape] - output_shape = tuple(op.shape[0] for op in ops) # type: ignore + self.collapsable = all(op.output_shape == ops[0].output_shape for op in ops) - # check if collapsable and adjust output_shape if needed - self.collapsable = isinstance(output_shape[0], tuple) and all( - output_shape[0] == s for s in output_shape - ) + assert not isinstance(ops[0].output_shape, tuple) if self.collapsable and self.collapse: - output_shape = (len(ops),) + output_shape[0] # collapse to DeviceArray + output_shape = (len(ops),) + ops[0].output_shape # collapse to DeviceArray + else: + output_shape = tuple(op.output_shape for op in ops) super().__init__( input_shape=ops[0].input_shape, @@ -69,7 +67,7 @@ def __init__( ) @staticmethod - def check_if_stackable(ops): + def check_if_stackable(ops: List[LinearOperator]): """Check that input ops are suitable for stack creation.""" if not isinstance(ops, (list, tuple)): raise ValueError("Expected a list of `LinearOperator`") @@ -88,9 +86,13 @@ def check_if_stackable(ops): f"but got {input_dtypes}." ) + output_shapes = [op.shape[0] for op in ops] + if any(is_nested(output_shapes)): + raise ValueError("Cannot stack `LinearOperators`s with nested output shapes.") + output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): - raise ValueError("Expected all `LinearOperator`s to have the same output dtype") + raise ValueError("Expected all `LinearOperator`s to have the same output dtype.") def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]: if self.collapsable and self.collapse: From 2f5b14ca0acc7e456ced7e0cdf043c750d6f83c8 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Tue, 26 Apr 2022 12:13:23 -0600 Subject: [PATCH 44/49] Switch back to ignore, can't solve this problem without code bloat --- scico/linop/_stack.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 10f020ba7..83523d6f3 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -51,15 +51,15 @@ def __init__( self.collapsable = all(op.output_shape == ops[0].output_shape for op in ops) - assert not isinstance(ops[0].output_shape, tuple) + output_shapes = tuple(op.output_shape for op in ops) if self.collapsable and self.collapse: - output_shape = (len(ops),) + ops[0].output_shape # collapse to DeviceArray + output_shape = (len(ops),) + output_shapes[0] # collapse to DeviceArray else: - output_shape = tuple(op.output_shape for op in ops) + output_shape = output_shapes super().__init__( input_shape=ops[0].input_shape, - output_shape=output_shape, + output_shape=output_shape, # type: ignore input_dtype=ops[0].input_dtype, output_dtype=ops[0].output_dtype, jit=jit, @@ -86,8 +86,7 @@ def check_if_stackable(ops: List[LinearOperator]): f"but got {input_dtypes}." ) - output_shapes = [op.shape[0] for op in ops] - if any(is_nested(output_shapes)): + if any([is_nested(op.shape[0]) for op in ops]): raise ValueError("Cannot stack `LinearOperators`s with nested output shapes.") output_dtypes = [op.output_dtype for op in ops] From 1fa46bb5210ac67f544bba258b18e9092120ca54 Mon Sep 17 00:00:00 2001 From: Fernando Davis Date: Thu, 28 Apr 2022 15:12:18 -0400 Subject: [PATCH 45/49] Fixed extra 's' in `LinearOperator`s string. --- scico/linop/_stack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 83523d6f3..eebea6a70 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -87,7 +87,7 @@ def check_if_stackable(ops: List[LinearOperator]): ) if any([is_nested(op.shape[0]) for op in ops]): - raise ValueError("Cannot stack `LinearOperators`s with nested output shapes.") + raise ValueError("Cannot stack `LinearOperator`s with nested output shapes.") output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): From e3d8974502bc8ec0691f20efabaca12e43e6340a Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 28 Apr 2022 21:58:54 +0200 Subject: [PATCH 46/49] Trivial edit --- scico/numpy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 3cd1d1807..d1eeba82c 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -14,8 +14,8 @@ many have been extended to automatically map over block array blocks as described in :mod:`scico.numpy.blockarray`. Also included are additional functions unique to SCICO in :mod:`.util`. - """ + import numpy as np import jax.numpy as jnp From 52eba3ba3bcb2bb43aba8478c961f7610c055dad Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 28 Apr 2022 21:59:31 +0200 Subject: [PATCH 47/49] Use type guard rather than type ignore --- scico/linop/optics.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 7e312ef9c..4b56ab28a 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -47,18 +47,19 @@ and :math:`y` to axis 1 (columns, increasing to the right). """ - # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Tuple, Union +from typing import Any, Tuple, Union import numpy as np from numpy.lib.scimath import sqrt # complex sqrt import jax +from typing_extensions import TypeGuard + import scico.numpy as snp from scico.linop import Diagonal, Identity, LinearOperator from scico.numpy.util import no_nan_divide @@ -67,6 +68,11 @@ from ._dft import DFT +def _isscalar(element: Any) -> TypeGuard[Union[int, float]]: + """Type guard interface to `snp.isscalar`.""" + return snp.isscalar(element) + + def radial_transverse_frequency( input_shape: Shape, dx: Union[float, Tuple[float, ...]] ) -> np.ndarray: @@ -89,12 +95,12 @@ def radial_transverse_frequency( :math:`\sqrt{k_x^2 + k_y^2}\,`. """ - ndim = len(input_shape) # 1 or 2 dimensions + ndim: int = len(input_shape) # 1 or 2 dimensions if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2") - if np.isscalar(dx): - dx = (dx,) * ndim # type: ignore + if _isscalar(dx): + dx = (dx,) * ndim else: assert isinstance(dx, tuple) if len(dx) != ndim: @@ -147,8 +153,8 @@ def __init__( if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2") - if np.isscalar(dx): - dx = (dx,) * ndim # type: ignore + if _isscalar(dx): + dx = (dx,) * ndim else: assert isinstance(dx, tuple) if len(dx) != ndim: @@ -495,8 +501,8 @@ def __init__( if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2") - if np.isscalar(dx): - dx = (dx,) * ndim # type: ignore + if _isscalar(dx): + dx = (dx,) * ndim else: assert isinstance(dx, tuple) if len(dx) != ndim: From 9bd02141929b13a2ac93486d5521b348d03e0599 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 29 Apr 2022 08:05:45 +0200 Subject: [PATCH 48/49] Fix EllipsisType import --- scico/typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/typing.py b/scico/typing.py index ee3e79070..7c9208a80 100644 --- a/scico/typing.py +++ b/scico/typing.py @@ -11,9 +11,9 @@ try: # available in python 3.10 - from typing import EllipsisType # type: ignore + from types import EllipsisType # type: ignore except ImportError: - EllipsisType = Any + EllipsisType = Any # type: ignore import numpy as np From 61779ff25994abfc926facf4504373ae23355408 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sat, 30 Apr 2022 07:49:53 +0200 Subject: [PATCH 49/49] Minor edits for docs style --- scico/numpy/blockarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 84e409e46..df528f3d3 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -212,7 +212,7 @@ class BlockArray(list): - """BlockArray class""" + """BlockArray class.""" # Ensure we use BlockArray.__radd__, __rmul__, etc for binary # operations of the form op(np.ndarray, BlockArray) See @@ -233,7 +233,7 @@ def __init__(self, inputs): def dtype(self): """Return the dtype of the blocks, which must currently be homogeneous. - This allows snp.zeros(x.shape, x.dtype) to work without a mechanism + This allows `snp.zeros(x.shape, x.dtype)` to work without a mechanism to handle to lists of dtypes. """ return self[0].dtype