Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with W^1/2 weight exponent #78

Merged
merged 22 commits into from
Nov 12, 2021
Merged

Problem with W^1/2 weight exponent #78

merged 22 commits into from
Nov 12, 2021

Conversation

tbalke
Copy link
Contributor

@tbalke tbalke commented Nov 9, 2021

Initially I thought svmbir and scico use different notation for weighted norms, so I ignored this. Now after second thought I think we need remove the squaring of the weights.

Our weight_op is ||y||_W^2 = ||W^(1/2) y||^2 = y^T W y so there is no need to square the weights before passing them to svmbir. (Albeit in the svmbir docs the notation was not really clear until I updated it just now.)

In the example scripts I removed the square root operation on the weights to compensate for the squaring. This, however, yields different regularization for the cg-example, so it requires retuning. I suggest to return to this later, since I am expecting there may be more changes to the svmbir interface in the near future which would require another retuning.

@tbalke tbalke added bug Something isn't working priority: high High priority labels Nov 9, 2021
@tbalke
Copy link
Contributor Author

tbalke commented Nov 9, 2021

It fails the test_prox_weights now which may not be a bad thing as we are still thinking that the prox is incorrect due to the rot_radius issue. Opinions?

@tbalke tbalke changed the title Problem with {W}^1/2 weight exponent Problem with W^1/2 weight exponent Nov 9, 2021
@bwohlberg
Copy link
Collaborator

It fails the test_prox_weights now which may not be a bad thing as we are still thinking that the prox is incorrect due to the rot_radius issue. Opinions?

If that were the explanation, then I would expect that the test without weights would also be failing. Perhaps it's simply that the error due to unclear svmbir docs was also propagated into the tests?
@Michael-T-McCann : since you wrote the tests for this (I think), could you please take a look?

@Michael-T-McCann
Copy link
Contributor

The relevant documentation is https://svmbir.readthedocs.io/en/latest/theory.html and https://scico.readthedocs.io/en/latest/_autosummary/scico.loss.html#scico.loss.WeightedSquaredL2Loss. What I see is that SVMBIR and SCICO use the same definition for the weighted norm, but in SVMBIR you pass the weight matrix ("$\Lambda$ corresponds to weights") and in SCICO you pass the square root of the weight matrix "weight_op ... Corresponds to $W^{1/2}".

Hence

self.weights = (
snp.conj(self.weight_op.diagonal) * self.weight_op.diagonal
) # because weight_op is W^{1/2}
in our wrapper.

@tbalke
Copy link
Contributor Author

tbalke commented Nov 9, 2021

Now I am confused. Are we really handling this consistently?

In class WeightedSquaredL2Loss(Loss): on one hand we have

scico/scico/loss.py

Lines 227 to 236 in 3852542

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
c = self.scale * lam
A = self.A.diagonal
W = self.weight_op.diagonal
lhs = c * 2.0 * A.conj() * W * W.conj() * self.y + x
ATWTWA = c * 2.0 * A.conj() * W.conj() * W * A
return lhs / (ATWTWA + 1.0)
else:
raise NotImplementedError

So the part lhs = c * 2.0 * A.conj() * W * W.conj() * self.y + x seems to align with @Michael-T-McCann's assessment.
But then we have

scico/scico/loss.py

Lines 239 to 255 in 3852542

def hessian(self) -> linop.LinearOperator:
r"""If ``self.A`` is a :class:`scico.linop.LinearOperator`, returns a
:class:`scico.linop.LinearOperator` corresponding to Hessian :math:`\mathrm{A^* W A}`.
Otherwise not implemented.
"""
if isinstance(self.A, linop.LinearOperator):
return linop.LinearOperator(
input_shape=self.A.input_shape,
output_shape=self.A.input_shape,
eval_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
adj_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),
)
else:
raise NotImplementedError(
f"Hessian is not implemented for {type(self)} when `A` is {type(self.A)}; must be LinearOperator"
)

so the line eval_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))), seems to instead use the other version.

In other words we use A^T W W^T y in one case and A^T W A x in the other, which looks inconsistent.

@Michael-T-McCann
Copy link
Contributor

Now I am confused. Are we really handling this consistently?

You may be on to something. What would be convincing to me is if you derived, e.g., the Hessian by hand, then wrote a test showing that our Hessian function returns the wrong thing.

@tbalke
Copy link
Contributor Author

tbalke commented Nov 9, 2021

I am starting to believe that my changes in this PR are incorrect, and you were right, @Michael-T-McCann.

(1) Let's start with that the loss is a || Ax-y ||^2_W. Then according to

scico/scico/loss.py

Lines 224 to 225 in 3852542

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return self.scale * self.functional(self.weight_op(self.y - self.A(x)))

we have that weight_op.diagonal = W^(1/2).

(2) The prox in

scico/scico/loss.py

Lines 227 to 236 in 3852542

def prox(self, x: Union[JaxArray, BlockArray], lam: float) -> Union[JaxArray, BlockArray]:
if isinstance(self.A, linop.Diagonal):
c = self.scale * lam
A = self.A.diagonal
W = self.weight_op.diagonal
lhs = c * 2.0 * A.conj() * W * W.conj() * self.y + x
ATWTWA = c * 2.0 * A.conj() * W.conj() * W * A
return lhs / (ATWTWA + 1.0)
else:
raise NotImplementedError

seems to use the weights (*almost) correctly but the choice of variable names is deeply confusing. It would be better to have something like

            sqrtW = self.weight_op.diagonal
            lhs = c * 2.0 * A.conj() * sqrtW * sqrtW.conj() * self.y + x

(*) The prox is actually incorrect, when A is diagonal and W is dense. Then there is no self.weight_op.diagonal.

(3) The Hessian seems to be incorrect. With the definition of the loss in (1) the Hessian would be 2a A^T W A, which translates to 2 * scale * A.adj(weight_op(weight_op(A(x)))). However, currently we have

eval_fn=lambda x: 2 * self.scale * self.A.adj(self.weight_op(self.A(x))),

This poses these issues for me:

  • Why is the weight operation to be kept so general rather than using a matrix operation?

  • If we use a matrix operation, would it make more sense to store W directly (rather than W^(1/2)?

  • The Hessian also never gets tested according to codecov, which should be changed

  • fix prox for non-diagonal weights

  • fix Hessian

  • add test for Hessian

@Michael-T-McCann
Copy link
Contributor

Michael-T-McCann commented Nov 9, 2021

Nice progress on this! A few thoughts.

* Why is the weight operation to be kept so general rather than using a matrix operation?

I don't know what you mean by "using a matrix operation." That said, it seems to me we only ever use diagonal weightings. I would assume the class was written to allow nondiagnal weightings simply because it is mathematically defined. To discuss: simplifying this to diagonal only?

* If we use a matrix operation, would it make more sense to store `W` directly (rather than `W^(1/2)`?

I suspect storing the square root was a method to enforce positive definiteness of W. Whether we keep that depends on the first point. If we do keep it, holy cow we need to change the variable name to something like W_sqroot. The documentation is also too subtle.

* The Hessian also never gets tested according to codecov, which should be changed

To take this a step further: if the Hessian is neither tested nor (apparently) used anywhere, we should consider dropping it. If we decide to keep it, agreed it needs tests.

@bwohlberg
Copy link
Collaborator

Are there really only two example scripts that are affected by the change in the meaning of W?

@@ -140,7 +147,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
Args:
x : Point at which to evaluate loss.
"""
return self.scale * self.functional(self.y - self.A(x))
return self.scale * (snp.abs(self.y - self.A(x)) ** 2).sum()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change desirable? Dropping the use if self.functional seems like it may have consequences for derived classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it does. Does it?

I have this to keep it consistent with the weighted case where we can avoid computing the square root of the weights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, and that's a worthwhile goal, but we should make sure there aren't any undesirable consequences before we make this change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind this change for the time being.

Longer term, consider: Why is functional a property at all if it is not used here? Broadly, I think this discussion is a symptom of the existence of Loss. One might hope to implement __call__ at the level of Loss (and therefore using self.functional) to reduce repeated code. But we can't do that, because Loss is so general it doesn't know that it should be A@x - y.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed: if this change is made, we should consider removing the functional attribute. With respect to Loss, do you recall why it's so general? If there's good reason, perhaps we should have a specialization that really is A@x - y?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason it is general is because of the Poisson. I think the reason it exists at all is that it used to not be a subclass of Functional and therefore having a base Loss class made sense.

scico/loss.py Show resolved Hide resolved
scico/loss.py Outdated
):

r"""Initialize a :class:`WeightedSquaredL2Loss` object.

Args:
y : Measurements
y : Measurements.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: change to "Measurement."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 7db54df

self.has_prox = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return self.scale * self.functional(self.weight_op(self.y - self.A(x)))
return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See earlier comment on similar lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI seems to be failing because the changes would slightly reduce the test coverage percentage.

Let's see whether that is still the case when we add the test for the Hessian.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still unhappy, it seems. It would be best to address this before we merge.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for loss.PoissonLoss which will, I assume, resolve this.

@Michael-T-McCann : Would you not agree that the Loss tests should be in a separate test_loss.py file rather than included in test_functional.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loss is currently a subclass of Functional (was not always this way). Therefore I think it makes sense for the losses to get tested in test_functional.py, unless the file is much too long.

@codecov
Copy link

codecov bot commented Nov 10, 2021

Codecov Report

Merging #78 (449d0b2) into main (f1a138c) will increase coverage by 0.13%.
The diff coverage is 75.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #78      +/-   ##
==========================================
+ Coverage   90.04%   90.17%   +0.13%     
==========================================
  Files          42       42              
  Lines        2982     2981       -1     
==========================================
+ Hits         2685     2688       +3     
+ Misses        297      293       -4     
Flag Coverage Δ
unittests 90.17% <75.00%> (+0.13%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
scico/admm.py 88.77% <0.00%> (ø)
scico/loss.py 84.48% <76.19%> (+2.90%) ⬆️
scico/linop/radon_svmbir.py 88.52% <100.00%> (+1.02%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f1a138c...449d0b2. Read the comment docs.

@bwohlberg
Copy link
Collaborator

CI seems to be failing because the changes would slightly reduce the test coverage percentage.

Copy link
Contributor

@smajee smajee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning behind removing the diagonal op check in scico/linop/radon_svmbir.py ?
svmbir does not support non-diagonal weighing.

Otherwise looks good to me.

@tbalke
Copy link
Contributor Author

tbalke commented Nov 10, 2021

What is the reasoning behind removing the diagonal op check in scico/linop/radon_svmbir.py ? svmbir does not support non-diagonal weighing.

Otherwise looks good to me.

The check is now in the __init__ of WeightedSquaredL2Loss which now always requires a diagonal weight. The check is is performed at the super().__init__(*args, **kwargs) line.

scico/loss.py Outdated
if isinstance(A, operator.Operator):
self.is_smooth = A.is_smooth
else:
self.is_smooth = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Michael-T-McCann : Any thoughts on this addition? It's a bit ugly mathematically because we're labeling a non-smooth functional as smooth (because we smoothed it at zero), but this seems to be the right choice from a practical/computational perspective.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote: document the use of epsilon, set is_smooth=True.

The only way around the epsilon I come up with is accounting for dark count rate in the forward model.

Copy link
Contributor Author

@tbalke tbalke Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with the epsilon it is still not smooth for Ax < (-epsilon). I think a way to get around the mathematical imprecision is to say that the attribute we are looking for is whether it can be used in pgm or not.

If A(x) = exp(x) for example, then we can use the current PoissonLoss with no problem. But if A(x) is linear then we can get into trouble even when using the epsilon since Ax < (-epsilon) is very likely.

So smoothness of the loss \alpha L(y, A(x)) is not only dependent on L but also on A(.) or perhaps y.

@bwohlberg is the is_smooth attribute used for anything other than the check in pgm?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even with the epsilon it is still not smooth for Ax < (-epsilon).

Good point. Given that epsilon just shifts the problem, perhaps we should remove it? Alternatively, instead of adding epsilon, set values less than epsilon to epsilon?

@bwohlberg is the is_smooth attribute used for anything other than the check in pgm?

I'm not sure. That may be the only use at the moment. Perhaps the whole is_smooth mechanism is worth a re-think. How is the issue of zeros handled in the Poisson loss example with PGM?

Copy link
Contributor Author

@tbalke tbalke Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that epsilon just shifts the problem, perhaps we should remove it? Alternatively, instead of adding epsilon, set values less than epsilon to epsilon?

Yes, this may be a better option. But see below.

Perhaps the whole is_smooth mechanism is worth a re-think.

Yes. I think if pgm rejected all non-smooth L(y, A(x)) for all x, we would be unnecessarily impractical. A looser condition of L(y, A(x)) smooth around x_0 or in some feasible set could be more practical.

How is the issue of zeros handled in the Poisson loss example with PGM?

(a) The initial condition is > 0 (almost surely).
(b) There is a non-negative indicator.
(c) Matrix A does not have negative entires
However, it seems like initializing to negatives or zero breaks the example code. I guess this kind of problem is not really that realistic. If we assume y to be Poisson, then in what world would A(x) be negative for any x in the feasible set?

I am leaning towards completely removing the epsilon and then then either A or the feasible set needs to be such that A(x)>0. (I think that currently would not break anything but I would not bet on it.)

@Michael-T-McCann ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had assumed the epsilon was necessary for some example or another. If it isn't, sure, remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue opened in #89

scico/loss.py Outdated Show resolved Hide resolved
scico/loss.py Outdated
if isinstance(A, operator.Operator):
self.is_smooth = A.is_smooth
else:
self.is_smooth = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote: document the use of epsilon, set is_smooth=True.

The only way around the epsilon I come up with is accounting for dark count rate in the forward model.

@tbalke tbalke mentioned this pull request Nov 12, 2021
@bwohlberg bwohlberg merged commit 4684909 into main Nov 12, 2021
@bwohlberg bwohlberg deleted the thilo/svmbir_weight_fix branch November 12, 2021 22:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: high High priority
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants