Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issue #395 #396

Merged
merged 6 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,28 @@ the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``:
>>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # doctest: +SKIP
Array([nan, nan], dtype=float32)

This can be fixed by defining the squared :math:`\ell_2` norm directly as
``g = lambda x: snp.sum(x**2)``. The gradient will work as expected:
This can be fixed (assuming real-valued arrays only) by defining the
squared :math:`\ell_2` norm directly as ``g = lambda x: snp.sum(x**2)``.
The gradient will work as expected:

::

>>> g = lambda x: snp.sum(x**2)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP
Array([0., 0.], dtype=float32)

If complex-valued arrays also need to be supported, a minor modification is
necessary:

::

>>> g = lambda x: snp.sum(snp.abs(x)**2)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP
Array([0., 0.], dtype=float32)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.complex64)) #doctest: +SKIP
Array([0.-0.j, 0.-0.j], dtype=complex64)


An alternative is to define a `custom derivative rule
<https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#enforcing-a-differentiation-convention>`_
to enforce a particular derivative convention at a point.
Expand Down Expand Up @@ -259,13 +272,6 @@ DeviceArrays are Immutable

Unlike standard NumPy arrays, JAX arrays are immutable: once they have
been created, they cannot be changed. This prohibits in-place updating
of JAX arrays.

JAX provides special syntax for updating individual array elements
through the `indexed update operators
of JAX arrays. JAX provides special syntax for updating individual
array elements through the `indexed update operators
<https://jax.readthedocs.io/en/latest/jax.ops.html#syntactic-sugar-for-indexed-update-operators>`_.

In-place operations such as `x += y` must be replaced with the
out-of-place version `x = x + y`. Note that these operations will be
optimized if they are placed inside of a `jitted function
<https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit>`_.
13 changes: 7 additions & 6 deletions scico/optimize/_admmaux.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2022 by SCICO Developers
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -22,7 +22,6 @@
from scico.linop import CircularConvolve, Identity, LinearOperator
from scico.loss import SquaredL2Loss
from scico.numpy import BlockArray
from scico.numpy.linalg import norm
from scico.numpy.util import ensure_on_device, is_real_dtype
from scico.solver import cg as scico_cg
from scico.solver import minimize
Expand Down Expand Up @@ -63,6 +62,8 @@ def internal_init(self, admm: soa.ADMM):
class GenericSubproblemSolver(SubproblemSolver):
"""Solver for generic problem without special structure.

Note that this solver is only suitable for small-scale problems.

Attributes:
admm (:class:`.ADMM`): ADMM solver object to which the solver is
attached.
Expand Down Expand Up @@ -98,9 +99,9 @@ def obj(x):
for rhoi, Ci, zi, ui in zip(
self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list
):
out = out + 0.5 * rhoi * norm(zi - ui - Ci(x)) ** 2
out += 0.5 * rhoi * snp.sum(snp.abs(zi - ui - Ci(x)) ** 2)
if self.admm.f is not None:
out = out + self.admm.f(x)
out += self.admm.f(x)
return out

res = minimize(obj, x0, **self.minimize_kwargs)
Expand Down Expand Up @@ -211,7 +212,7 @@ def internal_init(self, admm: soa.ADMM):
)
if admm.f is not None:
# hessian = A.T @ W @ A; W may be identity
lhs_op = lhs_op + admm.f.hessian
lhs_op += admm.f.hessian

lhs_op.jit()
self.lhs_op = lhs_op
Expand Down Expand Up @@ -240,7 +241,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]:
for rhoi, Ci, zi, ui in zip(
self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list
):
rhs = rhs + rhoi * Ci.adj(zi - ui)
rhs += rhoi * Ci.adj(zi - ui)
return rhs

def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
Expand Down
2 changes: 1 addition & 1 deletion scico/optimize/_primaldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PDHG(Optimizer):
where :math:`f` and :math:`g` are instances of :class:`.Functional`,
(in most cases :math:`f` will, more specifically be an an instance
of :class:`.Loss`), and :math:`C` is an instance of
:class:`.LinearOperator`.
:class:`.Operator` or :class:`.LinearOperator`.

When `C` is a :class:`.LinearOperator`, the algorithm iterations are

Expand Down