diff --git a/docs/source/notes.rst b/docs/source/notes.rst index 35822ad3d..74f63adf4 100644 --- a/docs/source/notes.rst +++ b/docs/source/notes.rst @@ -193,8 +193,9 @@ the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``: >>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) # doctest: +SKIP Array([nan, nan], dtype=float32) -This can be fixed by defining the squared :math:`\ell_2` norm directly as -``g = lambda x: snp.sum(x**2)``. The gradient will work as expected: +This can be fixed (assuming real-valued arrays only) by defining the +squared :math:`\ell_2` norm directly as ``g = lambda x: snp.sum(x**2)``. +The gradient will work as expected: :: @@ -202,6 +203,18 @@ This can be fixed by defining the squared :math:`\ell_2` norm directly as >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP Array([0., 0.], dtype=float32) +If complex-valued arrays also need to be supported, a minor modification is +necessary: + +:: + + >>> g = lambda x: snp.sum(snp.abs(x)**2) + >>> scico.grad(g)(snp.zeros(2, dtype=snp.float32)) #doctest: +SKIP + Array([0., 0.], dtype=float32) + >>> scico.grad(g)(snp.zeros(2, dtype=snp.complex64)) #doctest: +SKIP + Array([0.-0.j, 0.-0.j], dtype=complex64) + + An alternative is to define a `custom derivative rule `_ to enforce a particular derivative convention at a point. @@ -259,13 +272,6 @@ DeviceArrays are Immutable Unlike standard NumPy arrays, JAX arrays are immutable: once they have been created, they cannot be changed. This prohibits in-place updating -of JAX arrays. - -JAX provides special syntax for updating individual array elements -through the `indexed update operators +of JAX arrays. JAX provides special syntax for updating individual +array elements through the `indexed update operators `_. - -In-place operations such as `x += y` must be replaced with the -out-of-place version `x = x + y`. Note that these operations will be -optimized if they are placed inside of a `jitted function -`_. diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index a21dadca6..e34f8ffb1 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers +# Copyright (C) 2020-2023 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -22,7 +22,6 @@ from scico.linop import CircularConvolve, Identity, LinearOperator from scico.loss import SquaredL2Loss from scico.numpy import BlockArray -from scico.numpy.linalg import norm from scico.numpy.util import ensure_on_device, is_real_dtype from scico.solver import cg as scico_cg from scico.solver import minimize @@ -63,6 +62,8 @@ def internal_init(self, admm: soa.ADMM): class GenericSubproblemSolver(SubproblemSolver): """Solver for generic problem without special structure. + Note that this solver is only suitable for small-scale problems. + Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. @@ -98,9 +99,9 @@ def obj(x): for rhoi, Ci, zi, ui in zip( self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list ): - out = out + 0.5 * rhoi * norm(zi - ui - Ci(x)) ** 2 + out += 0.5 * rhoi * snp.sum(snp.abs(zi - ui - Ci(x)) ** 2) if self.admm.f is not None: - out = out + self.admm.f(x) + out += self.admm.f(x) return out res = minimize(obj, x0, **self.minimize_kwargs) @@ -211,7 +212,7 @@ def internal_init(self, admm: soa.ADMM): ) if admm.f is not None: # hessian = A.T @ W @ A; W may be identity - lhs_op = lhs_op + admm.f.hessian + lhs_op += admm.f.hessian lhs_op.jit() self.lhs_op = lhs_op @@ -240,7 +241,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]: for rhoi, Ci, zi, ui in zip( self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list ): - rhs = rhs + rhoi * Ci.adj(zi - ui) + rhs += rhoi * Ci.adj(zi - ui) return rhs def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index d96e73945..245b6860f 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -44,7 +44,7 @@ class PDHG(Optimizer): where :math:`f` and :math:`g` are instances of :class:`.Functional`, (in most cases :math:`f` will, more specifically be an an instance of :class:`.Loss`), and :math:`C` is an instance of - :class:`.LinearOperator`. + :class:`.Operator` or :class:`.LinearOperator`. When `C` is a :class:`.LinearOperator`, the algorithm iterations are