Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect handling of scale in Loss.grad #468

Closed
bwohlberg opened this issue Nov 9, 2023 · 1 comment · Fixed by #470
Closed

Incorrect handling of scale in Loss.grad #468

bwohlberg opened this issue Nov 9, 2023 · 1 comment · Fixed by #470
Assignees
Labels
bug Something isn't working

Comments

@bwohlberg
Copy link
Collaborator

bwohlberg commented Nov 9, 2023

There is a bug in Loss.grad handling of the scale attribute, but only when it's set via scalar multiplication:

import jax
from scico.loss import SquaredL2Loss
from scico.functional import L2Norm
import scico.numpy as snp

f = SquaredL2Loss(y=snp.zeros((4,)))
g = SquaredL2Loss(y=snp.zeros((4,)), scale=5)
h = 10 * f

# __call__ is correct
f(snp.ones((4,)))
>> Array(2., dtype=float32)
g(snp.ones((4,)))
>> Array(20., dtype=float32)
h(snp.ones((4,)))
>> Array(20., dtype=float32)

# __grad__ is broken
f.grad(snp.ones((4,)))
>> Array([1., 1., 1., 1.], dtype=float32)
g.grad(snp.ones((4,)))
>> Array([10., 10., 10., 10.], dtype=float32)
h.grad(snp.ones((4,)))
>> Array([1., 1., 1., 1.], dtype=float32

The same bug is not present in Functional.grad:

f = L2Norm()
g = 10 * f

f.grad(snp.ones((4,)))
>> Array([0.5, 0.5, 0.5, 0.5], dtype=float32)
g.grad(snp.ones((4,)))
>> Array([5., 5., 5., 5.], dtype=float32)
@bwohlberg bwohlberg added the bug Something isn't working label Nov 9, 2023
bwohlberg added a commit that referenced this issue Nov 14, 2023
@bwohlberg bwohlberg mentioned this issue Nov 14, 2023
@bwohlberg
Copy link
Collaborator Author

The bug turns out to be due to a combination of this

def __init__(self):
self._grad = scico.grad(self.__call__)

and this

scico/scico/loss.py

Lines 126 to 129 in 5ffd1f9

def __mul__(self, other):
new_loss = copy(self)
new_loss.set_scale(self.scale * other)
return new_loss

The copy call does not result in an __init__ call, so the new Loss object ends up with _grad set to the function that was originally constructed when __init__ was called for the "original", unscaled Loss object.

PR #470 has a simple fix, but this issue raises a few broader design questions:

  • Is there any value in initializing a _grad attribute of Functional objects rather than simply defining their grad method as directly computing the gradient from __call__?
  • Would the Loss implementation not be at least slightly simpler if it were derived from ScaledFunctional rather than Functional?

bwohlberg added a commit that referenced this issue Nov 15, 2023
* Update change log

* Resolve #468 and add corresponding test

* Shorten comment

* Resolve some oversights in prox definitions

* Minor edit

* Avoid chaining of ScaledFunctional and some code re-organization

* Address review comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants