Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with W^1/2 weight exponent #78

Merged
merged 22 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scripts/ct_astra_weighted_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def postprocess(x):

weights = counts / Io # scale by Io to balance the data vs regularization term
W = linop.Diagonal(snp.sqrt(weights))
bwohlberg marked this conversation as resolved.
Show resolved Hide resolved
f = loss.WeightedSquaredL2Loss(y=y, A=A, weight_op=W)
f = loss.WeightedSquaredL2Loss(y=y, A=A, W=W)

admm_weighted = ADMM(
f=f,
Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 100 # ADMM penalty parameter
σ = density * 0.2 # denoiser sigma

weight_op = Diagonal(weights ** 0.5)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@
ρ = 10 # ADMM penalty parameter
σ = density * 0.26 # denoiser sigma

weight_op = Diagonal(weights ** 0.5)

f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, weight_op=weight_op, scale=0.5)
f = SVMBIRWeightedSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g0 = σ * ρ * BM3D()
g1 = NonNegativeIndicator()

Expand Down
2 changes: 1 addition & 1 deletion scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]:

if self.admm.f is not None:
if isinstance(self.admm.f, WeightedSquaredL2Loss):
ATWy = self.admm.f.A.adj(self.admm.f.weight_op @ self.admm.f.y)
ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal @ self.admm.f.y)
rhs += 2.0 * self.admm.f.scale * ATWy
else:
ATy = self.admm.f.A.adj(self.admm.f.y)
Expand Down
13 changes: 2 additions & 11 deletions scico/linop/radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from scico.loss import WeightedSquaredL2Loss
from scico.typing import JaxArray, Shape

from ._linop import Diagonal, LinearOperator
from ._linop import LinearOperator


class ParallelBeamProjector(LinearOperator):
Expand Down Expand Up @@ -120,21 +120,12 @@ def __init__(self, *args, **kwargs):
"to instantiate a `SVMBIRWeightedSquaredL2Loss`."
)

if not isinstance(self.weight_op, Diagonal):
raise ValueError(
f"`weight_op` must be `Diagonal` but instead got {type(self.weight_op)}"
)

self.weights = (
snp.conj(self.weight_op.diagonal) * self.weight_op.diagonal
) # because weight_op is W^{1/2}

self.has_prox = True

def prox(self, v: JaxArray, lam: float) -> JaxArray:
v = v.reshape(self.A.svmbir_input_shape)
y = self.y.reshape(self.A.svmbir_output_shape)
weights = self.weights.reshape(self.A.svmbir_output_shape)
weights = self.W.diagonal.reshape(self.A.svmbir_output_shape)
sigma_p = snp.sqrt(lam)
result = svmbir.recon(
np.array(y),
Expand Down
75 changes: 44 additions & 31 deletions scico/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def __init__(
A: Optional[Union[Callable, operator.Operator]] = None,
scale: float = 0.5,
):
r"""Initialize a :class:`SquaredL2Loss` object.

Args:
y : Measurements.
A : Forward operator. If None, defaults to :class:`.Identity`.
scale : Scaling parameter.
"""
y = ensure_on_device(y)
self.functional = functional.SquaredL2Norm()
super().__init__(y=y, A=A, scale=scale)
Expand All @@ -140,7 +147,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
Args:
x : Point at which to evaluate loss.
"""
return self.scale * self.functional(self.y - self.A(x))
return self.scale * (snp.abs(self.y - self.A(x)) ** 2).sum()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change desirable? Dropping the use if self.functional seems like it may have consequences for derived classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it does. Does it?

I have this to keep it consistent with the weighted case where we can avoid computing the square root of the weights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, and that's a worthwhile goal, but we should make sure there aren't any undesirable consequences before we make this change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind this change for the time being.

Longer term, consider: Why is functional a property at all if it is not used here? Broadly, I think this discussion is a symptom of the existence of Loss. One might hope to implement __call__ at the level of Loss (and therefore using self.functional) to reduce repeated code. But we can't do that, because Loss is so general it doesn't know that it should be A@x - y.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed: if this change is made, we should consider removing the functional attribute. With respect to Loss, do you recall why it's so general? If there's good reason, perhaps we should have a specialization that really is A@x - y?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason it is general is because of the Poisson. I think the reason it exists at all is that it used to not be a subclass of Functional and therefore having a base Loss class made sense.

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
Expand All @@ -154,7 +161,7 @@ def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, Bl
@property
def hessian(self) -> linop.LinearOperator:
r"""If ``self.A`` is a :class:`.LinearOperator`, returns a new :class:`.LinearOperator` corresponding
to Hessian :math:`\mathrm{A^*A}`.
to Hessian :math:`2 \mathrm{scale} \cdot \mathrm{A^* A}`.

bwohlberg marked this conversation as resolved.
Show resolved Hide resolved
Otherwise not implemented.
"""
Expand All @@ -171,11 +178,11 @@ class WeightedSquaredL2Loss(Loss):
Weighted squared :math:`\ell_2` loss.

.. math::
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_{\mathrm{W}}^2 =
\mathrm{scale} \cdot \norm{\mathrm{W}^{1/2} \left( \mb{y} - A(\mb{x})\right)}_2^2\;
\mathrm{scale} \cdot \norm{\mb{y} - A(\mb{x})}_W^2 =
\mathrm{scale} \cdot \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} - A(\mb{x})\right)\;

Where :math:`\mathrm{W}` is an instance of :class:`scico.linop.LinearOperator`. If
:math:`\mathrm{W}` is None, reverts to the behavior of :class:`.SquaredL2Loss`.
Where :math:`W` is an instance of :class:`scico.linop.Diagonal`. If
:math:`W` is None, reverts to the behavior of :class:`.SquaredL2Loss`.

"""

Expand All @@ -184,30 +191,33 @@ def __init__(
y: Union[JaxArray, BlockArray],
A: Optional[Union[Callable, operator.Operator]] = None,
scale: float = 0.5,
weight_op: Optional[operator.Operator] = None,
W: Optional[linop.Diagonal] = None,
):

r"""Initialize a :class:`WeightedSquaredL2Loss` object.

Args:
y : Measurements
y : Measurements.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: change to "Measurement."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 7db54df

A : Forward operator. If None, defaults to :class:`.Identity`.
scale : Scaling parameter
weight_op: Weighting linear operator. Corresponds to :math:`W^{1/2}`
in the standard definition of the weighted squared :math:`\ell_2` loss.
scale : Scaling parameter.
W: Weighting diagonal operator. Must be non-negative.
If None, defaults to :class:`.Identity`.
"""
y = ensure_on_device(y)

self.weight_op: operator.Operator
self.W: linop.Diagonal

self.functional = functional.SquaredL2Norm()
if weight_op is None:
self.weight_op = linop.Identity(y.shape)
elif isinstance(weight_op, linop.LinearOperator):
self.weight_op = weight_op
if W is None:
self.W = linop.Identity(y.shape)
elif isinstance(W, linop.Diagonal):
if snp.all(W.diagonal >= 0):
self.W = W
else:
raise Exception(f"The weights, W.diagonal, must be non-negative.")
else:
raise TypeError(f"weight_op must be None or a LinearOperator, got {type(weight_op)}")
raise TypeError(f"W must be None or a linop.Diagonal, got {type(W)}")

super().__init__(y=y, A=A, scale=scale)

if isinstance(A, operator.Operator):
Expand All @@ -218,40 +228,43 @@ def __init__(
if isinstance(self.A, linop.LinearOperator):
self.is_quadratic = True

if isinstance(self.A, linop.Diagonal) and isinstance(self.weight_op, linop.Diagonal):
if isinstance(self.A, linop.Diagonal) and isinstance(self.W, linop.Diagonal):
self.has_prox = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return self.scale * self.functional(self.weight_op(self.y - self.A(x)))
return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See earlier comment on similar lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI seems to be failing because the changes would slightly reduce the test coverage percentage.

Let's see whether that is still the case when we add the test for the Hessian.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still unhappy, it seems. It would be best to address this before we merge.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for loss.PoissonLoss which will, I assume, resolve this.

@Michael-T-McCann : Would you not agree that the Loss tests should be in a separate test_loss.py file rather than included in test_functional.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loss is currently a subclass of Functional (was not always this way). Therefore I think it makes sense for the losses to get tested in test_functional.py, unless the file is much too long.

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
c = self.scale * lam
c = 2.0 * self.scale * lam
A = self.A.diagonal
W = self.weight_op.diagonal
lhs = c * 2.0 * A.conj() * W * W.conj() * self.y + x
ATWTWA = c * 2.0 * A.conj() * W.conj() * W * A
return lhs / (ATWTWA + 1.0)
W = self.W.diagonal
lhs = c * A.conj() * W * self.y + x
ATWA = c * A.conj() * W * A
return lhs / (ATWA + 1.0)
else:
raise NotImplementedError

@property
def hessian(self) -> linop.LinearOperator:
r"""If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a
:class:`scico.linop.LinearOperator` corresponding to Hessian :math:`\mathrm{A^* W A}`.
:class:`scico.linop.LinearOperator` corresponding to the Hessian
:math:`2 \mathrm{scale} \cdot \mathrm{A^* W A}`.

Otherwise not implemented.
"""
if isinstance(self.A, linop.LinearOperator):
A = self.A
W = self.W
if isinstance(A, linop.LinearOperator):
return linop.LinearOperator(
input_shape=self.A.input_shape,
output_shape=self.A.input_shape,
eval_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
adj_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
input_shape=A.input_shape,
output_shape=A.input_shape,
eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),
adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),
)
else:
raise NotImplementedError(
f"Hessian is not implemented for {type(self)} when `A` is {type(self.A)}; must be LinearOperator"
f"Hessian is not implemented for {type(self)} when `A` is {type(A)}; must be LinearOperator"
)


Expand Down
4 changes: 2 additions & 2 deletions scico/test/linop/test_radon_svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def test_prox_weights(Nx, Ny, num_angles, num_channels, is_3d):

# test with weights
weights, _ = scico.random.uniform(sino.shape, dtype=im.dtype)
D = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, weight_op=D)
W = scico.linop.Diagonal(weights)
f = SVMBIRWeightedSquaredL2Loss(y=sino, A=A, W=W)
prox_test(v, f, f.prox, alpha=0.25)
11 changes: 5 additions & 6 deletions scico/test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def setup_method(self):
W = 0.1 * W + 1.0
self.Ao = linop.MatrixOperator(A)
self.Do = linop.Diagonal(D)
self.Wo = linop.Diagonal(W)
self.W = linop.Diagonal(W)
self.y, key = randn((n,), key=key, dtype=dtype)
self.v, key = randn((n,), key=key, dtype=dtype) # point for prox eval
scalar, key = randn((1,), key=key, dtype=dtype)
Expand Down Expand Up @@ -377,14 +377,14 @@ def test_squared_l2(self):
pf = prox_test(self.v, L_d, L_d.prox, 0.75)

def test_weighted_squared_l2(self):
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, weight_op=self.Wo)
L = loss.WeightedSquaredL2Loss(y=self.y, A=self.Ao, W=self.W)
assert L.is_smooth == True
assert L.has_eval == True
assert L.has_prox == False # not diagonal

# test eval
np.testing.assert_allclose(
L(self.v), 0.5 * ((self.Wo @ (self.Ao @ self.v - self.y)) ** 2).sum()
L(self.v), 0.5 * (self.W @ (self.Ao @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L
Expand All @@ -393,16 +393,15 @@ def test_weighted_squared_l2(self):
assert cL(self.v) == self.scalar * L(self.v)

# SquaredL2 with Diagonal linop has a prox
Wo = self.Wo
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, weight_op=Wo)
L_d = loss.WeightedSquaredL2Loss(y=self.y, A=self.Do, W=self.W)

assert L_d.is_smooth == True
assert L_d.has_eval == True
assert L_d.has_prox == True

# test eval
np.testing.assert_allclose(
L_d(self.v), 0.5 * ((self.Wo @ (self.Do @ self.v - self.y)) ** 2).sum()
L_d(self.v), 0.5 * (self.W @ (self.Do @ self.v - self.y) ** 2).sum()
)

cL = self.scalar * L_d
Expand Down