-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NaN
gradients for sqrt
#1101
NaN
gradients for sqrt
#1101
Comments
I don't think this is really a bug in It's a little like |
Xref also discussion here: #1036 . It might be possible to regularise all |
That's a good point. To be honest, I'm not sure wrong gradients are better than an error. I would say that for this, if no perfect solution is available, the best solution might be an informative error. Something pointing out that the issue comes from this, and proposing a solution (e.g. just use |
I agree that an error is often better than a NaN. Looks like this has been discussed a bit: https://discourse.julialang.org/t/treating-nan-as-error-helping-debugging/36933 Less ambitiously, something like this could potentially be added only to AD. For instance inserting a function which is by default |
Zygote and PyTorch seem to behave similarily in these cases: gradient(x -> x * sqrt(x),0)
# (NaN,)
gradient(x -> x^(1.5),0)
# (0.0,) PyTorch: x = torch.tensor([0.], requires_grad=True); f = x*torch.sqrt(x); f.backward(); x.grad
# tensor([nan])
x = torch.tensor([0.], requires_grad=True); f = x**(1.5); f.backward(); x.grad
# tensor([0.]) To me (to my math professors, as far as I remember :-)) √x and x / √x are just two different functions. √x = 0 for x = 0 but x / √x is undefined for x = 0 (they are equal almost everywhere but still different). Given that the derivative of sqrt is undefined (in the mathematical sense) for x = 0, having NaN as a results seems quite logical too me. I would not expect any symbolic transformation from Zygote (or PyTorch) to lift this pathological case. |
There might be more clever ways, ForwardDiff's nan-safe mode works around some cases where the simple conclusion would be NaN. Today's discussion here: JuliaDiff/ChainRules.jl#576 |
The derivative of Possible solutions to avoid Inf gradients for
|
After a long time hunting a bug with @facusapienza21, we have realized that Zygote fails to provide a gradient for the basic
sqrt
function. This has been discussed at length in this Discourse thread.Here's a MWE to reproduce the issue:
For this last case, the value of back_θ(1.0) is
NaN
. However, if we avoid the use ofsqrt()
by defining the loss function asthen Zygote provides the right gradient.
According to @mcabbott, "the reason we get
NaN
is that the slope ofsqrt
at zero is infinite. That infinity multiplies the slope of0^x
at 4, which is zero. Whereas with the0^(x/2)
version, the slope is simply zero".Being such a basic function, this bug can potentially impact a large number of users.
The text was updated successfully, but these errors were encountered: