Skip to content

Commit

Permalink
Resolve issue #395 (#396)
Browse files Browse the repository at this point in the history
* Docstring fix

* Add warning on appropriate problem scale

* Resolve #395

* Clean up

* Remove incorrect claim regarding in-place operations

* Improve notes on gradient of norm function
  • Loading branch information
bwohlberg authored Apr 11, 2023
1 parent 013d696 commit f6a4fc6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
28 changes: 17 additions & 11 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,28 @@ the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``:
>>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # doctest: +SKIP
Array([nan, nan], dtype=float32)

This can be fixed by defining the squared :math:`\ell_2` norm directly as
``g = lambda x: snp.sum(x**2)``. The gradient will work as expected:
This can be fixed (assuming real-valued arrays only) by defining the
squared :math:`\ell_2` norm directly as ``g = lambda x: snp.sum(x**2)``.
The gradient will work as expected:

::

>>> g = lambda x: snp.sum(x**2)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP
Array([0., 0.], dtype=float32)

If complex-valued arrays also need to be supported, a minor modification is
necessary:

::

>>> g = lambda x: snp.sum(snp.abs(x)**2)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP
Array([0., 0.], dtype=float32)
>>> scico.grad(g)(snp.zeros(2, dtype=snp.complex64)) #doctest: +SKIP
Array([0.-0.j, 0.-0.j], dtype=complex64)


An alternative is to define a `custom derivative rule
<https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#enforcing-a-differentiation-convention>`_
to enforce a particular derivative convention at a point.
Expand Down Expand Up @@ -259,13 +272,6 @@ DeviceArrays are Immutable

Unlike standard NumPy arrays, JAX arrays are immutable: once they have
been created, they cannot be changed. This prohibits in-place updating
of JAX arrays.

JAX provides special syntax for updating individual array elements
through the `indexed update operators
of JAX arrays. JAX provides special syntax for updating individual
array elements through the `indexed update operators
<https://jax.readthedocs.io/en/latest/jax.ops.html#syntactic-sugar-for-indexed-update-operators>`_.

In-place operations such as `x += y` must be replaced with the
out-of-place version `x = x + y`. Note that these operations will be
optimized if they are placed inside of a `jitted function
<https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit>`_.
13 changes: 7 additions & 6 deletions scico/optimize/_admmaux.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2022 by SCICO Developers
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -22,7 +22,6 @@
from scico.linop import CircularConvolve, Identity, LinearOperator
from scico.loss import SquaredL2Loss
from scico.numpy import BlockArray
from scico.numpy.linalg import norm
from scico.numpy.util import ensure_on_device, is_real_dtype
from scico.solver import cg as scico_cg
from scico.solver import minimize
Expand Down Expand Up @@ -63,6 +62,8 @@ def internal_init(self, admm: soa.ADMM):
class GenericSubproblemSolver(SubproblemSolver):
"""Solver for generic problem without special structure.
Note that this solver is only suitable for small-scale problems.
Attributes:
admm (:class:`.ADMM`): ADMM solver object to which the solver is
attached.
Expand Down Expand Up @@ -98,9 +99,9 @@ def obj(x):
for rhoi, Ci, zi, ui in zip(
self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list
):
out = out + 0.5 * rhoi * norm(zi - ui - Ci(x)) ** 2
out += 0.5 * rhoi * snp.sum(snp.abs(zi - ui - Ci(x)) ** 2)
if self.admm.f is not None:
out = out + self.admm.f(x)
out += self.admm.f(x)
return out

res = minimize(obj, x0, **self.minimize_kwargs)
Expand Down Expand Up @@ -211,7 +212,7 @@ def internal_init(self, admm: soa.ADMM):
)
if admm.f is not None:
# hessian = A.T @ W @ A; W may be identity
lhs_op = lhs_op + admm.f.hessian
lhs_op += admm.f.hessian

lhs_op.jit()
self.lhs_op = lhs_op
Expand Down Expand Up @@ -240,7 +241,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]:
for rhoi, Ci, zi, ui in zip(
self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list
):
rhs = rhs + rhoi * Ci.adj(zi - ui)
rhs += rhoi * Ci.adj(zi - ui)
return rhs

def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
Expand Down
2 changes: 1 addition & 1 deletion scico/optimize/_primaldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PDHG(Optimizer):
where :math:`f` and :math:`g` are instances of :class:`.Functional`,
(in most cases :math:`f` will, more specifically be an an instance
of :class:`.Loss`), and :math:`C` is an instance of
:class:`.LinearOperator`.
:class:`.Operator` or :class:`.LinearOperator`.
When `C` is a :class:`.LinearOperator`, the algorithm iterations are
Expand Down

0 comments on commit f6a4fc6

Please sign in to comment.