Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some functions to different modules #175

Merged
merged 8 commits into from
Jan 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Version 0.0.2 (unreleased)
• Move optimization algorithms into ``optimize`` subpackage.
• Additional iteration stats columns for iterative ADMM subproblem solvers.
• Renamed "Primal Rsdl" to "Prml Rsdl" in displayed iteration stats.
• Move some functions from ``util`` and ``math`` modules to new ``array``
module.
• Bump pinned `jaxlib` and `jax` versions to 0.1.70 and 0.2.19 respectively.


Expand Down
2 changes: 1 addition & 1 deletion examples/scriptcheck.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fi

# Set environment variables and paths. This script is assumed to be run
# from its root directory.
export PYTHONPATH=$((cd .. && pwd))
export PYTHONPATH=$(cd .. && pwd)
export PYTHONIOENCODING=utf-8
d='/tmp/scriptcheck_'$$
mkdir -p $d
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/denoise_tv_iso_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, operator, plot
from scico.array import ensure_on_device
from scico.blockarray import BlockArray
from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize
from scico.typing import JaxArray
from scico.util import device_info, ensure_on_device
from scico.util import device_info

"""
Create a ground truth image.
Expand Down
3 changes: 1 addition & 2 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@

import scico.numpy as snp
from scico._autograd import linear_adjoint
from scico.array import is_complex_dtype, is_nested
from scico.blockarray import BlockArray, block_sizes
from scico.math import is_complex_dtype
from scico.typing import BlockShape, DType, JaxArray, Shape
from scico.util import is_nested


def _wrap_mul_div_scalar(func):
Expand Down
281 changes: 281 additions & 0 deletions scico/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2022 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Utility functions for arrays, array shapes, array indexing, etc."""


from __future__ import annotations

import warnings
from typing import Any, List, Optional, Tuple, Union

import numpy as np

import jax
from jax.interpreters.pxla import ShardedDeviceArray
from jax.interpreters.xla import DeviceArray

import scico.blockarray
import scico.numpy as snp
from scico.typing import ArrayIndex, Axes, AxisIndex, JaxArray, Shape

__author__ = """\n""".join(
[
"Brendt Wohlberg <[email protected]>",
"Luke Pfister <[email protected]>",
"Thilo Balke <[email protected]>",
bwohlberg marked this conversation as resolved.
Show resolved Hide resolved
"Michael McCann <[email protected]>",
]
)


def ensure_on_device(
*arrays: Union[np.ndarray, JaxArray, scico.blockarray.BlockArray]
) -> Union[JaxArray, scico.blockarray.BlockArray]:
"""Cast ndarrays to DeviceArrays.

Cast ndarrays to DeviceArrays and leaves DeviceArrays, BlockArrays,
and ShardedDeviceArray as is. This is intended to be used when
initializing optimizers and functionals so that all arrays are either
DeviceArrays, BlockArrays, or ShardedDeviceArray.

Args:
*arrays: One or more input arrays (ndarray, DeviceArray,
BlockArray, or ShardedDeviceArray).

Returns:
Modified array or arrays. Modified are only those that were
necessary.

Raises:
TypeError: If the arrays contain something that is neither
ndarray, DeviceArray, BlockArray, nor ShardedDeviceArray.
"""
arrays = list(arrays)

for i, array in enumerate(arrays):

if isinstance(array, np.ndarray):
warnings.warn(
f"Argument {i+1} of {len(arrays)} is an np.ndarray. "
f"Will cast it to DeviceArray. "
f"To suppress this warning cast all np.ndarrays to DeviceArray first.",
stacklevel=2,
)

arrays[i] = jax.device_put(arrays[i])
elif not isinstance(
array,
(DeviceArray, scico.blockarray.BlockArray, ShardedDeviceArray),
):
raise TypeError(
"Each item of `arrays` must be ndarray, DeviceArray, BlockArray, or "
f"ShardedDeviceArray; Argument {i+1} of {len(arrays)} is {type(arrays[i])}."
)

if len(arrays) == 1:
return arrays[0]
return arrays


def no_nan_divide(
x: Union[BlockArray, JaxArray], y: Union[BlockArray, JaxArray]
) -> Union[BlockArray, JaxArray]:
"""Return `x/y`, with 0 instead of NaN where `y` is 0.

Args:
x: Numerator.
y: Denominator.

Returns:
`x / y` with 0 wherever `y == 0`.
"""

return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0)


def parse_axes(
axes: Axes, shape: Optional[Shape] = None, default: Optional[List[int]] = None
) -> List[int]:
"""Normalize `axes` to a list and optionally ensure correctness.

Normalize `axes` to a list and (optionally) ensure that entries refer
to axes that exist in `shape`.

Args:
axes: User specification of one or more axes: int, list, tuple,
or ``None``.
shape: The shape of the array of which axes are being specified.
If not ``None``, `axes` is checked to make sure its entries
refer to axes that exist in `shape`.
default: Default value to return if `axes` is ``None``. By
default, `list(range(len(shape)))`.

Returns:
List of axes (never an int, never ``None``).
"""

if axes is None:
if default is None:
if shape is None:
raise ValueError("`axes` cannot be `None` without a default or shape specified.")
axes = list(range(len(shape)))
else:
axes = default
elif isinstance(axes, (list, tuple)):
axes = axes
elif isinstance(axes, int):
axes = (axes,)
else:
raise ValueError(f"Could not understand axes {axes} as a list of axes")
if shape is not None and max(axes) >= len(shape):
raise ValueError(
f"Invalid axes {axes} specified; each axis must be less than `len(shape)`={len(shape)}."
)
if len(set(axes)) != len(axes):
raise ValueError(f"Duplicate value in axes {axes}; each axis must be unique.")
return axes


def slice_length(length: int, slc: AxisIndex) -> int:
"""Determine the length of an array axis after slicing.

Args:
length: Length of axis being sliced.
slc: Slice/indexing to be applied to axis.

Returns:
Length of sliced axis.

Raises:
ValueError: If `slc` is an integer index that is out bounds for
the axis length.
"""
if slc is Ellipsis:
return length
if isinstance(slc, int):
if slc < -length or slc > length - 1:
raise ValueError(f"Index {slc} out of bounds for axis of length {length}.")
return 1
start, stop, stride = slc.indices(length)
if start > stop:
start = stop
return (stop - start + stride - 1) // stride


def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int]:
"""Determine the shape of an array after indexing/slicing.

Args:
shape: Shape of array.
idx: Indexing expression.

Returns:
Shape of indexed/sliced array.

Raises:
ValueError: If `idx` is longer than `shape`.
"""
if not isinstance(idx, tuple):
idx = (idx,)
if len(idx) > len(shape):
raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.")
idx_shape = list(shape)
offset = 0
for axis, ax_idx in enumerate(idx):
print(axis, offset)
if ax_idx is Ellipsis:
offset = len(shape) - len(idx)
continue
idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx)
return tuple(idx_shape)


def is_nested(x: Any) -> bool:
"""Check if input is a list/tuple containing at least one list/tuple.

Args:
x: Object to be tested.

Returns:
True if ``x`` is a list/tuple of list/tuples, False otherwise.


Example:
>>> is_nested([1, 2, 3])
False
>>> is_nested([(1,2), (3,)])
True
>>> is_nested([[1, 2], 3])
True

"""
if isinstance(x, (list, tuple)):
return any([isinstance(_, (list, tuple)) for _ in x])
return False


def is_real_dtype(dtype: DType) -> bool:
"""Determine whether a dtype is real.

Args:
dtype: A numpy or scico.numpy dtype (e.g. np.float32,
snp.complex64).

Returns:
False if the dtype is complex, otherwise True.
"""
return snp.dtype(dtype).kind != "c"


def is_complex_dtype(dtype: DType) -> bool:
"""Determine whether a dtype is complex.

Args:
dtype: A numpy or scico.numpy dtype (e.g. np.float32,
snp.complex64).

Returns:
True if the dtype is complex, otherwise False.
"""
return snp.dtype(dtype).kind == "c"


def real_dtype(dtype: DType) -> DType:
"""Construct the corresponding real dtype for a given complex dtype.

Construct the corresponding real dtype for a given complex dtype,
e.g. the real dtype corresponding to `np.complex64` is
`np.float32`.

Args:
dtype: A complex numpy or scico.numpy dtype (e.g. np.complex64,
np.complex128).

Returns:
The real dtype corresponding to the input dtype
"""

return snp.zeros(1, dtype).real.dtype


def complex_dtype(dtype: DType) -> DType:
"""Construct the corresponding complex dtype for a given real dtype.

Construct the corresponding complex dtype for a given real dtype,
e.g. the complex dtype corresponding to `np.float32` is
`np.complex64`.

Args:
dtype: A real numpy or scico.numpy dtype (e.g. np.float32,
np.float64).

Returns:
The complex dtype corresponding to the input dtype.
"""

return (snp.zeros(1, dtype) + 1j).dtype
12 changes: 6 additions & 6 deletions scico/blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@

from jaxlib.xla_extension import Buffer

from scico import util
from scico import array
from scico.typing import Axes, AxisIndex, BlockShape, DType, JaxArray, Shape

_arraylikes = (Buffer, DeviceArray, np.ndarray)
Expand Down Expand Up @@ -521,7 +521,7 @@ def reshape(
always returned.
"""

if util.is_nested(newshape):
if array.is_nested(newshape):
# x is a blockarray
return BlockArray.array_from_flattened(a, newshape)

Expand Down Expand Up @@ -566,13 +566,13 @@ def block_sizes(shape: Union[Shape, BlockShape]) -> Axes:
)

out = []
if util.is_nested(shape):
if array.is_nested(shape):
# shape is nested -> at least one element came from a blockarray
for y in shape:
if util.is_nested(y):
if array.is_nested(y):
# recursively calculate the block size until we arrive at
# a tuple (shape of a non-block array)
while util.is_nested(y):
while array.is_nested(y):
y = block_sizes(y)
out.append(np.sum(y)) # adjacent block sizes are added together
else:
Expand Down Expand Up @@ -630,7 +630,7 @@ def indexed_shape(shape: Shape, idx: Union[int, Tuple(AxisIndex)]) -> Tuple[int]
idxblk = len(shape) + idxblk
if idxarr is None:
return shape[idxblk]
return util.indexed_shape(shape[idxblk], idxarr)
return array.indexed_shape(shape[idxblk], idxarr)


def _flatten_blockarrays(inp, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from jax import jit

from scico import numpy as snp
from scico.array import no_nan_divide
from scico.blockarray import BlockArray
from scico.math import safe_divide
from scico.numpy import count_nonzero
from scico.numpy.linalg import norm
from scico.typing import JaxArray
Expand Down Expand Up @@ -247,7 +247,7 @@ def prox(
"""

length = norm(x, axis=self.l2_axis, keepdims=True)
direction = safe_divide(x, length)
direction = no_nan_divide(x, length)

new_length = length - lam
# set negative values to zero without `if`
Expand Down
Loading