Skip to content

Commit

Permalink
Merge branch 'main' into mike/svd
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg authored Feb 8, 2022
2 parents 3b6dc3a + 2b7a693 commit 2fcfd85
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Install scico requirements and run pytest

name: test
name: pytest

# Controls when the workflow will run
on:
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ SCICO is distributed as open-source software under a BSD 3-Clause License (see t

LANL open source approval reference C20091.

(c) 2020-2021. Triad National Security, LLC. All rights reserved.
(c) 2020-2022. Triad National Security, LLC. All rights reserved.
This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. Department of Energy/National Nuclear Security Administration. All rights in the program are reserved by Triad National Security, LLC, and the U.S. Department of Energy/National Nuclear Security Administration. The Government has granted for itself and others acting on its behalf a nonexclusive, paid-up, irrevocable worldwide license in this material to reproduce, prepare derivative works, distribute copies to the public, perform publicly and display publicly, and to permit others to do so.
2 changes: 1 addition & 1 deletion docs/source/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ under a BSD 3-Clause License
for details).
LANL open source approval reference C20091.

© 2020-2021. Triad National Security, LLC. All rights reserved.
© 2020-2022. Triad National Security, LLC. All rights reserved.
This program was produced under
U.S. Government contract 89233218CNA000001
for Los Alamos National Laboratory (LANL),
Expand Down
5 changes: 2 additions & 3 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def wrapper(a, b):


class Operator:
"""Generic Operator class"""
"""Generic Operator class."""

def __repr__(self):
return f"""{type(self)}
Expand All @@ -80,8 +80,7 @@ def __init__(
output_dtype: Optional[DType] = None,
jit: bool = False,
):
r"""Operator init method.
r"""
Args:
input_shape: Shape of input array.
output_shape: Shape of output array.
Expand Down
2 changes: 1 addition & 1 deletion scico/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def slice_length(length: int, idx: AxisIndex) -> Optional[int]:
return (stop - start + stride - 1) // stride


def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int]:
def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]:
"""Determine the shape of an array after indexing/slicing.
Args:
Expand Down
4 changes: 2 additions & 2 deletions scico/blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _decompose_index(idx: Union[int, Tuple(AxisIndex)]) -> Tuple:
return idxblk, idxarr


def indexed_shape(shape: Shape, idx: Union[int, Tuple(AxisIndex)]) -> Tuple[int]:
def indexed_shape(shape: Shape, idx: Union[int, Tuple[AxisIndex, ...]]) -> Tuple[int, ...]:
"""Determine the shape of the result of indexing a BlockArray.
Args:
Expand Down Expand Up @@ -866,7 +866,7 @@ def __init__(self, aval: _AbstractBlockArray, data: JaxArray):
def __repr__(self):
return "scico.blockarray.BlockArray: \n" + self._data.__repr__()

def __getitem__(self, idx: Union[int, Tuple(AxisIndex)]) -> JaxArray:
def __getitem__(self, idx: Union[int, Tuple[AxisIndex, ...]]) -> JaxArray:
idxblk, idxarr = _decompose_index(idx)
if idxblk < 0:
idxblk = self.num_blocks + idxblk
Expand Down
2 changes: 1 addition & 1 deletion scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:
if len(v.shape) == len(self.functional_list):
return BlockArray.array([fi.prox(vi, lam) for fi, vi in zip(self.functional_list, v)])
raise ValueError(
f"Number of blocks in x, {len(x.shape)}, and length of functional_list, "
f"Number of blocks in v, {len(v.shape)}, and length of functional_list, "
f"{len(self.functional_list)}, do not match"
)

Expand Down
4 changes: 2 additions & 2 deletions scico/linop/_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
operations, this must be `complex64` for proper adjoint
and gradient calculation.
axes: Axis or axes over which to apply finite difference
operator. If not specified, or `None`, differences are
operator. If not specified, or ``None``, differences are
evaluated along all axes.
append: Value to append to the input along each axis before
taking differences. Zero is a typical choice. If not
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
taking differences. Defaults to 0.
circular: If ``True``, perform circular differences, i.e.,
include x[-1] - x[0]. If ``True``, `append` must be
`None`.
``None``.
jit: If ``True``, jit the evaluation, adjoint, and gram
functions of the LinearOperator.
"""
Expand Down
6 changes: 3 additions & 3 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
num_channels: Number of pixels in the sinogram.
center_offset: Position of the detector center relative to the center-of-rotation,
in units of pixels
is_masked: If True, the valid region of the image is
is_masked: If ``True``, the valid region of the image is
determined by a mask defined as the circle inscribed
within the image boundary. Otherwise, the whole image
array is taken into account by projections.
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(
A: Forward operator.
scale: Scaling parameter.
W: Weighting diagonal operator. Must be non-negative.
If None, defaults to :class:`.Identity`.
If ``None``, defaults to :class:`.Identity`.
prox_kwargs: Dictionary of arguments passed to the
:meth:`svmbir.recon` prox routine. Defaults to
{"maxiter": 1000, "ctol": 0.001}.
Expand Down Expand Up @@ -338,7 +338,7 @@ def __init__(


def _unsqueeze(x: JaxArray, input_shape: Shape) -> JaxArray:
"""If x is 2D, make it 3D according to SVMBIR's convention."""
"""If x is 2D, make it 3D according to the SVMBIR convention."""
if len(input_shape) == 2:
x = x[snp.newaxis, :, :]
return x
2 changes: 1 addition & 1 deletion scico/optimize/_ladmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
"Time": "%8.2e",
}
itstat_attrib = ["itnum", "timer.elapsed()"]
# objective function can be evaluated if all 'g' functions can be evaluated
# objective function can be evaluated if 'g' function can be evaluated
if g.has_eval:
itstat_fields.update({"Objective": "%9.3e"})
itstat_attrib.append("objective()")
Expand Down
2 changes: 1 addition & 1 deletion scico/optimize/_primaldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
"Time": "%8.2e",
}
itstat_attrib = ["itnum", "timer.elapsed()"]
# objective function can be evaluated if all 'g' functions can be evaluated
# objective function can be evaluated if 'g' function can be evaluated
if g.has_eval:
itstat_fields.update({"Objective": "%9.3e"})
itstat_attrib.append("objective()")
Expand Down
46 changes: 24 additions & 22 deletions scico/optimize/pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,39 +442,41 @@ def x_step(v, L):

self.x_step = jax.jit(x_step)

# iteration number and time fields
itstat_fields = {
"Iter": "%d",
"Time": "%8.2e",
}
itstat_attrib = ["itnum", "timer.elapsed()"]
# objective function can be evaluated if 'g' function can be evaluated
if g.has_eval:
itstat_fields = {
"Iter": "%d",
"Time": "%8.2e",
"Objective": "%9.3e",
"L": "%9.3e",
"Residual": "%9.3e",
}
itstat_func = lambda pgm: (
pgm.itnum,
pgm.timer.elapsed(),
pgm.objective(self.x),
pgm.L,
pgm.norm_residual(),
)
else:
itstat_fields = {"Iter": "%d", "Time": "%8.2e", "Residual": "%9.3e"}
itstat_func = lambda pgm: (pgm.itnum, pgm.timer.elapsed(), pgm.norm_residual())

default_itstat_options = {
itstat_fields.update({"Objective": "%9.3e"})
itstat_attrib.append("objective()")
# step size and residual fields
itstat_fields.update({"L": "%9.3e", "Residual": "%9.3e"})
itstat_attrib.extend(["L", "norm_residual()"])

# dynamically create itstat_func; see https://stackoverflow.com/questions/24733831
itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")"
scope: dict[str, Callable] = {}
exec("def itstat_func(obj): " + itstat_return, scope)

default_itstat_options: dict[str, Union[dict, Callable, bool]] = {
"fields": itstat_fields,
"itstat_func": itstat_func,
"itstat_func": scope["itstat_func"],
"display": False,
}
if itstat_options:
default_itstat_options.update(itstat_options)
self.itstat_insert_func = default_itstat_options.pop("itstat_func", None)
self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func") # type: ignore
self.itstat_object = IterationStats(**default_itstat_options)

self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution

def objective(self, x) -> float:
def objective(self, x=None) -> float:
r"""Evaluate the objective function :math:`f(\mb{x}) + g(\mb{x})`."""
if x is None:
x = self.x
return self.f(x) + self.g(x)

def f_quad_approx(self, x, y, L) -> float:
Expand Down

0 comments on commit 2fcfd85

Please sign in to comment.