-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Autodiff stresses #443
Comments
cc #107 |
Cf also #47 (in particular a trick to reduce to 6 instead of 9 DOF, although we probably don't care too much) |
I'd focus on forwarddiff for now. We should be able to work around the errors. Michael did the work of making it work for IntervalArithmetic scalar types so it should hopefully be similar. |
yes please let me know if you need any help on that. I should easily find some nice examples to get you going in case you need any. |
Thanks, starting with ForwardDiff sounds good to me. As I understand it, there's the options of either
My guess would be that both are similar in difficulty but 2) should be preferable for performance, what are your thoughts? |
Performance is not the foremost issue so following IntervalArithmetic sounds good, however we've had quite a bit of issues with the generic FFTs (which are buggy and not actively developed), so if https://github.com/JuliaDiff/ForwardDiff.jl/pull/495/files does the job then great! |
Regarding the stacktrace: Only some FFT sizes work for the generic implementation we have and unless you specify an |
But I agree with Antoine. The generic FFT stuff only works "Mäh", so if we can avoid it, that would probably the better solution long-term. |
Tricky bug, nice catch! |
Some updates on both ForwardDiff approaches: I have iterated on the examples as discussed:
The inclusion of the AtomicNonLocal() term currently leads to NaN derivative results with ForwardDiff in both approaches, while other terms seem to work without further errors at least |
Cool that's great news! So we can actually use finite diff to debug the AtomicNonLocal term. Some ideas how to debug:
Regarding the stacktraces in the second PR ... it appears at least for reverse diff this happens already in the PWBasis setup. I don't really fully get why on a first glance. Let's discuss tmr. |
JuliaLang/julia#27705 has a snippet for yielding an error when a NaN is produced |
We've found the NaN of AtomicNonLocal, it came in due to a bug/inconsistency of ForwardDiff on This fixes the stress of AtomicNonlocal for both ForwardDiff approaches, which also each agree with FiniteDiff. On Approach 2 I also re-enabled the fft normalizations and added the required additional Dual rule for ScaledPlan. After this now both above both ForwardDiff approaches finally agree on the stress of the example system above! |
Interesting. Actually this is a structural zero, ie it comes about by recip_lattice * zeros(3). So norm always gets called on a vector of 0+eps 0, so the non differentiability of norm at zero is not an issue (at least for forward). Can you check it's OK with chainrules? If yes might as well do a quick workaround here and wait for the next gen of forward diff tools. |
This is the current behavior of norm at zero using (Zygote+ChainRules, ForwardDiff) x (Vector, SVector) using Zygote
using ForwardDiff
using StaticArrays
using LinearAlgebra
x = zeros(3)
Zygote.gradient(norm, x)[1]
# 3-element Vector{Float64}:
# 0.0
# 0.0
# 0.0
ForwardDiff.gradient(norm, x)
# 3-element Vector{Float64}:
# 0.0
# 0.0
# 1.0
y = @SVector zeros(3)
Zygote.gradient(norm, y)[1]
# 3-element SVector{3, Float64} with indices SOneTo(3):
# 0.0
# 0.0
# 0.0
ForwardDiff.gradient(norm, y)
# 3-element SVector{3, Float64} with indices SOneTo(3):
# NaN
# NaN
# NaN
# [f6369f11] ForwardDiff v0.10.18
# [90137ffa] StaticArrays v1.2.3
# [e88e6eb3] Zygote v0.6.13 For our use case all results are ok except the NaN since it doesn't cancel out in subsequent multiplication by zero, although I'm surprised by using ChainRules # [082447d4] ChainRules v0.8.13
ChainRules.unthunk(ChainRules.rrule(norm, x)[2](1.0)[2])
# 3-element Vector{Float64}:
# 0.0
# 0.0
# 0.0
ChainRules.unthunk(ChainRules.rrule(norm, y)[2](1.0)[2])
# 3-element SVector{3, Float64} with indices SOneTo(3):
# 0.0
# 0.0
# 0.0
function onehot(i, n)
x = zeros(n)
x[i] = 1.0
x
end
ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, y) # (0.0, 0.0) So a next gen forward diff picking up on ChainRules should indeed fix the problem.
|
Yeah, that just looks like a forwarddiff bug, so either work around it locally or fix it upstream. |
done in #476 |
Opening this to keep track of progress on obtaining stresses via autodiff.
Goal
Calculate the stress as the total derivative of the total energy wrt lattice parameters via automatic differentiation. As this falls under scope of the Hellmann-Feynman theorem, we do not need to differentiate through the full SCF solve but rather only through a post-processing on the final solution
scfres
.We start with the following minimal example of silicon with a single scalar lattice parameter
a
Approach
We plan to try ForwardDiff.jl, ReverseDiff.jl and Zygote.jl.
For stresses only (#params < 10) we expect ForwardDiff to perform best. Going further the reverse modes of ReverseDiff and Zygote are also interesting as they could jointly evaluate stresses and other derivatives of the total energy (eg. forces) more efficiently.
Expected challenges:
Progress
no method matching zero(::String)
(TODO understand stack trace)Related links
An overview of AD tools in Julia: https://juliadiff.org/
Chris Rackauckas on strengths and weaknesses of different AD packages: https://discourse.julialang.org/t/state-of-automatic-differentiation-in-julia/43083/3
Common patterns that need rules in Zygote: https://juliadiff.org/ChainRulesCore.jl/stable/writing_good_rules.html
The text was updated successfully, but these errors were encountered: