diff --git a/CHANGES.rst b/CHANGES.rst index b8452e83a..36dc57db3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,8 @@ Version 0.0.2 (unreleased) • Move optimization algorithms into ``optimize`` subpackage. • Additional iteration stats columns for iterative ADMM subproblem solvers. • Renamed "Primal Rsdl" to "Prml Rsdl" in displayed iteration stats. +• Move some functions from ``util`` and ``math`` modules to new ``array`` + module. • Bump pinned `jaxlib` and `jax` versions to 0.1.70 and 0.2.19 respectively. diff --git a/examples/scriptcheck.sh b/examples/scriptcheck.sh index 607887941..bd46c45e1 100755 --- a/examples/scriptcheck.sh +++ b/examples/scriptcheck.sh @@ -17,7 +17,7 @@ fi # Set environment variables and paths. This script is assumed to be run # from its root directory. -export PYTHONPATH=$((cd .. && pwd)) +export PYTHONPATH=$(cd .. && pwd) export PYTHONIOENCODING=utf-8 d='/tmp/scriptcheck_'$$ mkdir -p $d diff --git a/examples/scripts/denoise_tv_iso_pgm.py b/examples/scripts/denoise_tv_iso_pgm.py index e3524595a..e11e64b4b 100644 --- a/examples/scripts/denoise_tv_iso_pgm.py +++ b/examples/scripts/denoise_tv_iso_pgm.py @@ -39,10 +39,11 @@ import scico.numpy as snp import scico.random from scico import functional, linop, loss, operator, plot +from scico.array import ensure_on_device from scico.blockarray import BlockArray from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize from scico.typing import JaxArray -from scico.util import device_info, ensure_on_device +from scico.util import device_info """ Create a ground truth image. diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 4e5e3a872..7cb059137 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -23,10 +23,9 @@ import scico.numpy as snp from scico._autograd import linear_adjoint +from scico.array import is_complex_dtype, is_nested from scico.blockarray import BlockArray, block_sizes -from scico.math import is_complex_dtype from scico.typing import BlockShape, DType, JaxArray, Shape -from scico.util import is_nested def _wrap_mul_div_scalar(func): diff --git a/scico/array.py b/scico/array.py new file mode 100644 index 000000000..a6410c16c --- /dev/null +++ b/scico/array.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2020-2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Utility functions for arrays, array shapes, array indexing, etc.""" + + +from __future__ import annotations + +import warnings +from typing import Any, List, Optional, Tuple, Union + +import numpy as np + +import jax +from jax.interpreters.pxla import ShardedDeviceArray +from jax.interpreters.xla import DeviceArray + +import scico.blockarray +import scico.numpy as snp +from scico.typing import ArrayIndex, Axes, AxisIndex, JaxArray, Shape + +__author__ = """\n""".join( + [ + "Brendt Wohlberg ", + "Luke Pfister ", + "Thilo Balke ", + "Michael McCann ", + ] +) + + +def ensure_on_device( + *arrays: Union[np.ndarray, JaxArray, scico.blockarray.BlockArray] +) -> Union[JaxArray, scico.blockarray.BlockArray]: + """Cast ndarrays to DeviceArrays. + + Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, + and ShardedDeviceArray as is. This is intended to be used when + initializing optimizers and functionals so that all arrays are either + DeviceArrays, BlockArrays, or ShardedDeviceArray. + + Args: + *arrays: One or more input arrays (ndarray, DeviceArray, + BlockArray, or ShardedDeviceArray). + + Returns: + Modified array or arrays. Modified are only those that were + necessary. + + Raises: + TypeError: If the arrays contain something that is neither + ndarray, DeviceArray, BlockArray, nor ShardedDeviceArray. + """ + arrays = list(arrays) + + for i, array in enumerate(arrays): + + if isinstance(array, np.ndarray): + warnings.warn( + f"Argument {i+1} of {len(arrays)} is an np.ndarray. " + f"Will cast it to DeviceArray. " + f"To suppress this warning cast all np.ndarrays to DeviceArray first.", + stacklevel=2, + ) + + arrays[i] = jax.device_put(arrays[i]) + elif not isinstance( + array, + (DeviceArray, scico.blockarray.BlockArray, ShardedDeviceArray), + ): + raise TypeError( + "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " + f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." + ) + + if len(arrays) == 1: + return arrays[0] + return arrays + + +def no_nan_divide( + x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] +) -> Union[BlockArray, JaxArray]: + """Return `x/y`, with 0 instead of NaN where `y` is 0. + + Args: + x: Numerator. + y: Denominator. + + Returns: + `x / y` with 0 wherever `y == 0`. + """ + + return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) + + +def parse_axes( + axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None +) -> List[int]: + """Normalize `axes` to a list and optionally ensure correctness. + + Normalize `axes` to a list and (optionally) ensure that entries refer + to axes that exist in `shape`. + + Args: + axes: User specification of one or more axes: int, list, tuple, + or ``None``. + shape: The shape of the array of which axes are being specified. + If not ``None``, `axes` is checked to make sure its entries + refer to axes that exist in `shape`. + default: Default value to return if `axes` is ``None``. By + default, `list(range(len(shape)))`. + + Returns: + List of axes (never an int, never ``None``). + """ + + if axes is None: + if default is None: + if shape is None: + raise ValueError("`axes` cannot be `None` without a default or shape specified.") + axes = list(range(len(shape))) + else: + axes = default + elif isinstance(axes, (list, tuple)): + axes = axes + elif isinstance(axes, int): + axes = (axes,) + else: + raise ValueError(f"Could not understand axes {axes} as a list of axes") + if shape is not None and max(axes) >= len(shape): + raise ValueError( + f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}." + ) + if len(set(axes)) != len(axes): + raise ValueError(f"Duplicate value in axes {axes}; each axis must be unique.") + return axes + + +def slice_length(length: int, slc: AxisIndex) -> int: + """Determine the length of an array axis after slicing. + + Args: + length: Length of axis being sliced. + slc: Slice/indexing to be applied to axis. + + Returns: + Length of sliced axis. + + Raises: + ValueError: If `slc` is an integer index that is out bounds for + the axis length. + """ + if slc is Ellipsis: + return length + if isinstance(slc, int): + if slc < -length or slc > length - 1: + raise ValueError(f"Index {slc} out of bounds for axis of length {length}.") + return 1 + start, stop, stride = slc.indices(length) + if start > stop: + start = stop + return (stop - start + stride - 1) // stride + + +def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int]: + """Determine the shape of an array after indexing/slicing. + + Args: + shape: Shape of array. + idx: Indexing expression. + + Returns: + Shape of indexed/sliced array. + + Raises: + ValueError: If `idx` is longer than `shape`. + """ + if not isinstance(idx, tuple): + idx = (idx,) + if len(idx) > len(shape): + raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") + idx_shape = list(shape) + offset = 0 + for axis, ax_idx in enumerate(idx): + print(axis, offset) + if ax_idx is Ellipsis: + offset = len(shape) - len(idx) + continue + idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) + return tuple(idx_shape) + + +def is_nested(x: Any) -> bool: + """Check if input is a list/tuple containing at least one list/tuple. + + Args: + x: Object to be tested. + + Returns: + True if ``x`` is a list/tuple of list/tuples, False otherwise. + + + Example: + >>> is_nested([1, 2, 3]) + False + >>> is_nested([(1,2), (3,)]) + True + >>> is_nested([[1, 2], 3]) + True + + """ + if isinstance(x, (list, tuple)): + return any([isinstance(_, (list, tuple)) for _ in x]) + return False + + +def is_real_dtype(dtype: DType) -> bool: + """Determine whether a dtype is real. + + Args: + dtype: A numpy or scico.numpy dtype (e.g. np.float32, + snp.complex64). + + Returns: + False if the dtype is complex, otherwise True. + """ + return snp.dtype(dtype).kind != "c" + + +def is_complex_dtype(dtype: DType) -> bool: + """Determine whether a dtype is complex. + + Args: + dtype: A numpy or scico.numpy dtype (e.g. np.float32, + snp.complex64). + + Returns: + True if the dtype is complex, otherwise False. + """ + return snp.dtype(dtype).kind == "c" + + +def real_dtype(dtype: DType) -> DType: + """Construct the corresponding real dtype for a given complex dtype. + + Construct the corresponding real dtype for a given complex dtype, + e.g. the real dtype corresponding to `np.complex64` is + `np.float32`. + + Args: + dtype: A complex numpy or scico.numpy dtype (e.g. np.complex64, + np.complex128). + + Returns: + The real dtype corresponding to the input dtype + """ + + return snp.zeros(1, dtype).real.dtype + + +def complex_dtype(dtype: DType) -> DType: + """Construct the corresponding complex dtype for a given real dtype. + + Construct the corresponding complex dtype for a given real dtype, + e.g. the complex dtype corresponding to `np.float32` is + `np.complex64`. + + Args: + dtype: A real numpy or scico.numpy dtype (e.g. np.float32, + np.float64). + + Returns: + The complex dtype corresponding to the input dtype. + """ + + return (snp.zeros(1, dtype) + 1j).dtype diff --git a/scico/blockarray.py b/scico/blockarray.py index 56efc15cb..1e9b086e3 100644 --- a/scico/blockarray.py +++ b/scico/blockarray.py @@ -464,7 +464,7 @@ from jaxlib.xla_extension import Buffer -from scico import util +from scico import array from scico.typing import Axes, AxisIndex, BlockShape, DType, JaxArray, Shape _arraylikes = (Buffer, DeviceArray, np.ndarray) @@ -521,7 +521,7 @@ def reshape( always returned. """ - if util.is_nested(newshape): + if array.is_nested(newshape): # x is a blockarray return BlockArray.array_from_flattened(a, newshape) @@ -566,13 +566,13 @@ def block_sizes(shape: Union[Shape, BlockShape]) -> Axes: ) out = [] - if util.is_nested(shape): + if array.is_nested(shape): # shape is nested -> at least one element came from a blockarray for y in shape: - if util.is_nested(y): + if array.is_nested(y): # recursively calculate the block size until we arrive at # a tuple (shape of a non-block array) - while util.is_nested(y): + while array.is_nested(y): y = block_sizes(y) out.append(np.sum(y)) # adjacent block sizes are added together else: @@ -630,7 +630,7 @@ def indexed_shape(shape: Shape, idx: Union[int, Tuple(AxisIndex)]) -> Tuple[int] idxblk = len(shape) + idxblk if idxarr is None: return shape[idxblk] - return util.indexed_shape(shape[idxblk], idxarr) + return array.indexed_shape(shape[idxblk], idxarr) def _flatten_blockarrays(inp, *args, **kwargs): diff --git a/scico/functional/_norm.py b/scico/functional/_norm.py index c1c3f52b9..a809a8d93 100644 --- a/scico/functional/_norm.py +++ b/scico/functional/_norm.py @@ -12,8 +12,8 @@ from jax import jit from scico import numpy as snp +from scico.array import no_nan_divide from scico.blockarray import BlockArray -from scico.math import safe_divide from scico.numpy import count_nonzero from scico.numpy.linalg import norm from scico.typing import JaxArray @@ -247,7 +247,7 @@ def prox( """ length = norm(x, axis=self.l2_axis, keepdims=True) - direction = safe_divide(x, length) + direction = no_nan_divide(x, length) new_length = length - lam # set negative values to zero without `if` diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index 8636f2476..9d8480b3a 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -23,7 +23,7 @@ from jax.scipy.signal import convolve import scico.numpy as snp -from scico import util +from scico import array from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar from scico.typing import DType, JaxArray, Shape @@ -68,7 +68,7 @@ def __init__( if h.ndim != len(input_shape): raise ValueError(f"h.ndim = {h.ndim} must equal len(input_shape) = {len(input_shape)}") - self.h = util.ensure_on_device(h) + self.h = array.ensure_on_device(h) if mode not in ["full", "valid", "same"]: raise ValueError(f"Invalid mode={mode}; must be one of 'full', 'valid', 'same'") diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 08d0c378a..542bcc06c 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -17,8 +17,8 @@ import numpy as np import scico.numpy as snp +from scico.array import parse_axes from scico.typing import Axes, DType, JaxArray, Shape -from scico.util import parse_axes from ._linop import LinearOperator from ._stack import LinearOperatorStack diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index f7d76b02d..b3c7c8f85 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -16,7 +16,7 @@ from typing import Optional, Tuple, Union import scico.numpy as snp -from scico import blockarray, util +from scico import array, blockarray from scico._generic_operators import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar from scico.blockarray import BlockArray from scico.random import randn @@ -185,7 +185,7 @@ def __init__( """ - self.diagonal = util.ensure_on_device(diagonal) + self.diagonal = array.ensure_on_device(diagonal) if input_shape is None: input_shape = self.diagonal.shape @@ -193,9 +193,9 @@ def __init__( if input_dtype is None: input_dtype = self.diagonal.dtype - if isinstance(diagonal, BlockArray) and util.is_nested(input_shape): + if isinstance(diagonal, BlockArray) and array.is_nested(input_shape): output_shape = (snp.empty(input_shape) * diagonal).shape - elif not isinstance(diagonal, BlockArray) and not util.is_nested(input_shape): + elif not isinstance(diagonal, BlockArray) and not array.is_nested(input_shape): output_shape = snp.broadcast_shapes(input_shape, self.diagonal.shape) elif isinstance(diagonal, BlockArray): raise ValueError(f"`diagonal` was a BlockArray but `input_shape` was not nested.") @@ -284,7 +284,7 @@ def __init__( """ input_ndim = len(input_shape) - sum_axis = util.parse_axes(sum_axis, shape=input_shape) + sum_axis = array.parse_axes(sum_axis, shape=input_shape) self.sum_axis: Tuple[int, ...] = sum_axis super().__init__(input_shape=input_shape, input_dtype=input_dtype, jit=jit, **kwargs) @@ -322,10 +322,10 @@ def __init__( functions of the LinearOperator. """ - if util.is_nested(input_shape): + if array.is_nested(input_shape): output_shape = blockarray.indexed_shape(input_shape, idx) else: - output_shape = util.indexed_shape(input_shape, idx) + output_shape = array.indexed_shape(input_shape, idx) self.idx: ArrayIndex = idx super().__init__( diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 6f7bfd9d6..195869b0c 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -25,8 +25,8 @@ import jax import scico.numpy as snp +from scico.array import no_nan_divide from scico.linop import Diagonal, Identity, LinearOperator -from scico.math import safe_divide from scico.typing import Shape from ._dft import DFT @@ -266,7 +266,7 @@ def adequate_sampling(self): def pinv(self, y): """Apply pseudoinverse of Angular Spectrum propagator.""" - diag_inv = safe_divide(1, self.D.diagonal) + diag_inv = no_nan_divide(1, self.D.diagonal) return self.F.inv(diag_inv * self.F(y)) diff --git a/scico/loss.py b/scico/loss.py index 1f6d893cf..9629b3c1b 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -15,11 +15,11 @@ import scico.numpy as snp from scico import functional, linop, operator +from scico.array import ensure_on_device from scico.blockarray import BlockArray from scico.scipy.special import gammaln from scico.solver import cg from scico.typing import JaxArray -from scico.util import ensure_on_device __author__ = """\n""".join( ["Luke Pfister ", "Thilo Balke "] diff --git a/scico/math.py b/scico/math.py deleted file mode 100644 index d1615369a..000000000 --- a/scico/math.py +++ /dev/null @@ -1,121 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers -# All rights reserved. BSD 3-clause License. -# This file is part of the SCICO package. Details of the copyright and -# user license can be found in the 'LICENSE' file distributed with the -# package. - -"""Math functions.""" - - -from typing import Union - -import scico.numpy as snp -from scico.blockarray import BlockArray -from scico.typing import DType, JaxArray - -__author__ = """\n""".join( - ["Luke Pfister ", "Brendt Wohlberg "] -) - - -def safe_divide( - x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray] -) -> Union[BlockArray, JaxArray]: - """Return `x/y`, with 0 instead of NaN where `y` is 0. - - Args: - x: Numerator. - y: Denominator. - - Returns: - `x / y` with 0 wherever `y == 0`. - """ - - return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) - - -def rel_res(ax: Union[BlockArray, JaxArray], b: Union[BlockArray, JaxArray]) -> float: - r"""Relative residual of the solution to a linear equation. - - The standard relative residual for the linear system - :math:`A \mathbf{x} = \mathbf{b}` is :math:`\|\mathbf{b} - - A \mathbf{x}\|_2 / \|\mathbf{b}\|_2`. This function computes a - variant :math:`\|\mathbf{b} - A \mathbf{x}\|_2 / - \max(\|A\mathbf{x}\|_2, \|\mathbf{b}\|_2)` that is robust to the case - :math:`\mathbf{b} = 0`. - - Args: - ax: Linear component :math:`A \mathbf{x}` of equation. - b: Constant component :math:`\mathbf{b}` of equation. - - Returns: - Relative residual value. - """ - - nrm = max(snp.linalg.norm(ax.ravel()), snp.linalg.norm(b.ravel())) - if nrm == 0.0: - return 0.0 - return snp.linalg.norm((b - ax).ravel()) / nrm - - -def is_real_dtype(dtype: DType) -> bool: - """Determine whether a dtype is real. - - Args: - dtype: A numpy or scico.numpy dtype (e.g. np.float32, - snp.complex64). - - Returns: - False if the dtype is complex, otherwise True. - """ - return snp.dtype(dtype).kind != "c" - - -def is_complex_dtype(dtype: DType) -> bool: - """Determine whether a dtype is complex. - - Args: - dtype: A numpy or scico.numpy dtype (e.g. np.float32, - snp.complex64). - - Returns: - True if the dtype is complex, otherwise False. - """ - return snp.dtype(dtype).kind == "c" - - -def real_dtype(dtype: DType) -> DType: - """Construct the corresponding real dtype for a given complex dtype. - - Construct the corresponding real dtype for a given complex dtype, - e.g. the real dtype corresponding to `np.complex64` is - `np.float32`. - - Args: - dtype: A complex numpy or scico.numpy dtype (e.g. np.complex64, - np.complex128). - - Returns: - The real dtype corresponding to the input dtype - """ - - return snp.zeros(1, dtype).real.dtype - - -def complex_dtype(dtype: DType) -> DType: - """Construct the corresponding complex dtype for a given real dtype. - - Construct the corresponding complex dtype for a given real dtype, - e.g. the complex dtype corresponding to `np.float32` is - `np.complex64`. - - Args: - dtype: A real numpy or scico.numpy dtype (e.g. np.float32, - np.float64). - - Returns: - The complex dtype corresponding to the input dtype. - """ - - return (snp.zeros(1, dtype) + 1j).dtype diff --git a/scico/metric.py b/scico/metric.py index a9f89cd29..4f246feed 100644 --- a/scico/metric.py +++ b/scico/metric.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -140,3 +140,27 @@ def bsnr(blurry: Union[JaxArray, BlockArray], noisy: Union[JaxArray, BlockArray] with np.errstate(divide="ignore"): rt = blrvar / nsevar return 10.0 * snp.log10(rt) + + +def rel_res(ax: Union[BlockArray, JaxArray], b: Union[BlockArray, JaxArray]) -> float: + r"""Relative residual of the solution to a linear equation. + + The standard relative residual for the linear system + :math:`A \mathbf{x} = \mathbf{b}` is :math:`\|\mathbf{b} - + A \mathbf{x}\|_2 / \|\mathbf{b}\|_2`. This function computes a + variant :math:`\|\mathbf{b} - A \mathbf{x}\|_2 / + \max(\|A\mathbf{x}\|_2, \|\mathbf{b}\|_2)` that is robust to the case + :math:`\mathbf{b} = 0`. + + Args: + ax: Linear component :math:`A \mathbf{x}` of equation. + b: Constant component :math:`\mathbf{b}` of equation. + + Returns: + Relative residual value. + """ + + nrm = max(snp.linalg.norm(ax.ravel()), snp.linalg.norm(b.ravel())) + if nrm == 0.0: + return 0.0 + return snp.linalg.norm((b - ax).ravel()) / nrm diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index bbdffc144..4b5c86dd1 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -24,6 +24,8 @@ import jax from jax import numpy as jnp +from scico.array import is_nested + # These functions rely on the definition of a BlockArray and must be in # scico.blockarray to avoid a circular import from scico.blockarray import ( @@ -36,7 +38,6 @@ reshape, ) from scico.typing import BlockShape, JaxArray, Shape -from scico.util import is_nested from ._create import ( empty, diff --git a/scico/numpy/_create.py b/scico/numpy/_create.py index a47366d8a..3754ba477 100644 --- a/scico/numpy/_create.py +++ b/scico/numpy/_create.py @@ -14,9 +14,9 @@ import jax from jax import numpy as jnp +from scico.array import is_nested from scico.blockarray import BlockArray from scico.typing import BlockShape, DType, JaxArray, Shape -from scico.util import is_nested def zeros( diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index c3525b74f..f2bbde399 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -13,10 +13,10 @@ from jax.scipy.signal import convolve from scico._generic_operators import LinearOperator, Operator +from scico.array import is_nested from scico.blockarray import BlockArray from scico.linop import Convolve, ConvolveByX from scico.typing import BlockShape, DType, JaxArray -from scico.util import is_nested __author__ = """Luke Pfister """ diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index efcfe06f5..1dbaa6a13 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -14,13 +14,14 @@ from typing import Callable, List, Optional, Union import scico.numpy as snp +from scico.array import ensure_on_device from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator from scico.numpy.linalg import norm from scico.typing import JaxArray -from scico.util import Timer, ensure_on_device +from scico.util import Timer __author__ = """\n""".join( ["Luke Pfister ", "Brendt Wohlberg "] diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 324fc352b..94777d97e 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -14,13 +14,14 @@ from typing import Callable, Optional, Union import scico.numpy as snp +from scico.array import ensure_on_device from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import LinearOperator from scico.numpy.linalg import norm from scico.typing import JaxArray -from scico.util import Timer, ensure_on_device +from scico.util import Timer __author__ = "Brendt Wohlberg " diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index 7a17e3c86..5681d3304 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -18,17 +18,17 @@ from jax.scipy.sparse.linalg import cg as jax_cg import scico.numpy as snp +from scico.array import ensure_on_device, is_real_dtype from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import CircularConvolve, Identity, LinearOperator from scico.loss import SquaredL2Loss, WeightedSquaredL2Loss -from scico.math import is_real_dtype from scico.numpy.linalg import norm from scico.solver import cg as scico_cg from scico.solver import minimize from scico.typing import JaxArray -from scico.util import Timer, ensure_on_device +from scico.util import Timer __author__ = """\n""".join( ["Luke Pfister ", "Brendt Wohlberg "] diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index 14ace38e7..de0f553c2 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2021 by SCICO Developers +# Copyright (C) 2020-2022 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -16,12 +16,13 @@ import jax import scico.numpy as snp +from scico.array import ensure_on_device from scico.blockarray import BlockArray from scico.diagnostics import IterationStats from scico.functional import Functional from scico.loss import Loss from scico.typing import JaxArray -from scico.util import Timer, ensure_on_device +from scico.util import Timer __author__ = """\n""".join( [ diff --git a/scico/random.py b/scico/random.py index 4ea5f716a..2e56b90a9 100644 --- a/scico/random.py +++ b/scico/random.py @@ -60,9 +60,9 @@ import jax +from scico.array import is_nested from scico.blockarray import BlockArray, block_sizes from scico.typing import BlockShape, DType, JaxArray, PRNGKey, Shape -from scico.util import is_nested def _add_seed(fun): diff --git a/scico/test/linop/test_linop.py b/scico/test/linop/test_linop.py index bd9f2734b..4963b3051 100644 --- a/scico/test/linop/test_linop.py +++ b/scico/test/linop/test_linop.py @@ -29,9 +29,9 @@ def adjoint_test( """Check the validity of A.conj().T as the adjoint for a LinearOperator A. Args: - A : LinearOperator to test - key: PRNGKey for generating `x`. - rtol: Relative tolerance + A: LinearOperator to test. + key: PRNGKey for generating `x`. + rtol: Relative tolerance. """ assert linop.valid_adjoint(A, A.H, key=key, eps=rtol, x=x, y=y) diff --git a/scico/test/test_array.py b/scico/test/test_array.py new file mode 100644 index 000000000..bce83b1cf --- /dev/null +++ b/scico/test/test_array.py @@ -0,0 +1,173 @@ +import warnings + +import numpy as np + +from jax.interpreters.xla import DeviceArray + +import pytest + +import scico.numpy as snp +from scico.array import ( + complex_dtype, + ensure_on_device, + indexed_shape, + is_complex_dtype, + is_nested, + is_real_dtype, + no_nan_divide, + parse_axes, + real_dtype, + slice_length, +) +from scico.blockarray import BlockArray +from scico.random import randn + + +def test_ensure_on_device(): + # Used to restore the warnings after the context is used + with warnings.catch_warnings(): + # Ignores warning raised by ensure_on_device + warnings.filterwarnings(action="ignore", category=UserWarning) + + NP = np.ones(2) + SNP = snp.ones(2) + BA = BlockArray.array([NP, SNP]) + NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) + + assert isinstance(NP_, DeviceArray) + + assert isinstance(SNP_, DeviceArray) + assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() + + assert isinstance(BA_, BlockArray) + assert BA._data.unsafe_buffer_pointer() == BA_._data.unsafe_buffer_pointer() + + np.testing.assert_raises(TypeError, ensure_on_device, [1, 1, 1]) + + NP_ = ensure_on_device(NP) + assert isinstance(NP_, DeviceArray) + + +def test_no_nan_divide_array(): + x, key = randn((4,), dtype=np.float32) + y, key = randn(x.shape, dtype=np.float32, key=key) + y = y.at[0].set(0) + + res = no_nan_divide(x, y) + + assert res[0] == 0 + idx = y != 0 + np.testing.assert_allclose(res[idx], x[idx] / y[idx]) + + +def test_no_nan_divide_blockarray(): + x, key = randn(((3, 3), (4,)), dtype=np.float32) + + y, key = randn(x.shape, dtype=np.float32, key=key) + y = y.at[1].set(0 * y[1]) + + res = no_nan_divide(x, y) + + assert snp.all(res[1] == 0.0) + np.testing.assert_allclose(res[0], x[0] / y[0]) + + +def test_parse_axes(): + axes = None + np.testing.assert_raises(ValueError, parse_axes, axes) + + axes = None + assert parse_axes(axes, np.shape([[1, 1], [1, 1]])) == [0, 1] + + axes = None + assert parse_axes(axes, np.shape([[1, 1], [1, 1]]), default=[0]) == [0] + + axes = [1, 2] + assert parse_axes(axes) == axes + + axes = 1 + assert parse_axes(axes) == (1,) + + axes = "axes" + np.testing.assert_raises(ValueError, parse_axes, axes) + + axes = 2 + np.testing.assert_raises(ValueError, parse_axes, axes, np.shape([1])) + + axes = (1, 2, 2) + np.testing.assert_raises(ValueError, parse_axes, axes) + + +@pytest.mark.parametrize("length", (4, 5, 8, 16, 17)) +@pytest.mark.parametrize("start", (None, 0, 1, 2, 3)) +@pytest.mark.parametrize("stop", (None, 0, 1, 2, -2, -1)) +@pytest.mark.parametrize("stride", (None, 1, 2, 3)) +def test_slice_length(length, start, stop, stride): + x = np.zeros(length) + slc = slice(start, stop, stride) + assert x[slc].size == slice_length(length, slc) + + +@pytest.mark.parametrize("length", (4, 5)) +@pytest.mark.parametrize("slc", (0, 1, -4, Ellipsis)) +def test_slice_length_other(length, slc): + x = np.zeros(length) + assert x[slc].size == slice_length(length, slc) + + +@pytest.mark.parametrize("shape", ((8, 8, 1), (7, 1, 6, 5))) +@pytest.mark.parametrize( + "slc", + ( + np.s_[0:5], + np.s_[:, 0:4], + np.s_[2:, :, :-2], + np.s_[..., 2:], + np.s_[..., 2:, :], + np.s_[1:, ..., 2:], + ), +) +def test_indexed_shape(shape, slc): + x = np.zeros(shape) + assert x[slc].shape == indexed_shape(shape, slc) + + +def test_is_nested(): + # list + assert is_nested([1, 2, 3]) == False + + # tuple + assert is_nested((1, 2, 3)) == False + + # list of lists + assert is_nested([[1, 2], [4, 5], [3]]) == True + + # list of lists + scalar + assert is_nested([[1, 2], 3]) == True + + # list of tuple + scalar + assert is_nested([(1, 2), 3]) == True + + # tuple of tuple + scalar + assert is_nested(((1, 2), 3)) == True + + # tuple of lists + scalar + assert is_nested(([1, 2], 3)) == True + + +def test_is_real_dtype(): + assert not is_real_dtype(snp.complex64) + assert is_real_dtype(snp.float32) + + +def test_is_complex_dtype(): + assert is_complex_dtype(snp.complex64) + assert not is_complex_dtype(snp.float32) + + +def test_real_dtype(): + assert real_dtype(snp.complex64) == snp.float32 + + +def test_complex_dtype(): + assert complex_dtype(snp.float32) == snp.complex64 diff --git a/scico/test/test_math.py b/scico/test/test_math.py deleted file mode 100644 index 57e8a2ba0..000000000 --- a/scico/test/test_math.py +++ /dev/null @@ -1,68 +0,0 @@ -import numpy as np - -import scico.numpy as snp -from scico.math import ( - complex_dtype, - is_complex_dtype, - is_real_dtype, - real_dtype, - rel_res, - safe_divide, -) -from scico.random import randn - - -def test_safe_divide_array(): - x, key = randn((4,), dtype=np.float32) - y, key = randn(x.shape, dtype=np.float32, key=key) - y = y.at[0].set(0) - - res = safe_divide(x, y) - - assert res[0] == 0 - idx = y != 0 - np.testing.assert_allclose(res[idx], x[idx] / y[idx]) - - -def test_safe_divide_blockarray(): - x, key = randn(((3, 3), (4,)), dtype=np.float32) - - y, key = randn(x.shape, dtype=np.float32, key=key) - y = y.at[1].set(0 * y[1]) - - res = safe_divide(x, y) - - assert snp.all(res[1] == 0.0) - np.testing.assert_allclose(res[0], x[0] / y[0]) - - -def test_rel_res(): - A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32) - x = snp.array([[3], [-2]], dtype=snp.float32) - Ax = snp.matmul(A, x) - b = snp.array([[8], [3], [-5]], dtype=snp.float32) - assert 0.0 == rel_res(Ax, b) - - A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32) - x = snp.array([[0], [0]], dtype=snp.float32) - Ax = snp.matmul(A, x) - b = snp.array([[0], [0], [0]], dtype=snp.float32) - assert 0.0 == rel_res(Ax, b) - - -def test_is_real_dtype(): - assert not is_real_dtype(snp.complex64) - assert is_real_dtype(snp.float32) - - -def test_is_complex_dtype(): - assert is_complex_dtype(snp.complex64) - assert not is_complex_dtype(snp.float32) - - -def test_real_dtype(): - assert real_dtype(snp.complex64) == snp.float32 - - -def test_complex_dtype(): - assert complex_dtype(snp.float32) == snp.complex64 diff --git a/scico/test/test_metric.py b/scico/test/test_metric.py index 6fc1555da..80973b69d 100644 --- a/scico/test/test_metric.py +++ b/scico/test/test_metric.py @@ -1,5 +1,6 @@ import numpy as np +import scico.numpy as snp from scico import metric @@ -53,3 +54,17 @@ def test_bsnr(self): n /= np.sqrt(np.var(n)) y = x + n assert np.abs(metric.bsnr(x, y)) < 1e-6 + + +def test_rel_res(): + A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32) + x = snp.array([[3], [-2]], dtype=snp.float32) + Ax = snp.matmul(A, x) + b = snp.array([[8], [3], [-5]], dtype=snp.float32) + assert 0.0 == metric.rel_res(Ax, b) + + A = snp.array([[2, -1], [1, 0], [-1, 1]], dtype=snp.float32) + x = snp.array([[0], [0]], dtype=snp.float32) + Ax = snp.matmul(A, x) + b = snp.array([[0], [0], [0]], dtype=snp.float32) + assert 0.0 == metric.rel_res(Ax, b) diff --git a/scico/test/test_util.py b/scico/test/test_util.py index ac96ee4a4..bdb6c449f 100644 --- a/scico/test/test_util.py +++ b/scico/test/test_util.py @@ -1,52 +1,14 @@ import socket import urllib.error as urlerror -import warnings import numpy as np import jax -from jax.interpreters.xla import DeviceArray import pytest import scico.numpy as snp -from scico.blockarray import BlockArray -from scico.util import ( - ContextTimer, - Timer, - check_for_tracer, - ensure_on_device, - indexed_shape, - is_nested, - parse_axes, - slice_length, - url_get, -) - - -def test_ensure_on_device(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - NP = np.ones(2) - SNP = snp.ones(2) - BA = BlockArray.array([NP, SNP]) - NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) - - assert isinstance(NP_, DeviceArray) - - assert isinstance(SNP_, DeviceArray) - assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() - - assert isinstance(BA_, BlockArray) - assert BA._data.unsafe_buffer_pointer() == BA_._data.unsafe_buffer_pointer() - - np.testing.assert_raises(TypeError, ensure_on_device, [1, 1, 1]) - - NP_ = ensure_on_device(NP) - assert isinstance(NP_, DeviceArray) +from scico.util import ContextTimer, Timer, check_for_tracer, url_get # See https://stackoverflow.com/a/33117579 @@ -77,66 +39,6 @@ def test_url_get(): np.testing.assert_raises(ValueError, url_get, url, -1) -def test_parse_axes(): - axes = None - np.testing.assert_raises(ValueError, parse_axes, axes) - - axes = None - assert parse_axes(axes, np.shape([[1, 1], [1, 1]])) == [0, 1] - - axes = None - assert parse_axes(axes, np.shape([[1, 1], [1, 1]]), default=[0]) == [0] - - axes = [1, 2] - assert parse_axes(axes) == axes - - axes = 1 - assert parse_axes(axes) == (1,) - - axes = "axes" - np.testing.assert_raises(ValueError, parse_axes, axes) - - axes = 2 - np.testing.assert_raises(ValueError, parse_axes, axes, np.shape([1])) - - axes = (1, 2, 2) - np.testing.assert_raises(ValueError, parse_axes, axes) - - -@pytest.mark.parametrize("length", (4, 5, 8, 16, 17)) -@pytest.mark.parametrize("start", (None, 0, 1, 2, 3)) -@pytest.mark.parametrize("stop", (None, 0, 1, 2, -2, -1)) -@pytest.mark.parametrize("stride", (None, 1, 2, 3)) -def test_slice_length(length, start, stop, stride): - x = np.zeros(length) - slc = slice(start, stop, stride) - assert x[slc].size == slice_length(length, slc) - - -@pytest.mark.parametrize("length", (4, 5)) -@pytest.mark.parametrize("slc", (0, 1, -4, Ellipsis)) -def test_slice_length_other(length, slc): - x = np.zeros(length) - assert x[slc].size == slice_length(length, slc) - - -@pytest.mark.parametrize("shape", ((8, 8, 1), (7, 1, 6, 5))) -@pytest.mark.parametrize( - "slc", - ( - np.s_[0:5], - np.s_[:, 0:4], - np.s_[2:, :, :-2], - np.s_[..., 2:], - np.s_[..., 2:, :], - np.s_[1:, ..., 2:], - ), -) -def test_indexed_shape(shape, slc): - x = np.zeros(shape) - assert x[slc].shape == indexed_shape(shape, slc) - - def test_check_for_tracer(): # Using examples from Jax documentation @@ -158,29 +60,6 @@ def norm(X): mv(x) -def test_is_nested(): - # list - assert is_nested([1, 2, 3]) == False - - # tuple - assert is_nested((1, 2, 3)) == False - - # list of lists - assert is_nested([[1, 2], [4, 5], [3]]) == True - - # list of lists + scalar - assert is_nested([[1, 2], 3]) == True - - # list of tuple + scalar - assert is_nested([(1, 2), 3]) == True - - # tuple of tuple + scalar - assert is_nested(((1, 2), 3)) == True - - # tuple of lists + scalar - assert is_nested(([1, 2], 3)) == True - - def test_timer_basic(): t = Timer() t.start() diff --git a/scico/util.py b/scico/util.py index 8973e49c1..5d14518c4 100644 --- a/scico/util.py +++ b/scico/util.py @@ -5,9 +5,8 @@ # user license can be found in the 'LICENSE' file distributed with the # package. -"""Utility functions.""" +"""General utility functions.""" -# The timer classes in this module are copied from https://github.com/bwohlberg/sporco from __future__ import annotations @@ -15,27 +14,18 @@ import socket import urllib.error as urlerror import urllib.request as urlrequest -import warnings from functools import wraps from timeit import default_timer as timer -from typing import Any, Callable, List, Optional, Tuple, Union - -import numpy as np +from typing import Callable, List, Optional, Union import jax from jax.interpreters.batching import BatchTracer from jax.interpreters.partial_eval import DynamicJaxprTracer -from jax.interpreters.pxla import ShardedDeviceArray -from jax.interpreters.xla import DeviceArray - -import scico.blockarray -from scico.typing import ArrayIndex, Axes, AxisIndex, JaxArray, Shape __author__ = """\n""".join( [ "Brendt Wohlberg ", "Luke Pfister ", - "Thilo Balke ", ] ) @@ -60,53 +50,28 @@ def device_info(devid: int = 0) -> str: # pragma: no cover return info -def ensure_on_device( - *arrays: Union[np.ndarray, JaxArray, scico.blockarray.BlockArray] -) -> Union[JaxArray, scico.blockarray.BlockArray]: - """Cast ndarrays to DeviceArrays. - - Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays, - and ShardedDeviceArray as is. This is intended to be used when - initializing optimizers and functionals so that all arrays are either - DeviceArrays, BlockArrays, or ShardedDeviceArray. - - Args: - *arrays: One or more input arrays (ndarray, DeviceArray, - BlockArray, or ShardedDeviceArray). - - Returns: - arrays : Modified array or arrays. Modified are only those that - were necessary. +def check_for_tracer(func: Callable) -> Callable: + """Check if positional arguments to ``func`` are jax tracers. - Raises: - TypeError: If the arrays contain something that is neither - ndarray, DeviceArray, BlockArray, nor ShardedDeviceArray. + This is intended to be used as a decorator for functions that call + external code from within SCICO. At present, external functions + cannot be jit-ed or vmap/pmaped. This decorator checks for signs of + jit/vmap/pmap and raises an appropriate exception. """ - arrays = list(arrays) - for i, array in enumerate(arrays): - - if isinstance(array, np.ndarray): - warnings.warn( - f"Argument {i+1} of {len(arrays)} is an np.ndarray. " - f"Will cast it to DeviceArray. " - f"To suppress this warning cast all np.ndarrays to DeviceArray first.", - stacklevel=2, + @wraps(func) + def wrapper(*args, **kwargs): + if any([isinstance(x, DynamicJaxprTracer) for x in args]): + raise TypeError( + f"DynamicJaxprTracer found in {func.__name__}; did you jit this function?" ) - - arrays[i] = jax.device_put(arrays[i]) - elif not isinstance( - array, - (DeviceArray, scico.blockarray.BlockArray, ShardedDeviceArray), - ): + if any([isinstance(x, BatchTracer) for x in args]): raise TypeError( - "Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or " - f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}." + f"BatchTracer found in {func.__name__}; did you vmap/pmap this function?" ) + return func(*args, **kwargs) - if len(arrays) == 1: - return arrays[0] - return arrays + return wrapper def url_get(url: str, maxtry: int = 3, timeout: int = 10) -> io.BytesIO: # pragma: no cover @@ -139,149 +104,7 @@ def url_get(url: str, maxtry: int = 3, timeout: int = 10) -> io.BytesIO: # prag return io.BytesIO(cntnt) -def parse_axes( - axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None -) -> List[int]: - """Normalize `axes` to a list and optionally ensure correctness. - - Normalize `axes` to a list and (optionally) ensure that entries refer - to axes that exist in `shape`. - - Args: - axes: user specification of one or more axes: int, list, tuple, - or ``None``. - shape: the shape of the array of which axes are being specified. - If not ``None``, `axes` is checked to make sure its entries - refer to axes that exist in `shape`. - default: default value to return if `axes` is ``None``. By - default, `list(range(len(shape)))`. - - Returns: - List of axes (never an int, never ``None``). - """ - - if axes is None: - if default is None: - if shape is None: - raise ValueError("`axes` cannot be `None` without a default or shape specified.") - axes = list(range(len(shape))) - else: - axes = default - elif isinstance(axes, (list, tuple)): - axes = axes - elif isinstance(axes, int): - axes = (axes,) - else: - raise ValueError(f"Could not understand axes {axes} as a list of axes") - if shape is not None and max(axes) >= len(shape): - raise ValueError( - f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}." - ) - if len(set(axes)) != len(axes): - raise ValueError(f"Duplicate value in axes {axes}; each axis must be unique.") - return axes - - -def slice_length(length: int, slc: AxisIndex) -> int: - """Determine the length of an array axis after slicing. - - Args: - length: Length of axis being sliced. - slc: Slice/indexing to be applied to axis. - - Returns: - Length of sliced axis. - - Raises: - ValueError: If `slc` is an integer index that is out bounds for - the axis length. - """ - if slc is Ellipsis: - return length - if isinstance(slc, int): - if slc < -length or slc > length - 1: - raise ValueError(f"Index {slc} out of bounds for axis of length {length}.") - return 1 - start, stop, stride = slc.indices(length) - if start > stop: - start = stop - return (stop - start + stride - 1) // stride - - -def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int]: - """Determine the shape of an array after indexing/slicing. - - Args: - shape: Shape of array. - idx: Indexing expression. - - Returns: - Shape of indexed/sliced array. - - Raises: - ValueError: If `idx` is longer than `shape`. - """ - if not isinstance(idx, tuple): - idx = (idx,) - if len(idx) > len(shape): - raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") - idx_shape = list(shape) - offset = 0 - for axis, ax_idx in enumerate(idx): - print(axis, offset) - if ax_idx is Ellipsis: - offset = len(shape) - len(idx) - continue - idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) - return tuple(idx_shape) - - -def is_nested(x: Any) -> bool: - """Check if input is a list/tuple containing at least one list/tuple. - - Args: - x: Object to be tested. - - Returns: - True if ``x`` is a list/tuple of list/tuples, False otherwise. - - - Example: - >>> is_nested([1, 2, 3]) - False - >>> is_nested([(1,2), (3,)]) - True - >>> is_nested([ [1, 2], 3]) - True - - """ - if isinstance(x, (list, tuple)): - return any([isinstance(_, (list, tuple)) for _ in x]) - return False - - -def check_for_tracer(func: Callable) -> Callable: - """Check if positional arguments to ``func`` are jax tracers. - - This is intended to be used as a decorator for functions that call - external code from within SCICO. At present, external functions - cannot be jit-ed or vmap/pmaped. This decorator checks for signs of - jit/vmap/pmap and raises an appropriate exception. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - if any([isinstance(x, DynamicJaxprTracer) for x in args]): - raise TypeError( - f"DynamicJaxprTracer found in {func.__name__}; did you jit this function?" - ) - if any([isinstance(x, BatchTracer) for x in args]): - raise TypeError( - f"BatchTracer found in {func.__name__}; did you vmap/pmap this function?" - ) - return func(*args, **kwargs) - - return wrapper +# Timer classes are copied from https://github.com/bwohlberg/sporco class Timer: