Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hellmann-Feynman stresses via ForwardDiff and custom rules #476

Merged
merged 52 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d442e91
add kinetic stress autodiff example
niklasschmitz Jun 7, 2021
2074095
add stress-total error messages of FD, RD, Zygote
niklasschmitz Jun 7, 2021
cafed13
add stress-total scalar test
niklasschmitz Jun 7, 2021
38db38f
disable lattice cond check temporarily
niklasschmitz Jun 14, 2021
797bf34
add model_atomic_debug with Kinetic() only
niklasschmitz Jun 14, 2021
223dc24
stacktrace: no next_working_fft_size for Duals
niklasschmitz Jun 14, 2021
abe7c7b
stress-kinetic with generic fft forwarddiff works
niklasschmitz Jun 14, 2021
4861677
move model term selection to call site
niklasschmitz Jun 14, 2021
4bea529
generic fft: Kinetic, AtomicLocal, Ewald, Psp work
niklasschmitz Jun 14, 2021
30dbc98
direct return
niklasschmitz Jun 14, 2021
a5660f8
use make_basis in stress-total
niklasschmitz Jun 14, 2021
746aad6
rename stress-forward-genericlinearalgebra
niklasschmitz Jun 14, 2021
cdad8eb
stack trace: no next_working_fft_size for Dual
niklasschmitz Jun 14, 2021
324918f
stack trace: no build_fft_plans for Dual
niklasschmitz Jun 14, 2021
9e772a5
reanme stress-forward-genericfft
niklasschmitz Jun 14, 2021
e7c4105
add generic-fft stress btimes
niklasschmitz Jun 14, 2021
5512526
add (a few) FFTW ForwardDiff rules
niklasschmitz Jun 14, 2021
1a5b54c
stack trace: no mul!(...Dual, ...FFTW.cFFTWPlan,.)
niklasschmitz Jun 14, 2021
42cb076
fwddiff FFTW: Kinetic, AtomicLocal, Ewald, Psp
niklasschmitz Jun 15, 2021
791a5a8
add fwddiff FFTW btimes
niklasschmitz Jun 15, 2021
320ce60
add description header
niklasschmitz Jun 15, 2021
57331b0
update generic fft stress values
niklasschmitz Jun 15, 2021
3da7d29
del forward-genericfft in favor of seperate branch
niklasschmitz Jun 15, 2021
7a0dcf3
update stack traces for Zygote, ReverseDiff
niklasschmitz Jun 15, 2021
e763015
stack trace: one carbon
niklasschmitz Jun 17, 2021
34b5129
stack trace: AtomicNonlocal NaN ForwardDiff
niklasschmitz Jun 17, 2021
5cf5a2c
fix AtomicNonlocal NaN
niklasschmitz Jun 17, 2021
eb89f36
fix fft normalization & dual-scaled plans
niklasschmitz Jun 17, 2021
cb1b968
add ForwardDiff norm of SVector workaround
niklasschmitz Jun 21, 2021
6eafdef
add total stress result of ForwardDiff
niklasschmitz Jun 21, 2021
b0dcf5e
add model_DFT ForwardDiff stress
niklasschmitz Jun 28, 2021
11288af
Merge branch 'JuliaMolSim:master' into autodiff-stress
niklasschmitz Jun 30, 2021
631b591
move forwarddiff rules to workarounds dir
niklasschmitz Jun 30, 2021
00078e3
add comments on r_to_G on duals
niklasschmitz Jun 30, 2021
b878050
add lda_x
niklasschmitz Jul 6, 2021
51593f6
Merge branch 'JuliaMolSim:master' into autodiff-stress
niklasschmitz Jul 8, 2021
65aad15
delete scratch files
niklasschmitz Jul 8, 2021
e8c9b5f
add silicon stress testcase
niklasschmitz Jul 8, 2021
80f30bd
revert project toml comment
niklasschmitz Jul 8, 2021
21e39c9
rm hardcoded fft_size
niklasschmitz Jul 8, 2021
badd501
tighten test tolerances
niklasschmitz Jul 8, 2021
f377c5a
add spglib dual rule
niklasschmitz Jul 8, 2021
d00aed5
re-enable cond check
niklasschmitz Jul 8, 2021
23a77a8
fix hellmann-feynman: recompute density from psi
niklasschmitz Jul 8, 2021
a181ed2
move FiniteDiff to test deps
niklasschmitz Jul 8, 2021
3c57d6b
apply suggestions from review I
niklasschmitz Jul 13, 2021
f1ef7ee
avoid code duplication of DummyInplace
niklasschmitz Jul 13, 2021
8be63da
use silicon numbers from testcases.jl
niklasschmitz Jul 15, 2021
6a821bd
use promote_type between basis and psi
niklasschmitz Jul 15, 2021
d90f8ca
Merge branch 'master' into stress-fwddiff-abstractffts
niklasschmitz Jul 15, 2021
30ada7f
line break
niklasschmitz Jul 15, 2021
cb466c1
d to n_dim
niklasschmitz Jul 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ JLD2 = "0.4 - 0.4.7"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Expand All @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"

[targets]
test = ["Test", "Aqua", "DoubleFloats", "GenericLinearAlgebra", "IntervalArithmetic", "Plots", "Random", "KrylovKit", "JLD2", "WriteVTK"]
test = ["Test", "Aqua", "DoubleFloats", "FiniteDiff", "GenericLinearAlgebra", "IntervalArithmetic", "Plots", "Random", "KrylovKit", "JLD2", "WriteVTK"]
10 changes: 8 additions & 2 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,19 @@ function __init__()
# DoubleFloats has been loaded (via a "using" or an "import").
# See https://github.com/JuliaPackaging/Requires.jl for details.
#
# The global variable GENERIC_FFT_LOADED makes sure that things are only
# included once.
# The global variables GENERIC_FFT_LOADED and DUMMY_INPLACE_LOADED
# make sure that things are only included once.
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
!isdefined(DFTK, :DUMMY_INPLACE_LOADED) && include("workarounds/dummy_inplace_fft.jl")
include("workarounds/forwarddiff_rules.jl")
end
@require IntervalArithmetic="d1acc4aa-44c8-5952-acd4-ba5d80a2a253" begin
include("workarounds/intervals.jl")
!isdefined(DFTK, :DUMMY_INPLACE_LOADED) && include("workarounds/dummy_inplace_fft.jl")
mfherbst marked this conversation as resolved.
Show resolved Hide resolved
!isdefined(DFTK, :GENERIC_FFT_LOADED) && include("workarounds/fft_generic.jl")
end
@require DoubleFloats="497a8b3b-efae-58df-a0af-a86822472b78" begin
!isdefined(DFTK, :DUMMY_INPLACE_LOADED) && include("workarounds/dummy_inplace_fft.jl")
!isdefined(DFTK, :GENERIC_FFT_LOADED) && include("workarounds/fft_generic.jl")
end
@require Plots="91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plotting.jl")
Expand Down
6 changes: 4 additions & 2 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ function Model(lattice::AbstractMatrix{T};
norm(lattice[:, i]) == norm(lattice[i, :]) == 0 || error(
"For 1D and 2D systems, the non-empty dimensions must come first")
end
cond(lattice[1:n_dim, 1:n_dim]) > 1e-5 || (
@warn "Your lattice is badly conditioned, the computation is likely to fail.")
_check_well_conditioned(lattice[1:n_dim, 1:n_dim]) || @warn (
"Your lattice is badly conditioned, the computation is likely to fail.")

# Compute reciprocal lattice and volumes.
# recall that the reciprocal lattice is the set of G vectors such
Expand Down Expand Up @@ -213,3 +213,5 @@ function spin_components(spin_polarization::Symbol)
spin_polarization == :full && return (:undefined, )
end
spin_components(model::Model) = spin_components(model.spin_polarization)

_check_well_conditioned(A; tol=1e5) = (cond(A) <= tol)
5 changes: 3 additions & 2 deletions src/densities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ is not collinear the spin density is `nothing`.
@assert n_k > 0

# Allocate an accumulator for ρ in each thread for each spin component
ρaccus = [similar(view(ψ[1], :, 1), (basis.fft_size..., n_spin))
T = promote_type(eltype(basis), eltype(ψ[1]))
ρaccus = [similar(ψ[1], T, (basis.fft_size..., n_spin))
for ithread in 1:Threads.nthreads()]

# TODO Better load balancing ... the workload per kpoint depends also on
Expand All @@ -79,7 +80,7 @@ is not collinear the spin density is `nothing`.

Threads.@threads for (ikpts, ρaccu) in collect(zip(kpt_per_thread, ρaccus))
ρaccu .= 0
ρ_k = similar(ψ[1][:, 1], basis.fft_size)
ρ_k = similar(ψ[1], T, basis.fft_size)
for ik in ikpts
kpt = basis.kpoints[ik]
compute_partial_density!(ρ_k, basis, kpt, ψ[ik], occupation[ik])
Expand Down
15 changes: 15 additions & 0 deletions src/workarounds/dummy_inplace_fft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This is needed to flag that the dummy_inplace_fft.jl file has already been loaded
const DUMMY_INPLACE_LOADED = true

# A dummy wrapper around an out-of-place FFT plan to make it appear in-place
# This is needed for some generic FFT implementations, which do not have in-place plans
struct DummyInplace{opFFT}
fft::opFFT
end
LinearAlgebra.mul!(Y, p::DummyInplace, X) = (Y .= mul!(similar(X), p.fft, X))
LinearAlgebra.ldiv!(Y, p::DummyInplace, X) = (Y .= ldiv!(similar(X), p.fft, X))

import Base: *, \, length
*(p::DummyInplace, X) = p.fft * X
\(p::DummyInplace, X) = p.fft \ X
length(p::DummyInplace) = length(p.fft)
14 changes: 0 additions & 14 deletions src/workarounds/fft_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,3 @@ function generic_plan_bfft(data::AbstractArray{T, 3}) where T
FourierTransforms.plan_bfft(data[1, :, 1]),
FourierTransforms.plan_bfft(data[1, 1, :])], T(1))
end


# A dummy wrapper around an out-of-place FFT plan to make it appear in-place
# This is needed for some generic FFT implementations, which do not have in-place plans
struct DummyInplace{opFFT}
fft::opFFT
end
LinearAlgebra.mul!(Y, p::DummyInplace, X) = (Y .= mul!(similar(X), p.fft, X))
LinearAlgebra.ldiv!(Y, p::DummyInplace, X) = (Y .= ldiv!(similar(X), p.fft, X))

import Base: *, \, length
*(p::DummyInplace, X) = p.fft * X
\(p::DummyInplace, X) = p.fft \ X
length(p::DummyInplace) = length(p.fft)
137 changes: 137 additions & 0 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import ForwardDiff
import AbstractFFTs

mfherbst marked this conversation as resolved.
Show resolved Hide resolved
# original PR by mcabbott: https://github.com/JuliaDiff/ForwardDiff.jl/pull/495

ForwardDiff.value(x::Complex{<:ForwardDiff.Dual}) = Complex(x.re.value, x.im.value)

ForwardDiff.partials(x::Complex{<:ForwardDiff.Dual}, n::Int) =
Complex(ForwardDiff.partials(x.re, n), ForwardDiff.partials(x.im, n))

ForwardDiff.npartials(x::Complex{<:ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N
ForwardDiff.npartials(::Type{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N} = N

ForwardDiff.tagtype(x::Complex{<:ForwardDiff.Dual{T,V,N}}) where {T,V,N} = T
ForwardDiff.tagtype(::Type{<:Complex{<:ForwardDiff.Dual{T,V,N}}}) where {T,V,N} = T

# AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = float.(x .+ 0im)
AbstractFFTs.complexfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.complexfloat.(x)
AbstractFFTs.complexfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d) + 0im

AbstractFFTs.realfloat(x::AbstractArray{<:ForwardDiff.Dual}) = AbstractFFTs.realfloat.(x)
AbstractFFTs.realfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = convert(ForwardDiff.Dual{T,float(V),N}, d)

for plan in [:plan_fft, :plan_ifft, :plan_bfft]
@eval begin
AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x); kwargs...) =
AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region; kwargs...)

AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, region=1:ndims(x); kwargs...) =
AbstractFFTs.$plan(ForwardDiff.value.(x), region; kwargs...)
end
end

# rfft only accepts real arrays
AbstractFFTs.plan_rfft(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x); kwargs...) =
AbstractFFTs.plan_rfft(ForwardDiff.value.(x), region; kwargs...)

for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex?
@eval begin
AbstractFFTs.$plan(x::AbstractArray{<:ForwardDiff.Dual}, region=1:ndims(x); kwargs...) =
AbstractFFTs.$plan(ForwardDiff.value.(x) .+ 0im, region; kwargs...)

AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, d::Integer, region=1:ndims(x); kwargs...) =
AbstractFFTs.$plan(ForwardDiff.value.(x), d, region; kwargs...)
end
end

for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities
@eval begin
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:ForwardDiff.Dual}) =
_apply_plan(p, x)

Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:ForwardDiff.Dual}}) =
_apply_plan(p, x)

LinearAlgebra.mul!(Y::AbstractArray, p::AbstractFFTs.$P, X::AbstractArray{<:ForwardDiff.Dual}) =
(Y .= _apply_plan(p, X))

LinearAlgebra.mul!(Y::AbstractArray, p::AbstractFFTs.$P, X::AbstractArray{<:Complex{<:ForwardDiff.Dual}}) =
(Y .= _apply_plan(p, X))
end
end

LinearAlgebra.mul!(Y::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p::AbstractFFTs.ScaledPlan{T,P,<:ForwardDiff.Dual}, X::AbstractArray{<:ComplexF64}) where {T,P} =
(Y .= _apply_plan(p, X))

function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
xtil = p * ForwardDiff.value.(x)
dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n
p * ForwardDiff.partials.(x, n)
end
T = ForwardDiff.tagtype(eltype(x))
map(xtil, dxtils...) do val, parts...
Complex(
ForwardDiff.Dual{T}(real(val), map(real, parts)),
ForwardDiff.Dual{T}(imag(val), map(imag, parts)),
)
end
end

function _apply_plan(p::AbstractFFTs.ScaledPlan{T,P,<:ForwardDiff.Dual}, x::AbstractArray) where {T,P}
_apply_plan(p.p, p.scale * x) # for when p.scale is Dual, need out-of-place
end

# DFTK setup specific

next_working_fft_size(::Type{<:ForwardDiff.Dual}, size) = size

_fftw_flags(::Type{<:ForwardDiff.Dual}) = FFTW.MEASURE | FFTW.UNALIGNED

function build_fft_plans(T::Type{<:Union{ForwardDiff.Dual,Complex{<:ForwardDiff.Dual}}}, fft_size)
tmp = Array{Complex{T}}(undef, fft_size...)
opFFT = FFTW.plan_fft(tmp, flags=_fftw_flags(T))
opBFFT = FFTW.plan_bfft(tmp, flags=_fftw_flags(T))

ipFFT = DummyInplace{typeof(opFFT)}(opFFT)
ipBFFT = DummyInplace{typeof(opBFFT)}(opBFFT)
# backward by inverting and stripping off normalizations
ipFFT, opFFT, ipBFFT, opBFFT
end

# PlaneWaveBasis{<:Dual} contains dual-scaled fft, which means that the result f_fourier
# must be able to hold complex dual numbers even if f_real is not dual
function r_to_G(basis::PlaneWaveBasis{T}, f_real::AbstractArray) where {T<:ForwardDiff.Dual}
f_fourier = similar(f_real, complex(T))
@assert length(size(f_real)) ∈ (3, 4)
# this exploits trailing index convention
for iσ = 1:size(f_real, 4)
@views r_to_G!(f_fourier[:, :, :, iσ], basis, f_real[:, :, :, iσ])
end
f_fourier
end

# determine symmetry operations only from primal lattice values
function spglib_get_symmetry(lattice::Matrix{<:ForwardDiff.Dual}, atoms, magnetic_moments=[]; kwargs...)
spglib_get_symmetry(ForwardDiff.value.(lattice), atoms, magnetic_moments; kwargs...)
end

function _check_well_conditioned(A::AbstractArray{<:ForwardDiff.Dual}; kwargs...)
_check_well_conditioned(ForwardDiff.value.(A); kwargs...)
end


# other workarounds

# problem: ForwardDiff of norm of SVector gives NaN derivative at zero
# https://github.com/JuliaMolSim/DFTK.jl/issues/443#issuecomment-864930410
# solution: follow ChainRules custom frule for norm
# https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/LinearAlgebra/norm.jl#L5
# TODO delete, once forward diff AD tools use ChainRules natively
function LinearAlgebra.norm(x::SVector{S,<:ForwardDiff.Dual}) where {S}
T = ForwardDiff.tagtype(eltype(x))
dx = ForwardDiff.partials.(x)
y = norm(ForwardDiff.value.(x))
dy = real(dot(ForwardDiff.value.(x), dx)) * pinv(y)
ForwardDiff.Dual{T}(y, dy)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,9 @@ Random.seed!(0)
include("aqua.jl")
end

if "all" in TAGS
include("stresses.jl")
end

("example" in TAGS) && include("runexamples.jl")
end
47 changes: 47 additions & 0 deletions test/stresses.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Test
using DFTK
using ForwardDiff
import FiniteDiff
include("testcases.jl")

# Hellmann-Feynman stress
# via ForwardDiff & custom FFTW overloads on ForwardDiff.Dual

@testset "ForwardDiff stresses on silicon" begin
function make_basis(a)
lattice = a / 2 * [[0 1 1.];
[1 0 1.];
[1 1 0.]]
Si = ElementPsp(silicon.atnum, psp=load_psp(silicon.psp))
atoms = [Si => silicon.positions]
model = model_DFT(lattice, atoms, [:lda_x, :lda_c_vwn])
kgrid = [1, 1, 1]
Ecut = 7
PlaneWaveBasis(model, Ecut; kgrid=kgrid)
end

function recompute_energy(a)
basis = make_basis(a)
scfres = self_consistent_field(basis, is_converged=DFTK.ScfConvergenceDensity(1e-13))
energies, H = energy_hamiltonian(basis, scfres.ψ, scfres.occupation; ρ=scfres.ρ)
energies.total
end

function hellmann_feynman_energy(scfres_ref, a)
basis = make_basis(a)
ρ = DFTK.compute_density(basis, scfres_ref.ψ, scfres_ref.occupation)
energies, H = energy_hamiltonian(basis, scfres_ref.ψ, scfres_ref.occupation; ρ=ρ)
energies.total
end

a = 10.26
scfres = self_consistent_field(make_basis(a), is_converged=DFTK.ScfConvergenceDensity(1e-13))
hellmann_feynman_energy(a) = hellmann_feynman_energy(scfres, a)

ref_recompute = FiniteDiff.finite_difference_derivative(recompute_energy, a)
ref_hf = FiniteDiff.finite_difference_derivative(hellmann_feynman_energy, a)
s_hf = ForwardDiff.derivative(hellmann_feynman_energy, a)

@test isapprox(ref_hf, ref_recompute, atol=1e-4)
@test isapprox(s_hf, ref_hf, atol=1e-8)
end