From c7d1b5b5cf6ce58d2f9aef9504355788a109f16f Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 24 Feb 2022 19:13:53 -0700 Subject: [PATCH 01/24] Minor changes and clean up --- scico/loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 2975a111f..c80c5d13d 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -215,23 +215,23 @@ def prox( ATWA = c * A.conj() * W * A return lhs / (ATWA + 1.0) - # prox_{f}(v) = arg min 1/2 || v - x ||^2 + 位 伪 || A x - y ||^2_W + # prox_{f}(v) = arg min 1/2 || v - x ||_2^2 + 位 饾浖 || A x - y ||^2_W # x # solution at: # - # (I + 位 2伪 A^T W A) x = v + 位 2伪 A^T W y + # (I + 位 2饾浖 A^T W A) x = v + 位 2饾浖 A^T W y # W = self.W A = self.A - 伪 = self.scale + 饾浖 = self.scale y = self.y if "x0" in kwargs and kwargs["x0"] is not None: x0 = kwargs["x0"] else: x0 = snp.zeros_like(v) - hessian = self.hessian # = (2伪 A^T W A) + hessian = self.hessian # = (2饾浖 A^T W A) lhs = linop.Identity(v.shape) + lam * hessian - rhs = v + 2 * lam * 伪 * A.adj(W(y)) + rhs = v + 2 * lam * 饾浖 * A.adj(W(y)) x, _ = cg(lhs, rhs, x0, **self.prox_kwargs) return x From 551608033825aa6336eb06da487ab876a0c7f88a Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 25 Feb 2022 10:57:12 -0700 Subject: [PATCH 02/24] Docstring style fix --- scico/operator/biconvolve.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index 3c172e3e3..7c137dd96 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -18,8 +18,6 @@ from scico.linop import Convolve, ConvolveByX from scico.typing import BlockShape, DType, JaxArray -__author__ = """Luke Pfister """ - class BiConvolve(Operator): """BiConvolution operator. @@ -44,7 +42,7 @@ def __init__( input_shape: Shape of input BlockArray. Must correspond to a BlockArray with two blocks of equal ndims. input_dtype: `dtype` for input argument. Defaults to - `float32`. + ``float32``. mode: A string indicating the size of the output. One of "full", "valid", "same". Defaults to "full". jit: If ``True``, jit the evaluation of this Operator. From 61813f99a76d1be1bdb99cb70701015c44e27aa5 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 25 Feb 2022 10:58:57 -0700 Subject: [PATCH 03/24] =?UTF-8?q?Add=20weighted=20=E2=84=932=20of=20absolu?= =?UTF-8?q?te=20value=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scico/loss.py | 81 ++++++++++++++++++++++++++++++++++++++--- scico/test/test_loss.py | 34 +++++++++++++++++ 2 files changed, 109 insertions(+), 6 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index c80c5d13d..f7311c2fe 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -237,9 +237,9 @@ def prox( @property def hessian(self) -> linop.LinearOperator: - r"""Compute the hessian of a linear operator. + r"""Compute the Hessian of linear operator `A`. - If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a + If `self.A` is a :class:`scico.linop.LinearOperator`, returns a :class:`scico.linop.LinearOperator` corresponding to the Hessian :math:`2 \alpha \mathrm{A^H W A}`. Otherwise not implemented. """ @@ -254,8 +254,8 @@ def hessian(self) -> linop.LinearOperator: ) raise NotImplementedError( - f"Hessian is not implemented for {type(self)} when `A` is {type(A)}; " - "must be LinearOperator" + f"Hessian is not implemented for {type(self)} when A is {type(A)}; " + "must be LinearOperator." ) @@ -275,7 +275,7 @@ def __init__( y: Union[JaxArray, BlockArray], A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, - prox_kwargs: dict = {"maxiter": 1000, "tol": 1e-12}, + prox_kwargs: dict = {"maxiter": 100, "tol": 1e-5}, ): r""" Args: @@ -308,7 +308,7 @@ def __init__( Args: y: Measurement. A: Forward operator. Defaults to ``None``, in which case - ``self.A`` is a :class:`.Identity`. + `self.A` is a :class:`.Identity`. scale: Scaling parameter. Default: 0.5. """ y = ensure_on_device(y) @@ -320,3 +320,72 @@ def __init__( def __call__(self, x: Union[JaxArray, BlockArray]) -> float: Ax = self.A(x) return self.scale * snp.sum(Ax - self.y * snp.log(Ax) + self.const) + + +class WeightedSquaredL2AbsLoss(Loss): + r"""Weighted squared :math:`\ell_2` with absolute value loss. + + Weighted squared :math:`\ell_2` with absolute value loss + + .. math:: + \alpha \norm{\mb{y} - | A(\mb{x}) |}_W^2 = + \alpha \left(\mb{y} - | A(\mb{x} |)\right)^T W \left(\mb{y} - + | A(\mb{x}) |\right) \;, + + where :math:`\alpha` is the scaling parameter and :math:`W` is an + instance of :class:`scico.linop.Diagonal`. + """ + + def __init__( + self, + y: Union[JaxArray, BlockArray], + A: Optional[Union[Callable, operator.Operator]] = None, + scale: float = 0.5, + W: Optional[linop.Diagonal] = None, + prox_kwargs: dict = {"maxiter": 100, "tol": 1e-5}, + ): + + r""" + Args: + y: Measurement. + A: Forward operator. If ``None``, defaults to :class:`.Identity`. + scale: Scaling parameter. + W: Weighting diagonal operator. Must be non-negative. + If ``None``, defaults to :class:`.Identity`. + """ + y = ensure_on_device(y) + + if W is None: + self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) + elif isinstance(W, linop.Diagonal): + if snp.all(W.diagonal >= 0): + self.W = W + else: + raise ValueError(f"The weights, W.diagonal, must be non-negative.") + else: + raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}.") + + super().__init__(y=y, A=A, scale=scale) + + if prox_kwargs is None: + prox_kwargs = dict + self.prox_kwargs = prox_kwargs + + if isinstance(self.A, linop.Identity) and snp.all(y >= 0): + self.has_prox = True + + def __call__(self, x: Union[JaxArray, BlockArray]) -> float: + return self.scale * (self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x))) ** 2).sum() + + def prox( + self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + ) -> Union[JaxArray, BlockArray]: + if not self.has_prox: + raise NotImplementedError(f"prox is not implemented.") + + 饾浖 = lam * 2.0 * self.scale * self.W.diagonal + y = self.y + r = snp.abs(v) + 饾浗 = (饾浖 * y + r) / ((饾浖 + 1.0) * r) + x = 饾浗 * v + return x diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index 39e5689c6..f725333c1 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -129,3 +129,37 @@ def test_poisson(self): assert L.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L.scale assert cL(v) == self.scalar * L(v) + + def test_weighted_squared_l2_abs(self): + L = loss.WeightedSquaredL2AbsLoss(y=self.y, A=self.Ao, W=self.W) + assert L.has_eval + assert not L.has_prox + + # test eval + np.testing.assert_allclose( + L(self.v), 0.5 * (self.W @ (snp.abs(self.Ao @ self.v) - self.y) ** 2).sum() + ) + + cL = self.scalar * L + assert L.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L.scale + assert cL(self.v) == self.scalar * L(self.v) + + # Loss has a prox with Identity linop + y = snp.abs(self.y) + L_d = loss.WeightedSquaredL2AbsLoss(y=y, A=None, W=self.W) + + assert L_d.has_eval + assert L_d.has_prox + + # test eval + np.testing.assert_allclose(L_d(self.v), 0.5 * (self.W @ (snp.abs(self.v) - y) ** 2).sum()) + + cL = self.scalar * L_d + assert L_d.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L_d.scale + assert cL(self.v) == self.scalar * L_d(self.v) + + pf = prox_test(self.v, L_d, L_d.prox, 0.75) + with pytest.raises(NotImplementedError): + pf = prox_test(self.v, L, L.prox, 0.75) From eac60d4c3747078211cd9d23cf862c48f1ef59d1 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 25 Feb 2022 11:05:14 -0700 Subject: [PATCH 04/24] Docstring style fix --- scico/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/random.py b/scico/random.py index ee2d158b0..ee45e46d5 100644 --- a/scico/random.py +++ b/scico/random.py @@ -197,7 +197,7 @@ def randn( shape: Shape of output array. If shape is a tuple, a DeviceArray is returned. If shape is a tuple of tuples, a :class:`.BlockArray` is returned. - key: JAX PRNGKey. Defaults to None, in which case a new key + key: JAX PRNGKey. Defaults to ``None``, in which case a new key is created using the seed arg. seed: Seed for new PRNGKey. Default: 0. dtype: dtype for returned value. Default to ``np.float32``. From 97ce27da4586ab4f73545e8614436c4baa42181c Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 25 Feb 2022 11:05:32 -0700 Subject: [PATCH 05/24] Add test for complex v --- scico/test/test_loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index f725333c1..18bfe6fa2 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -30,6 +30,7 @@ def setup_method(self): self.y, key = randn((n,), key=key, dtype=dtype) self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval scalar, key = randn((1,), key=key, dtype=dtype) + self.key = key self.scalar = scalar.copy().ravel()[0] def test_generic_squared_l2(self): @@ -160,6 +161,8 @@ def test_weighted_squared_l2_abs(self): assert cL.scale == self.scalar * L_d.scale assert cL(self.v) == self.scalar * L_d(self.v) + v, key = randn(y.shape, key=self.key, dtype=np.complex128) pf = prox_test(self.v, L_d, L_d.prox, 0.75) + pf = prox_test(v, L_d, L_d.prox, 0.75) with pytest.raises(NotImplementedError): pf = prox_test(self.v, L, L.prox, 0.75) From 50da8add77d3c7c65e8fee812902c30c4eb20a35 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 25 Feb 2022 11:21:19 -0700 Subject: [PATCH 06/24] Improve docs --- scico/loss.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index f7311c2fe..06257bae6 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -328,12 +328,15 @@ class WeightedSquaredL2AbsLoss(Loss): Weighted squared :math:`\ell_2` with absolute value loss .. math:: - \alpha \norm{\mb{y} - | A(\mb{x}) |}_W^2 = - \alpha \left(\mb{y} - | A(\mb{x} |)\right)^T W \left(\mb{y} - + \alpha \norm{\mb{y} - | A(\mb{x}) |\,}_W^2 = + \alpha \left(\mb{y} - | A(\mb{x}) |\right)^T W \left(\mb{y} - | A(\mb{x}) |\right) \;, where :math:`\alpha` is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. + + Proximal operator :meth:`prox` is implemented when :math:`A` is an + instance of :class:`scico.linop.Identity`. """ def __init__( From 62346be6567fb9724df97b2d0560407e3e0c9419 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Fri, 25 Feb 2022 19:40:04 -0700 Subject: [PATCH 07/24] Fix and test v=0 case --- scico/loss.py | 4 ++-- scico/test/test_loss.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 06257bae6..7d56014b8 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -389,6 +389,6 @@ def prox( 饾浖 = lam * 2.0 * self.scale * self.W.diagonal y = self.y r = snp.abs(v) - 饾浗 = (饾浖 * y + r) / ((饾浖 + 1.0) * r) - x = 饾浗 * v + 饾浗 = (饾浖 * y + r) / (饾浖 + 1.0) + x = snp.where(r > 0, (饾浗 / r) * v, 饾浗) return x diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index 18bfe6fa2..4ef811c4e 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -161,8 +161,12 @@ def test_weighted_squared_l2_abs(self): assert cL.scale == self.scalar * L_d.scale assert cL(self.v) == self.scalar * L_d(self.v) + pf = prox_test(self.v, L_d, L_d.prox, 0.5) # real v v, key = randn(y.shape, key=self.key, dtype=np.complex128) - pf = prox_test(self.v, L_d, L_d.prox, 0.75) - pf = prox_test(v, L_d, L_d.prox, 0.75) + pf = prox_test(v, L_d, L_d.prox, 2.0) # complex v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 0.0) # complex zero v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 1.0) # complex zero v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 2.0) # complex zero v + with pytest.raises(NotImplementedError): pf = prox_test(self.v, L, L.prox, 0.75) From f2e137923467bdfa46867ed41e2c9abf315122e4 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 13:46:42 -0700 Subject: [PATCH 08/24] =?UTF-8?q?Add=20weighted=20squared=20=E2=84=932=20o?= =?UTF-8?q?f=20squared=20absolute=20value=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/references.bib | 15 +++++ scico/loss.py | 128 ++++++++++++++++++++++++++++++++++++- scico/test/test_loss.py | 42 ++++++++++++ 3 files changed, 183 insertions(+), 2 deletions(-) diff --git a/docs/source/references.bib b/docs/source/references.bib index bae66749d..1d3ba1878 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -412,6 +412,21 @@ @Article {sauer-1993-local doi = {10.1109/78.193196} } +@Article {soulez-2016-proximity, + author = {Ferr{\'{e}}ol Soulez and {\'{E}}ric Thi{\'{e}}baut + and Antony Schutz and Andr{\'{e}} Ferrari and + Fr{\'{e}}d{\'{e}}ric Courbin and Michael Unser}, + title = {Proximity operators for phase retrieval}, + journal = {Applied Optics}, + doi = {10.1364/ao.55.007412}, + year = 2016, + month = Sep, + volume = 55, + number = 26, + pages = {7412--7421} + +} + @Article {sreehari-2016-plug, author = {Suhas Sreehari and Singanallur V. Venkatakrishnan and Brendt Wohlberg and Gregery T. Buzzard and diff --git a/scico/loss.py b/scico/loss.py index 7d56014b8..0712bfba2 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -15,7 +15,7 @@ import scico.numpy as snp from scico import functional, linop, operator -from scico.array import ensure_on_device +from scico.array import ensure_on_device, no_nan_divide from scico.blockarray import BlockArray from scico.scipy.special import gammaln from scico.solver import cg @@ -336,7 +336,9 @@ class WeightedSquaredL2AbsLoss(Loss): instance of :class:`scico.linop.Diagonal`. Proximal operator :meth:`prox` is implemented when :math:`A` is an - instance of :class:`scico.linop.Identity`. + instance of :class:`scico.linop.Identity`. This is not proximal + operator according to the strict definition since the loss function + is non-convex (Sec. 3) :cite:`soulez-2016-proximity`. """ def __init__( @@ -392,3 +394,125 @@ def prox( 饾浗 = (饾浖 * y + r) / (饾浖 + 1.0) x = snp.where(r > 0, (饾浗 / r) * v, 饾浗) return x + + +def cbrt(x): + """Compute the cube root of the argument. + + The two standard options for computing the cube root of an array are + :func:`numpy.cbrt`, or raising to the power of (1/3), i.e. `x ** (1/3)`. + The former cannot be used for complex values, and the latter returns + a complex root of a negative real value. This functions can be used + for both real and complex values, and returns the real root of + negative real values. + + Args: + x: Input array + + Returns: + Array of cube roots of input `x`. + """ + s = snp.where(snp.abs(snp.angle(x)) <= 3 * snp.pi / 4, 1, -1) + return s * (s * x) ** (1 / 3) + + +def dep_cubic_root(p, q): + r"""Compute the positive real root of a depressed cubic equation. + + A depressed cubic equation is one that can be written in the form + + .. math:: + x^3 + px + q \;. + + This function finds the positive real root of such an equation via + `Cardano's method `__, + for `p` and `q` such that there is a single positive real root + (see Sec. 3.C of :cite:`soulez-2016-proximity`). + """ + q2 = q / 2 + 螖 = q2 ** 2 + (p / 3) ** 3 + 螖rt = snp.sqrt(螖 + 0j) + u3, v3 = -q2 + 螖rt, -q2 - 螖rt + u, v = cbrt(u3), cbrt(v3) + r = (u + v).real + assert snp.allclose(snp.abs(r ** 3 + p * r + q), 0) + return r + + +class WeightedSquaredL2AbsSquaredLoss(Loss): + r"""Weighted squared :math:`\ell_2` with squared absolute value loss. + + Weighted squared :math:`\ell_2` with squared absolute value loss + + .. math:: + \alpha \norm{\mb{y} - | A(\mb{x}) |^2 \,}_W^2 = + \alpha \left(\mb{y} - | A(\mb{x}) |^2 \right)^T W \left(\mb{y} - + | A(\mb{x}) |^2 \right) \;, + + where :math:`\alpha` is the scaling parameter and :math:`W` is an + instance of :class:`scico.linop.Diagonal`. + + Proximal operator :meth:`prox` is implemented when :math:`A` is an + instance of :class:`scico.linop.Identity`. This is not proximal + operator according to the strict definition since the loss function + is non-convex (Sec. 3) :cite:`soulez-2016-proximity`. + """ + + def __init__( + self, + y: Union[JaxArray, BlockArray], + A: Optional[Union[Callable, operator.Operator]] = None, + scale: float = 0.5, + W: Optional[linop.Diagonal] = None, + prox_kwargs: dict = {"maxiter": 100, "tol": 1e-5}, + ): + + r""" + Args: + y: Measurement. + A: Forward operator. If ``None``, defaults to :class:`.Identity`. + scale: Scaling parameter. + W: Weighting diagonal operator. Must be non-negative. + If ``None``, defaults to :class:`.Identity`. + """ + y = ensure_on_device(y) + + if W is None: + self.W: Union[linop.Diagonal, linop.Identity] = linop.Identity(y.shape) + elif isinstance(W, linop.Diagonal): + if snp.all(W.diagonal >= 0): + self.W = W + else: + raise ValueError(f"The weights, W.diagonal, must be non-negative.") + else: + raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}.") + + super().__init__(y=y, A=A, scale=scale) + + if prox_kwargs is None: + prox_kwargs = dict + self.prox_kwargs = prox_kwargs + + if isinstance(self.A, linop.Identity) and snp.all(y >= 0): + self.has_prox = True + + def __call__(self, x: Union[JaxArray, BlockArray]) -> float: + return self.scale * (self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x)) ** 2) ** 2).sum() + + def prox( + self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + ) -> Union[JaxArray, BlockArray]: + if not self.has_prox: + raise NotImplementedError(f"prox is not implemented.") + + 饾浖 = lam * 4.0 * self.scale * self.W.diagonal + 饾浗 = snp.abs(v) + p = no_nan_divide(1.0 - 饾浖 * self.y, 饾浖) + q = no_nan_divide(-饾浗, 饾浖) + # r = snp.where(饾浖 > 0, dep_cubic_root(p, q), 饾浗) + r = dep_cubic_root(p, q) + 蠁 = snp.where(饾浗 > 0, v / snp.abs(v), 1.0) + # x = r * 蠁 + x = snp.where(饾浖 > 0, r * 蠁, v) + print("v", v, "y", self.y, "饾浖", 饾浖, "饾浗", 饾浗, "p", p, "q", q, "r", r, "蠁", 蠁, "x", x) + return x diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index 4ef811c4e..451301d1d 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -170,3 +170,45 @@ def test_weighted_squared_l2_abs(self): with pytest.raises(NotImplementedError): pf = prox_test(self.v, L, L.prox, 0.75) + + def test_weighted_squared_l2_abs_squared(self): + L = loss.WeightedSquaredL2AbsSquaredLoss(y=self.y, A=self.Ao, W=self.W) + assert L.has_eval + assert not L.has_prox + + # test eval + np.testing.assert_allclose( + L(self.v), 0.5 * (self.W @ (snp.abs(self.Ao @ self.v) ** 2 - self.y) ** 2).sum() + ) + + cL = self.scalar * L + assert L.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L.scale + assert cL(self.v) == self.scalar * L(self.v) + + # Loss has a prox with Identity linop + y = snp.abs(self.y) + L_d = loss.WeightedSquaredL2AbsSquaredLoss(y=y, A=None, W=self.W) + + assert L_d.has_eval + assert L_d.has_prox + + # test eval + np.testing.assert_allclose( + L_d(self.v), 0.5 * (self.W @ (snp.abs(self.v) ** 2 - y) ** 2).sum() + ) + + cL = self.scalar * L_d + assert L_d.scale == 0.5 # hasn't changed + assert cL.scale == self.scalar * L_d.scale + assert cL(self.v) == self.scalar * L_d(self.v) + + pf = prox_test(self.v, L_d, L_d.prox, 0.5) # real v + v, key = randn(y.shape, key=self.key, dtype=np.complex128) + pf = prox_test(v, L_d, L_d.prox, 2.0) # complex v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 0.0) # complex zero v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 1.0) # complex zero v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 2.0) # complex zero v + + with pytest.raises(NotImplementedError): + pf = prox_test(self.v, L, L.prox, 0.75) From 0b60b6814b58a48c2e628cf55efc9dbed439ec38 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 13:50:38 -0700 Subject: [PATCH 09/24] Clean up --- scico/loss.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 0712bfba2..dd5692cb4 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -396,7 +396,7 @@ def prox( return x -def cbrt(x): +def _cbrt(x): """Compute the cube root of the argument. The two standard options for computing the cube root of an array are @@ -416,7 +416,7 @@ def cbrt(x): return s * (s * x) ** (1 / 3) -def dep_cubic_root(p, q): +def _dep_cubic_root(p, q): r"""Compute the positive real root of a depressed cubic equation. A depressed cubic equation is one that can be written in the form @@ -509,10 +509,7 @@ def prox( 饾浗 = snp.abs(v) p = no_nan_divide(1.0 - 饾浖 * self.y, 饾浖) q = no_nan_divide(-饾浗, 饾浖) - # r = snp.where(饾浖 > 0, dep_cubic_root(p, q), 饾浗) - r = dep_cubic_root(p, q) + r = _dep_cubic_root(p, q) 蠁 = snp.where(饾浗 > 0, v / snp.abs(v), 1.0) - # x = r * 蠁 x = snp.where(饾浖 > 0, r * 蠁, v) - print("v", v, "y", self.y, "饾浖", 饾浖, "饾浗", 饾浗, "p", p, "q", q, "r", r, "蠁", 蠁, "x", x) return x From 419b3bb529e24cd1b59059c791326abaddb6957d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 14:09:52 -0700 Subject: [PATCH 10/24] Fix cleanup error --- scico/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/loss.py b/scico/loss.py index dd5692cb4..4708f8d76 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -433,7 +433,7 @@ def _dep_cubic_root(p, q): 螖 = q2 ** 2 + (p / 3) ** 3 螖rt = snp.sqrt(螖 + 0j) u3, v3 = -q2 + 螖rt, -q2 - 螖rt - u, v = cbrt(u3), cbrt(v3) + u, v = _cbrt(u3), _cbrt(v3) r = (u + v).real assert snp.allclose(snp.abs(r ** 3 + p * r + q), 0) return r From bdc02597527cc54c22e7d497ce418f865924ef2d Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 14:29:20 -0700 Subject: [PATCH 11/24] Clean up tests --- scico/test/test_loss.py | 121 ++++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 59 deletions(-) diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index 451301d1d..cd0179b5e 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -12,6 +12,7 @@ import scico.numpy as snp from scico import functional, linop, loss +from scico.array import complex_dtype from scico.random import randn @@ -131,84 +132,86 @@ def test_poisson(self): assert cL.scale == self.scalar * L.scale assert cL(v) == self.scalar * L(v) - def test_weighted_squared_l2_abs(self): - L = loss.WeightedSquaredL2AbsLoss(y=self.y, A=self.Ao, W=self.W) + +class TestComplexLoss: + + cplx_loss = ( + (loss.WeightedSquaredL2AbsLoss, lambda x: snp.abs(x)), + (loss.WeightedSquaredL2AbsSquaredLoss, lambda x: snp.abs(x) ** 2), + ) + + def setup_method(self): + n = 4 + dtype = np.float64 + A, key = randn((n, n), key=None, dtype=dtype, seed=1234) + W, key = randn((n,), key=key, dtype=dtype) + W = 0.1 * W + 1.0 + self.Ao = linop.MatrixOperator(A) + self.Ao_abs = linop.MatrixOperator(snp.abs(A)) + self.W = linop.Diagonal(W) + self.x, key = randn((n,), key=key, dtype=complex_dtype(dtype)) + self.v, key = randn((n,), key=key, dtype=complex_dtype(dtype)) # point for prox eval + scalar, key = randn((1,), key=key, dtype=dtype) + self.scalar = scalar.copy().ravel()[0] + + @pytest.mark.parametrize("loss_tuple", cplx_loss) + def test_properties(self, loss_tuple): + loss_class = loss_tuple[0] + loss_func = loss_tuple[1] + + y = loss_func(self.Ao(self.x)) + L = loss_class(y=y, A=self.Ao, W=self.W) assert L.has_eval assert not L.has_prox - # test eval - np.testing.assert_allclose( - L(self.v), 0.5 * (self.W @ (snp.abs(self.Ao @ self.v) - self.y) ** 2).sum() - ) - cL = self.scalar * L assert L.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L.scale assert cL(self.v) == self.scalar * L(self.v) - # Loss has a prox with Identity linop - y = snp.abs(self.y) - L_d = loss.WeightedSquaredL2AbsLoss(y=y, A=None, W=self.W) - - assert L_d.has_eval - assert L_d.has_prox - - # test eval - np.testing.assert_allclose(L_d(self.v), 0.5 * (self.W @ (snp.abs(self.v) - y) ** 2).sum()) - - cL = self.scalar * L_d - assert L_d.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L_d.scale - assert cL(self.v) == self.scalar * L_d(self.v) - - pf = prox_test(self.v, L_d, L_d.prox, 0.5) # real v - v, key = randn(y.shape, key=self.key, dtype=np.complex128) - pf = prox_test(v, L_d, L_d.prox, 2.0) # complex v - pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 0.0) # complex zero v - pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 1.0) # complex zero v - pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 2.0) # complex zero v - with pytest.raises(NotImplementedError): - pf = prox_test(self.v, L, L.prox, 0.75) + px = L.prox(self.v, 0.75) - def test_weighted_squared_l2_abs_squared(self): - L = loss.WeightedSquaredL2AbsSquaredLoss(y=self.y, A=self.Ao, W=self.W) - assert L.has_eval - assert not L.has_prox + np.testing.assert_allclose(L(self.x), 0) - # test eval - np.testing.assert_allclose( - L(self.v), 0.5 * (self.W @ (snp.abs(self.Ao @ self.v) ** 2 - self.y) ** 2).sum() - ) + y = loss_func(self.x) + L = loss_class(y=y, A=None, W=self.W) + assert L.has_eval + assert L.has_prox cL = self.scalar * L assert L.scale == 0.5 # hasn't changed assert cL.scale == self.scalar * L.scale assert cL(self.v) == self.scalar * L(self.v) - # Loss has a prox with Identity linop - y = snp.abs(self.y) - L_d = loss.WeightedSquaredL2AbsSquaredLoss(y=y, A=None, W=self.W) + np.testing.assert_allclose(L(self.x), 0) - assert L_d.has_eval - assert L_d.has_prox + @pytest.mark.parametrize("loss_tuple", cplx_loss) + def test_prox(self, loss_tuple): + loss_class = loss_tuple[0] + loss_func = loss_tuple[1] - # test eval - np.testing.assert_allclose( - L_d(self.v), 0.5 * (self.W @ (snp.abs(self.v) ** 2 - y) ** 2).sum() - ) + y = loss_func(self.x) + L = loss_class(y=y, A=None, W=self.W) - cL = self.scalar * L_d - assert L_d.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L_d.scale - assert cL(self.v) == self.scalar * L_d(self.v) + pf = prox_test(self.v.real, L, L.prox, 0.5) # real v - pf = prox_test(self.v, L_d, L_d.prox, 0.5) # real v - v, key = randn(y.shape, key=self.key, dtype=np.complex128) - pf = prox_test(v, L_d, L_d.prox, 2.0) # complex v - pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 0.0) # complex zero v - pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 1.0) # complex zero v - pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L_d, L_d.prox, 2.0) # complex zero v + pf = prox_test(self.v, L, L.prox, 0.0) # complex v + pf = prox_test(self.v, L, L.prox, 0.1) # complex v + pf = prox_test(self.v, L, L.prox, 2.0) # complex v - with pytest.raises(NotImplementedError): - pf = prox_test(self.v, L, L.prox, 0.75) + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0) # complex zero v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0) # complex zero v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 2.0) # complex zero v + + # zero y + y = snp.zeros(self.x.shape) + L = loss_class(y=y, A=None, W=self.W) + + pf = prox_test(self.v.real, L, L.prox, 0.5) # real v + + pf = prox_test(self.v, L, L.prox, 0.0) # complex v + pf = prox_test(self.v, L, L.prox, 0.1) # complex v + + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 0.0) # complex zero v + pf = prox_test((1 + 1j) * snp.zeros(self.v.shape), L, L.prox, 1.0) # complex zero v From 20a04e0a402bbc69ed22af7efae75a3efab6a3ea Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 19:30:18 -0700 Subject: [PATCH 12/24] Clean up and address problems with cg with complex values --- scico/solver.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scico/solver.py b/scico/solver.py index c2ad85668..60a431de4 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -137,8 +137,8 @@ def _split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArr Returns: A real ndarray with stacked real/imaginary parts. If `x` has shape (M, N, ...), the returned array will have shape - (2, M, N, ...) where the first slice contains the ``x.real`` and - the second contains ``x.imag``. If `x` is a BlockArray, this + (2, M, N, ...) where the first slice contains the `x.real` and + the second contains `x.imag`. If `x` is a BlockArray, this function is called on each block and the output is joined into a BlockArray. """ @@ -157,8 +157,8 @@ def _join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArra x: Array to join. Returns: - A complex array with real and imaginary parts taken from ``x[0]`` - and ``x[1]`` respectively. + A complex array with real and imaginary parts taken from `x[0]` + and `x[1]` respectively. """ if isinstance(x, BlockArray): return BlockArray.array([_join_real_imag(_) for _ in x]) @@ -310,9 +310,9 @@ def cg( b: Input array :math:`\mb{b}`. x0: Initial solution. tol: Relative residual stopping tolerance. Convergence occurs - when ``norm(residual) <= max(tol * norm(b), atol)``. + when `norm(residual) <= max(tol * norm(b), atol)`. atol: Absolute residual stopping tolerance. Convergence occurs - when ``norm(residual) <= max(tol * norm(b), atol)``. + when `norm(residual) <= max(tol * norm(b), atol)`. maxiter: Maximum iterations. Default: 1000. M: Preconditioner for `A`. The preconditioner should approximate the inverse of `A`. The default, ``None``, uses no @@ -353,4 +353,4 @@ def cg( p = z + beta * p ii += 1 - return (x, {"num_iter": ii, "rel_res": snp.sqrt(num) / bn}) + return (x, {"num_iter": ii, "rel_res": snp.sqrt(num).real / bn}) From 751f93714d142d0b60d3e9db7cb315c1907c36de Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 19:30:49 -0700 Subject: [PATCH 13/24] Bump assertion tolerance --- scico/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/loss.py b/scico/loss.py index 4708f8d76..fd1571f18 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -435,7 +435,7 @@ def _dep_cubic_root(p, q): u3, v3 = -q2 + 螖rt, -q2 - 螖rt u, v = _cbrt(u3), _cbrt(v3) r = (u + v).real - assert snp.allclose(snp.abs(r ** 3 + p * r + q), 0) + assert snp.allclose(snp.abs(r ** 3 + p * r + q), 0, atol=1e-4) return r From 33bb1b35703bfff8f748108535600d5af68041f3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 20:58:20 -0700 Subject: [PATCH 14/24] Remove redundant lambda --- scico/test/test_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index cd0179b5e..f1c87aaae 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -136,7 +136,7 @@ def test_poisson(self): class TestComplexLoss: cplx_loss = ( - (loss.WeightedSquaredL2AbsLoss, lambda x: snp.abs(x)), + (loss.WeightedSquaredL2AbsLoss, snp.abs), (loss.WeightedSquaredL2AbsSquaredLoss, lambda x: snp.abs(x) ** 2), ) From 510899586fdd6edf31400e582d50a80fbea50e2b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 21:12:35 -0700 Subject: [PATCH 15/24] Add some tests --- scico/test/test_loss.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index f1c87aaae..ec5d17d59 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -175,7 +175,7 @@ def test_properties(self, loss_tuple): np.testing.assert_allclose(L(self.x), 0) y = loss_func(self.x) - L = loss_class(y=y, A=None, W=self.W) + L = loss_class(y=y, A=None, W=None, prox_kwargs=None) assert L.has_eval assert L.has_prox @@ -186,6 +186,13 @@ def test_properties(self, loss_tuple): np.testing.assert_allclose(L(self.x), 0) + W = -1 * self.W + with pytest.raises(ValueError): + L = loss_class(y=y, W=W) + + with pytest.raises(TypeError): + L = loss_class(y=y, W=linop.Sum(input_shape=W.input_shape)) + @pytest.mark.parametrize("loss_tuple", cplx_loss) def test_prox(self, loss_tuple): loss_class = loss_tuple[0] From 30f18ca4293f0b56720523ae3201d9112ebd7da2 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 22:07:14 -0700 Subject: [PATCH 16/24] Remove unused parameter --- scico/loss.py | 10 ---------- scico/test/test_loss.py | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index fd1571f18..18b9ee696 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -347,7 +347,6 @@ def __init__( A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, W: Optional[linop.Diagonal] = None, - prox_kwargs: dict = {"maxiter": 100, "tol": 1e-5}, ): r""" @@ -372,10 +371,6 @@ def __init__( super().__init__(y=y, A=A, scale=scale) - if prox_kwargs is None: - prox_kwargs = dict - self.prox_kwargs = prox_kwargs - if isinstance(self.A, linop.Identity) and snp.all(y >= 0): self.has_prox = True @@ -464,7 +459,6 @@ def __init__( A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, W: Optional[linop.Diagonal] = None, - prox_kwargs: dict = {"maxiter": 100, "tol": 1e-5}, ): r""" @@ -489,10 +483,6 @@ def __init__( super().__init__(y=y, A=A, scale=scale) - if prox_kwargs is None: - prox_kwargs = dict - self.prox_kwargs = prox_kwargs - if isinstance(self.A, linop.Identity) and snp.all(y >= 0): self.has_prox = True diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index ec5d17d59..61be9606c 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -175,7 +175,7 @@ def test_properties(self, loss_tuple): np.testing.assert_allclose(L(self.x), 0) y = loss_func(self.x) - L = loss_class(y=y, A=None, W=None, prox_kwargs=None) + L = loss_class(y=y, A=None, W=None) assert L.has_eval assert L.has_prox From 08e3a59297396fef74206ff46b46a84d963be8d7 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 1 Mar 2022 22:19:12 -0700 Subject: [PATCH 17/24] Improve prox_kwargs defaults handling --- scico/loss.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 18b9ee696..a4366b30f 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -161,7 +161,7 @@ def __init__( A: Optional[Union[Callable, operator.Operator]] = None, scale: float = 0.5, W: Optional[linop.Diagonal] = None, - prox_kwargs: dict = {"maxiter": 1000, "tol": 1e-12}, + prox_kwargs: Optional[dict] = None, ): r""" @@ -188,9 +188,11 @@ def __init__( super().__init__(y=y, A=A, scale=scale) - if prox_kwargs is None: - prox_kwargs = dict - self.prox_kwargs = prox_kwargs + default_prox_kwargs = {"maxiter": 100, "tol": 1e-5} + if prox_kwargs: + default_prox_kwargs.update(prox_kwargs) + self.prox_kwargs = default_prox_kwargs + prox_kwargs: dict = ({"maxiter": 100, "tol": 1e-5},) if isinstance(self.A, linop.LinearOperator): self.has_prox = True From 4155cd369e9074a312baf5c7fd02c55ee124bad8 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 2 Mar 2022 10:49:36 -0700 Subject: [PATCH 18/24] Minor edit --- scico/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index a4366b30f..19aa79b9e 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -316,8 +316,8 @@ def __init__( y = ensure_on_device(y) super().__init__(y=y, A=A, scale=scale) - #: Constant term in Poisson log likehood; equal to ln(y!) - self.const = gammaln(self.y + 1.0) # ln(y!) + #: Constant term, :math:`\ln(y!)`, in Poisson log likehood. + self.const = gammaln(self.y + 1.0) def __call__(self, x: Union[JaxArray, BlockArray]) -> float: Ax = self.A(x) From eca6c3ac02c74fd73d1abd72700dd02c16730448 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 2 Mar 2022 10:49:54 -0700 Subject: [PATCH 19/24] Improve tests --- scico/test/test_loss.py | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index 61be9606c..56f8e24d1 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -17,6 +17,9 @@ class TestLoss: + + l2loss = (loss.SquaredL2Loss, loss.WeightedSquaredL2Loss) + def setup_method(self): n = 4 dtype = np.float64 @@ -53,8 +56,9 @@ def test_generic_exception(self): with pytest.raises(NotImplementedError): L.prox(self.v, self.scalar) - def test_squared_l2(self): - L = loss.SquaredL2Loss(y=self.y, A=self.Ao) + @pytest.mark.parametrize("loss_class", l2loss) + def test_l2_loss(self, loss_class): + L = loss_class(y=self.y, A=self.Ao) assert L.has_eval assert L.has_prox @@ -66,8 +70,8 @@ def test_squared_l2(self): assert cL.scale == self.scalar * L.scale assert cL(self.v) == self.scalar * L(self.v) - # SquaredL2 with Diagonal linop has a prox - L_d = loss.SquaredL2Loss(y=self.y, A=self.Do) + # squared l2 loss with diagonal linop has a prox + L_d = loss_class(y=self.y, A=self.Do) # test eval np.testing.assert_allclose(L_d(self.v), 0.5 * ((self.Do @ self.v - self.y) ** 2).sum()) @@ -87,35 +91,19 @@ def test_weighted_squared_l2(self): L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W) assert L.has_eval assert L.has_prox - - # test eval np.testing.assert_allclose( L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum() ) + pf = prox_test(self.v, L, L.prox, 0.75) - cL = self.scalar * L - assert L.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L.scale - assert cL(self.v) == self.scalar * L(self.v) - - # SquaredL2 with Diagonal linop has a prox + # weighted l2 loss with diagonal linop has a prox L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W) - assert L_d.has_eval assert L_d.has_prox - - # test eval np.testing.assert_allclose( L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum() ) - - cL = self.scalar * L_d - assert L_d.scale == 0.5 # hasn't changed - assert cL.scale == self.scalar * L_d.scale - assert cL(self.v) == self.scalar * L_d(self.v) - pf = prox_test(self.v, L_d, L_d.prox, 0.75) - pf = prox_test(self.v, L, L.prox, 0.75) def test_poisson(self): L = loss.PoissonLoss(y=self.y, A=self.Ao_abs) @@ -133,9 +121,9 @@ def test_poisson(self): assert cL(v) == self.scalar * L(v) -class TestComplexLoss: +class TestAbsLoss: - cplx_loss = ( + abs_loss = ( (loss.WeightedSquaredL2AbsLoss, snp.abs), (loss.WeightedSquaredL2AbsSquaredLoss, lambda x: snp.abs(x) ** 2), ) @@ -154,7 +142,7 @@ def setup_method(self): scalar, key = randn((1,), key=key, dtype=dtype) self.scalar = scalar.copy().ravel()[0] - @pytest.mark.parametrize("loss_tuple", cplx_loss) + @pytest.mark.parametrize("loss_tuple", abs_loss) def test_properties(self, loss_tuple): loss_class = loss_tuple[0] loss_func = loss_tuple[1] @@ -193,7 +181,7 @@ def test_properties(self, loss_tuple): with pytest.raises(TypeError): L = loss_class(y=y, W=linop.Sum(input_shape=W.input_shape)) - @pytest.mark.parametrize("loss_tuple", cplx_loss) + @pytest.mark.parametrize("loss_tuple", abs_loss) def test_prox(self, loss_tuple): loss_class = loss_tuple[0] loss_func = loss_tuple[1] From 2a84d275fbed8c8d6049981c68421a428d5b5338 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 2 Mar 2022 18:00:46 -0700 Subject: [PATCH 20/24] Remove SquaredL2Loss, rename WeightedSquaredL2Loss to SquaredL2Loss --- examples/scripts/ct_astra_weighted_tv_admm.py | 2 +- scico/linop/radon_svmbir.py | 20 ++++----- scico/loss.py | 44 ++++--------------- scico/optimize/admm.py | 25 +++++------ scico/test/linop/test_radon_svmbir.py | 4 +- scico/test/optimize/test_admm.py | 4 +- scico/test/test_loss.py | 18 +++----- 7 files changed, 40 insertions(+), 77 deletions(-) diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index 3d87cd2ca..003df6b75 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -144,7 +144,7 @@ def postprocess(x): lambda_weighted = 1.14e2 weights = jax.device_put(counts / Io) -f = loss.WeightedSquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights)) +f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights)) admm_weighted = ADMM( f=f, diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 437fadaf8..435695d41 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -19,7 +19,7 @@ import jax.experimental.host_callback import scico.numpy as snp -from scico.loss import Loss, WeightedSquaredL2Loss +from scico.loss import Loss, SquaredL2Loss from scico.typing import JaxArray, Shape from ._linop import Diagonal, Identity, LinearOperator @@ -178,8 +178,7 @@ def _bproj_hcb(self, y): class SVMBIRExtendedLoss(Loss): - r"""Extended Weighted squared :math:`\ell_2` loss with svmbir CT - projector. + r"""Extended squared :math:`\ell_2` loss with svmbir CT projector. Generalization of the weighted squared :math:`\ell_2` loss of a CT reconstruction problem, @@ -195,13 +194,12 @@ class SVMBIRExtendedLoss(Loss): to :class:`scico.linop.Identity`. The extended loss differs from a typical weighted squared - :math:`\ell_2` loss as follows. - When ``positivity=True``, the prox projects onto the non-negative - orthant and the loss is infinite if any element of the input is - negative. When the ``is_masked`` option of the associated - :class:`.ParallelBeamProjector` is ``True``, the reconstruction is - computed over a masked region of the image as described - in class :class:`.ParallelBeamProjector`. + :math:`\ell_2` loss as follows. When `positivity=True`, the prox + projects onto the non-negative orthant and the loss is infinite if + any element of the input is negative. When the `is_masked` option + of the associated :class:`.ParallelBeamProjector` is ``True``, the + reconstruction is computed over a masked region of the image as + described in class :class:`.ParallelBeamProjector`. """ def __init__( @@ -299,7 +297,7 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: return jax.device_put(result.reshape(self.A.input_shape)) -class SVMBIRWeightedSquaredL2Loss(SVMBIRExtendedLoss, WeightedSquaredL2Loss): +class SVMBIRWeightedSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): r"""Weighted squared :math:`\ell_2` loss with svmbir CT projector. Weighted squared :math:`\ell_2` loss of a CT reconstruction problem, diff --git a/scico/loss.py b/scico/loss.py index 19aa79b9e..3cd44c440 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -90,7 +90,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: """ if self.f is None: raise NotImplementedError( - "Functional l is not defined and __call__ has" " not been overridden" + "Functional f is not defined and __call__ has not been overridden" ) return self.scale * self.f(self.A(x) - self.y) @@ -140,7 +140,7 @@ def set_scale(self, new_scale: float): self.scale = new_scale -class WeightedSquaredL2Loss(Loss): +class SquaredL2Loss(Loss): r"""Weighted squared :math:`\ell_2` loss. Weighted squared :math:`\ell_2` loss @@ -151,8 +151,9 @@ class WeightedSquaredL2Loss(Loss): A(\mb{x})\right) \;, where :math:`\alpha` is the scaling parameter and :math:`W` is an - instance of :class:`scico.linop.Diagonal`. If :math:`W` is None, - reverts to the behavior of :class:`.SquaredL2Loss`. + instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, + the weighting is an identity operator, giving an unweighted squared + :math:`\ell_2` loss. """ def __init__( @@ -261,33 +262,6 @@ def hessian(self) -> linop.LinearOperator: ) -class SquaredL2Loss(WeightedSquaredL2Loss): - r"""Squared :math:`\ell_2` loss. - - Squared :math:`\ell_2` loss - - .. math:: - \alpha \norm{\mb{y} - A(\mb{x})}_2^2 \;, - - where :math:`\alpha` is the scaling parameter. - """ - - def __init__( - self, - y: Union[JaxArray, BlockArray], - A: Optional[Union[Callable, operator.Operator]] = None, - scale: float = 0.5, - prox_kwargs: dict = {"maxiter": 100, "tol": 1e-5}, - ): - r""" - Args: - y: Measurement. - A: Forward operator. If ``None``, defaults to :class:`.Identity`. - scale: Scaling parameter. - """ - super().__init__(y=y, A=A, scale=scale, W=None, prox_kwargs=prox_kwargs) - - class PoissonLoss(Loss): r"""Poisson negative log likelihood loss. @@ -324,7 +298,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * snp.sum(Ax - self.y * snp.log(Ax) + self.const) -class WeightedSquaredL2AbsLoss(Loss): +class SquaredL2AbsLoss(Loss): r"""Weighted squared :math:`\ell_2` with absolute value loss. Weighted squared :math:`\ell_2` with absolute value loss @@ -367,7 +341,7 @@ def __init__( if snp.all(W.diagonal >= 0): self.W = W else: - raise ValueError(f"The weights, W.diagonal, must be non-negative.") + raise ValueError("The weights, W.diagonal, must be non-negative.") else: raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}.") @@ -436,7 +410,7 @@ def _dep_cubic_root(p, q): return r -class WeightedSquaredL2AbsSquaredLoss(Loss): +class SquaredL2AbsSquaredLoss(Loss): r"""Weighted squared :math:`\ell_2` with squared absolute value loss. Weighted squared :math:`\ell_2` with squared absolute value loss @@ -479,7 +453,7 @@ def __init__( if snp.all(W.diagonal >= 0): self.W = W else: - raise ValueError(f"The weights, W.diagonal, must be non-negative.") + raise ValueError("The weights, W.diagonal, must be non-negative.") else: raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}.") diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index aa209206b..a71a730f4 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -23,7 +23,7 @@ from scico.diagnostics import IterationStats from scico.functional import Functional from scico.linop import CircularConvolve, Identity, LinearOperator -from scico.loss import SquaredL2Loss, WeightedSquaredL2Loss +from scico.loss import SquaredL2Loss from scico.numpy.linalg import norm from scico.solver import cg as scico_cg from scico.solver import minimize @@ -120,12 +120,11 @@ class LinearSubproblemSolver(SubproblemSolver): for the case where :code:`f` is an :math:`\ell_2` or weighted :math:`\ell_2` norm, and :code:`f.A` is a linear operator, so that the subproblem involves solving a linear equation. This requires that - ``f.functional`` be an instance of either :class:`.SquaredL2Loss` - or :class:`.WeightedSquaredL2Loss` and for the forward operator - :code:`f.A` to be an instance of :class:`.LinearOperator`. + ``f.functional`` be an instance of :class:`.SquaredL2Loss` and for + the forward operator :code:`f.A` to be an instance of + :class:`.LinearOperator`. - In the case :class:`.WeightedSquaredL2Loss`, the - :math:`\mb{x}`-update step is + The :math:`\mb{x}`-update step is .. math:: @@ -133,9 +132,9 @@ class LinearSubproblemSolver(SubproblemSolver): \norm{\mb{y} - A x}_W^2 + \sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;, - where :math:`W` is the weighting :class:`.Diagonal` from the - :class:`.WeightedSquaredL2Loss` instance. This update step - reduces to the solution of the linear system + where :math:`W` a weighting :class:`.Diagonal` operator + or an :class:`.Identity` operator (i.e. no weighting). + This update step reduces to the solution of the linear system .. math:: @@ -144,8 +143,6 @@ class LinearSubproblemSolver(SubproblemSolver): A^H W \mb{y} + \sum_{i=1}^N \rho_i C_i^H ( \mb{z}^{(k)}_i - \mb{u}^{(k)}_i) \;. - In the case of :class:`.SquaredL2Loss`, :math:`W` is set to - the :class:`Identity` operator. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is @@ -201,10 +198,10 @@ def internal_init(self, admm): f"LinearSubproblemSolver requires f.A to be a scico.linop.LinearOperator; " f"got {type(admm.f.A)}" ) - if not isinstance(admm.f, WeightedSquaredL2Loss): # SquaredL2Loss is subclass + if not isinstance(admm.f, SquaredL2Loss): raise ValueError( - f"LinearSubproblemSolver requires f to be a scico.loss.WeightedSquaredL2Loss" - f"or scico.loss.SquaredL2Loss; got {type(admm.f)}" + "LinearSubproblemSolver requires f to be a scico.loss.SquaredL2Loss; " + f"got {type(admm.f)}" ) super().internal_init(admm) diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index b1f6b6c9d..440e03a58 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -7,7 +7,7 @@ import scico import scico.numpy as snp from scico.linop import Diagonal -from scico.loss import WeightedSquaredL2Loss +from scico.loss import SquaredL2Loss from scico.test.linop.test_linop import adjoint_test from scico.test.prox import prox_test @@ -154,7 +154,7 @@ def test_prox_cg(Nx, Ny, num_angles, num_channels, is_3d, weight_type, center_of else: f_sv = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(W)) - f_wg = WeightedSquaredL2Loss(y=y, A=A, W=Diagonal(W)) + f_wg = SquaredL2Loss(y=y, A=A, W=Diagonal(W)) v, _ = scico.random.normal(im.shape, dtype=im.dtype) v *= im.max() * 0.5 diff --git a/scico/test/optimize/test_admm.py b/scico/test/optimize/test_admm.py index 905b80fe9..a9a1c7529 100644 --- a/scico/test/optimize/test_admm.py +++ b/scico/test/optimize/test_admm.py @@ -197,9 +197,7 @@ def test_admm_quadratic(self): maxiter = 100 蟻 = 1e0 A = linop.MatrixOperator(self.Amx) - f = loss.WeightedSquaredL2Loss( - y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.饾浖 / 2.0 - ) + f = loss.SquaredL2Loss(y=self.y, A=A, W=linop.Diagonal(self.W[:, 0]), scale=self.饾浖 / 2.0) g_list = [(self.位 / 2) * functional.SquaredL2Norm()] C_list = [linop.MatrixOperator(self.Bmx)] rho_list = [蟻] diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index 56f8e24d1..aebe15a1f 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -17,9 +17,6 @@ class TestLoss: - - l2loss = (loss.SquaredL2Loss, loss.WeightedSquaredL2Loss) - def setup_method(self): n = 4 dtype = np.float64 @@ -56,9 +53,8 @@ def test_generic_exception(self): with pytest.raises(NotImplementedError): L.prox(self.v, self.scalar) - @pytest.mark.parametrize("loss_class", l2loss) - def test_l2_loss(self, loss_class): - L = loss_class(y=self.y, A=self.Ao) + def test_squared_l2(self): + L = loss.SquaredL2Loss(y=self.y, A=self.Ao) assert L.has_eval assert L.has_prox @@ -71,7 +67,7 @@ def test_l2_loss(self, loss_class): assert cL(self.v) == self.scalar * L(self.v) # squared l2 loss with diagonal linop has a prox - L_d = loss_class(y=self.y, A=self.Do) + L_d = loss.SquaredL2Loss(y=self.y, A=self.Do) # test eval np.testing.assert_allclose(L_d(self.v), 0.5 * ((self.Do @ self.v - self.y) ** 2).sum()) @@ -88,7 +84,7 @@ def test_l2_loss(self, loss_class): pf = prox_test(self.v, L, L.prox, 0.75) def test_weighted_squared_l2(self): - L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W) + L = loss.SquaredL2Loss(y=self.y, A=self.Ao, W=self.W) assert L.has_eval assert L.has_prox np.testing.assert_allclose( @@ -97,7 +93,7 @@ def test_weighted_squared_l2(self): pf = prox_test(self.v, L, L.prox, 0.75) # weighted l2 loss with diagonal linop has a prox - L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W) + L_d = loss.SquaredL2Loss(y=self.y, A=self.Do, W=self.W) assert L_d.has_eval assert L_d.has_prox np.testing.assert_allclose( @@ -124,8 +120,8 @@ def test_poisson(self): class TestAbsLoss: abs_loss = ( - (loss.WeightedSquaredL2AbsLoss, snp.abs), - (loss.WeightedSquaredL2AbsSquaredLoss, lambda x: snp.abs(x) ** 2), + (loss.SquaredL2AbsLoss, snp.abs), + (loss.SquaredL2AbsSquaredLoss, lambda x: snp.abs(x) ** 2), ) def setup_method(self): From 6887366f536d6ede22437b19ae60d16166dc1d02 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 2 Mar 2022 18:06:03 -0700 Subject: [PATCH 21/24] Rename SVMBIRWeightedSquaredL2Loss to SVMBIRSquaredL2Loss --- examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py | 6 +++--- .../scripts/ct_svmbir_ppp_bm3d_admm_prox.py | 6 +++--- examples/scripts/ct_svmbir_tv_multi.py | 4 ++-- scico/linop/radon_svmbir.py | 17 ++++++++--------- scico/test/linop/test_radon_svmbir.py | 12 ++++++------ 5 files changed, 22 insertions(+), 23 deletions(-) diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py index 05554c861..3700dc111 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py @@ -17,7 +17,7 @@ This version uses the data fidelity term as the ADMM f, and thus the optimization with respect to the data fidelity uses CG rather than the -prox of the SVMBIRWeightedSquaredL2Loss functional. +prox of the SVMBIRSquaredL2Loss functional. """ import numpy as np @@ -32,7 +32,7 @@ from scico import metric, plot from scico.functional import BM3D, NonNegativeIndicator from scico.linop import Diagonal, Identity -from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRWeightedSquaredL2Loss +from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -92,7 +92,7 @@ 蟻 = 15 # ADMM penalty parameter 蟽 = density * 0.18 # denoiser sigma -f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) +f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g0 = 蟽 * 蟻 * BM3D() g1 = NonNegativeIndicator() diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py index 43c1f614e..2131ca6f7 100644 --- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py +++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py @@ -17,7 +17,7 @@ This version uses the data fidelity term as one of the ADMM g functionals, and thus the optimization with respect to the data fidelity is able to -exploit the internal prox of the SVMBIRWeightedSquaredL2Loss functional. +exploit the internal prox of the SVMBIRSquaredL2Loss functional. """ import numpy as np @@ -32,7 +32,7 @@ from scico import metric, plot from scico.functional import BM3D, NonNegativeIndicator from scico.linop import Diagonal, Identity -from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRWeightedSquaredL2Loss +from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -92,7 +92,7 @@ 蟻 = 10 # ADMM penalty parameter 蟽 = density * 0.26 # denoiser sigma -f = SVMBIRWeightedSquaredL2Loss( +f = SVMBIRSquaredL2Loss( y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5, "ctol": 0.0} ) g0 = 蟽 * 蟻 * BM3D() diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index 4accf6c8d..37988ca2a 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -24,7 +24,7 @@ import scico.numpy as snp from scico import functional, linop, metric, plot from scico.linop import Diagonal -from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRWeightedSquaredL2Loss +from scico.linop.radon_svmbir import ParallelBeamProjector, SVMBIRSquaredL2Loss from scico.optimize import PDHG, LinearizedADMM from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -84,7 +84,7 @@ 位 = 1e-1 # L1 norm regularization parameter -f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) +f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g = 位 * functional.L21Norm() # regularization functional # The append=0 option makes the results of horizontal and vertical finite diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 435695d41..0f0eed64c 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -38,9 +38,9 @@ class ParallelBeamProjector(LinearOperator): ``is_masked`` option selects whether a valid region for projections (pixels outside this region are ignored when performing the projection) is active. This region of validity is also respected by - :meth:`.SVMBIRWeightedSquaredL2Loss.prox` when - :class:`.SVMBIRWeightedSquaredL2Loss` is initialized with a - :class:`ParallelBeamProjector` with this option enabled. + :meth:`.SVMBIRSquaredL2Loss.prox` when :class:`.SVMBIRSquaredL2Loss` + is initialized with a :class:`ParallelBeamProjector` with this option + enabled. """ def __init__( @@ -297,7 +297,7 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: return jax.device_put(result.reshape(self.A.input_shape)) -class SVMBIRWeightedSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): +class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): r"""Weighted squared :math:`\ell_2` loss with svmbir CT projector. Weighted squared :math:`\ell_2` loss of a CT reconstruction problem, @@ -307,8 +307,8 @@ class SVMBIRWeightedSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss): \alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right) \;, - where :math:`A` is a :class:`.ParallelBeamProjector`, - :math:`\alpha` is the scaling parameter and :math:`W` is an instance + where :math:`A` is a :class:`.ParallelBeamProjector`, :math:`\alpha` + is the scaling parameter and :math:`W` is an instance of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set to :class:`scico.linop.Identity`. """ @@ -319,7 +319,7 @@ def __init__( prox_kwargs: Optional[dict] = None, **kwargs, ): - r"""Initialize a :class:`SVMBIRWeightedSquaredL2Loss` object. + r"""Initialize a :class:`SVMBIRSquaredL2Loss` object. Args: y: Sinogram measurement. @@ -335,8 +335,7 @@ def __init__( if self.A.is_masked: raise ValueError( - "is_masked must be false for the ParallelBeamProjector in " - "SVMBIRWeightedSquaredL2Loss." + "is_masked must be false for the ParallelBeamProjector in " "SVMBIRSquaredL2Loss." ) diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/test_radon_svmbir.py index 440e03a58..569e13b0f 100644 --- a/scico/test/linop/test_radon_svmbir.py +++ b/scico/test/linop/test_radon_svmbir.py @@ -17,7 +17,7 @@ from scico.linop.radon_svmbir import ( ParallelBeamProjector, SVMBIRExtendedLoss, - SVMBIRWeightedSquaredL2Loss, + SVMBIRSquaredL2Loss, ) except ImportError as e: pytest.skip("svmbir not installed", allow_module_level=True) @@ -98,7 +98,7 @@ def test_prox(Nx, Ny, num_angles, num_channels, is_3d, center_offset, is_masked) if is_masked: f = SVMBIRExtendedLoss(y=sino, A=A, positivity=False) else: - f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A) + f = SVMBIRSquaredL2Loss(y=sino, A=A) prox_test(v, f, f.prox, alpha=0.25) @@ -123,7 +123,7 @@ def test_prox_weights(Nx, Ny, num_angles, num_channels, is_3d, center_offset, is if is_masked: f = SVMBIRExtendedLoss(y=sino, A=A, W=W, positivity=False) else: - f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, W=W) + f = SVMBIRSquaredL2Loss(y=sino, A=A, W=W) prox_test(v, f, f.prox, alpha=0.25) @@ -152,7 +152,7 @@ def test_prox_cg(Nx, Ny, num_angles, num_channels, is_3d, weight_type, center_of if is_masked: f_sv = SVMBIRExtendedLoss(y=y, A=A, W=Diagonal(W), positivity=False) else: - f_sv = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(W)) + f_sv = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W)) f_wg = SquaredL2Loss(y=y, A=A, W=Diagonal(W)) @@ -187,7 +187,7 @@ def test_approx_prox( if is_masked or positivity: f = SVMBIRExtendedLoss(y=y, A=A, W=Diagonal(W), positivity=positivity) else: - f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(W)) + f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W)) xprox = snp.array(f.prox(v, lam=位)) @@ -196,7 +196,7 @@ def test_approx_prox( y=y, A=A, W=Diagonal(W), prox_kwargs={"maxiter": 2}, positivity=positivity ) else: - f_approx = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={"maxiter": 2}) + f_approx = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(W), prox_kwargs={"maxiter": 2}) xprox_approx = snp.array(f_approx.prox(v, lam=位, v0=xprox)) From e620bf01dd466112e971635002b8d9764157078c Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 2 Mar 2022 18:10:10 -0700 Subject: [PATCH 22/24] Update submodule --- data | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data b/data index 71f16aa2c..031b1d2e3 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 71f16aa2ca6e67fa8f6e35006e1a572386306bde +Subproject commit 031b1d2e3717067428979090feb59b23c15334a0 From 393a88abed6841f6cbfe6aee8d0db9a10f4fdfa6 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 2 Mar 2022 18:13:07 -0700 Subject: [PATCH 23/24] Rename loss --- scico/loss.py | 2 +- scico/test/test_loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/loss.py b/scico/loss.py index 3cd44c440..26db84669 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -410,7 +410,7 @@ def _dep_cubic_root(p, q): return r -class SquaredL2AbsSquaredLoss(Loss): +class SquaredL2SquaredAbsLoss(Loss): r"""Weighted squared :math:`\ell_2` with squared absolute value loss. Weighted squared :math:`\ell_2` with squared absolute value loss diff --git a/scico/test/test_loss.py b/scico/test/test_loss.py index aebe15a1f..d286e0064 100644 --- a/scico/test/test_loss.py +++ b/scico/test/test_loss.py @@ -121,7 +121,7 @@ class TestAbsLoss: abs_loss = ( (loss.SquaredL2AbsLoss, snp.abs), - (loss.SquaredL2AbsSquaredLoss, lambda x: snp.abs(x) ** 2), + (loss.SquaredL2SquaredAbsLoss, lambda x: snp.abs(x) ** 2), ) def setup_method(self): From 3d9e71509ea60a082eec09ecbb08326214994e9b Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 2 Mar 2022 18:16:57 -0700 Subject: [PATCH 24/24] Update submodule --- data | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data b/data index 031b1d2e3..d7b0478e3 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 031b1d2e3717067428979090feb59b23c15334a0 +Subproject commit d7b0478e3769f0cb8ec72b7cf936eb2a997ee8f2