Skip to content

Commit

Permalink
Stresses API (#487)
Browse files Browse the repository at this point in the history
Co-authored-by: Niklas Schmitz <[email protected]>
Co-authored-by: Michael F. Herbst <[email protected]>
  • Loading branch information
3 people authored Jul 27, 2021
1 parent 26df5ee commit acb79f9
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 99 deletions.
7 changes: 5 additions & 2 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ include("energies.jl")
export Hamiltonian
export HamiltonianBlock
export energy_hamiltonian
export compute_forces
export compute_forces_cart
export Kinetic
export ExternalFromFourier
export ExternalFromReal
Expand Down Expand Up @@ -161,6 +159,11 @@ export high_symmetry_kpath
export plot_bandstructure
include("postprocess/band_structure.jl")

export compute_forces
export compute_forces_cart
include("postprocess/forces.jl")
export compute_stresses
include("postprocess/stresses.jl")
export compute_dos
export compute_ldos
export compute_nos
Expand Down
38 changes: 8 additions & 30 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ end

# Lowest-level constructor. All given parameters must be the same on all processors
# and are stored in PlaneWaveBasis for easy reconstruction
@timing function PlaneWaveBasis(model::Model{T}, Ecut::Number,
fft_size, variational,
@timing function PlaneWaveBasis(model::Model{T},
Ecut::Number, fft_size, variational,
kcoords::AbstractVector, ksymops,
kgrid, kshift, symmetries,comm_kpts) where {T <: Real}
kgrid, kshift, symmetries, comm_kpts) where {T <: Real}
if !(all(fft_size .== next_working_fft_size(T, fft_size)))
error("Selected fft_size will not work for the buggy generic " *
"FFT routines; use next_working_fft_size")
Expand Down Expand Up @@ -322,15 +322,15 @@ function PlaneWaveBasis(model::Model;
use_symmetry=true,
kwargs...)
if use_symmetry
kcoords, ksymops, symmetries = bzmesh_ir_wedge(kgrid, model.symmetries, kshift=kshift)
kcoords, ksymops, symmetries = bzmesh_ir_wedge(kgrid, model.symmetries; kshift)
else
kcoords, ksymops, _ = bzmesh_uniform(kgrid, kshift=kshift)
kcoords, ksymops, _ = bzmesh_uniform(kgrid; kshift)
# even when not using symmetry to reduce computations, still
# store in symmetries the set of kgrid-preserving symmetries
symmetries = symmetries_preserving_kgrid(model.symmetries, kcoords)
end
PlaneWaveBasis(model, austrip(Ecut), kcoords, ksymops, symmetries;
kgrid=kgrid, kshift=kshift, kwargs...)
kgrid, kshift, kwargs...)
end

PlaneWaveBasis(model::Model, Ecut; kwargs...) = PlaneWaveBasis(model; Ecut, kwargs...)
Expand Down Expand Up @@ -407,7 +407,7 @@ end
Sum an array over kpoints, taking weights into account
"""
function weighted_ksum(basis::PlaneWaveBasis, array)
res = sum(@. basis.kweights * array)
res = sum(basis.kweights .* array)
mpi_sum(res, basis.comm_kpts)
end

Expand Down Expand Up @@ -543,28 +543,6 @@ function r_to_G_matrix(basis::PlaneWaveBasis{T}) where {T}
ret
end

""""
Convert a `basis` into one that uses or doesn't use BZ symmetrization
Mainly useful for debug purposes (e.g. in cases we don't want to
bother with symmetry)
"""
function PlaneWaveBasis(basis::PlaneWaveBasis; use_symmetry)
use_symmetry && error("Not implemented")
if all(s -> length(s) == 1, basis.ksymops)
return basis
end
kcoords = []
for (ik, kpt) in enumerate(basis.kpoints)
for (S, τ) in basis.ksymops[ik]
push!(kcoords, normalize_kpoint_coordinate(S * kpt.coordinate))
end
end
new_basis = PlaneWaveBasis(basis.model, basis.Ecut, kcoords,
[[identity_symop()] for _ in 1:length(kcoords)];
fft_size=basis.fft_size)
end


"""
Gather the distributed k-Point data on the master process and return
it as a `PlaneWaveBasis`. On the other (non-master) processes `nothing` is returned.
Expand Down Expand Up @@ -625,7 +603,7 @@ function gather_kpts(data::AbstractArray, basis::PlaneWaveBasis)
end
end

# select the occupied orbitals assuming the Aufbau principle
# select the occupied orbitals assuming an insulator
function select_occupied_orbitals(basis::PlaneWaveBasis, ψ)
model = basis.model
n_spin = model.n_spin_components
Expand Down
2 changes: 1 addition & 1 deletion src/bzmesh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ the full mesh. `symmetries` is the tuple returned from
`tol_symmetry` is the tolerance used for searching for symmetry operations.
"""
function bzmesh_ir_wedge(kgrid_size, symmetries; kshift=[0, 0, 0])
all(isequal.(kgrid_size, 1)) && return bzmesh_uniform(kgrid_size, kshift=kshift)
all(isequal.(kgrid_size, 1)) && return bzmesh_uniform(kgrid_size; kshift)

# Transform kshift to the convention used in spglib:
# If is_shift is set (i.e. integer 1), then a shift of 0.5 is performed,
Expand Down
39 changes: 39 additions & 0 deletions src/postprocess/forces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This uses the `compute_forces(term, ψ, occ; kwargs...)` function defined by all terms
"""
Compute the forces of an obtained SCF solution. Returns the forces wrt. the fractional
lattice vectors. To get cartesian forces use [`compute_forces_cart`](@ref).
Returns a list of lists of forces
`[[force for atom in positions] for (element, positions) in atoms]`
which has the same structure as the `atoms` object passed to the underlying [`Model`](@ref).
"""
@timing function compute_forces(basis::PlaneWaveBasis, ψ, occ; kwargs...)
# TODO optimize allocs here
T = eltype(basis)
forces = [zeros(Vec3{T}, length(positions)) for (element, positions) in basis.model.atoms]
for term in basis.terms
f_term = compute_forces(term, ψ, occ; kwargs...)
if !isnothing(f_term)
forces += f_term
end
end
forces
end

"""
Compute the cartesian forces of an obtained SCF solution in Hartree / Bohr.
Returns a list of lists of forces
`[[force for atom in positions] for (element, positions) in atoms]`
which has the same structure as the `atoms` object passed to the underlying [`Model`](@ref).
"""
function compute_forces_cart(basis::PlaneWaveBasis, ψ, occ; kwargs...)
lattice = basis.model.lattice
forces = compute_forces(basis::PlaneWaveBasis, ψ, occ; kwargs...)
[[lattice \ f for f in forces_for_element] for forces_for_element in forces]
end

function compute_forces(scfres)
compute_forces(scfres.basis, scfres.ψ, scfres.occupation; ρ=scfres.ρ)
end
function compute_forces_cart(scfres)
compute_forces_cart(scfres.basis, scfres.ψ, scfres.occupation; ρ=scfres.ρ)
end
35 changes: 35 additions & 0 deletions src/postprocess/stresses.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using ForwardDiff
"""
Compute the stresses (= 1/Vol dE/d(M*lattice), taken at M=I) of an obtained SCF solution.
"""
@timing function compute_stresses(scfres)
# TODO optimize by only computing derivatives wrt 6 independent parameters
scfres = unfold_bz(scfres)
# compute the Hellmann-Feynman energy (with fixed ψ/occ/ρ)
function HF_energy(lattice)
T = eltype(lattice)
basis = scfres.basis
model = basis.model
new_model = Model(lattice;
model.n_electrons,
model.atoms,
magnetic_moments=[], # not used because we give symmetries explicitly
terms=model.term_types,
model.temperature,
model.smearing,
model.spin_polarization,
model.symmetries)
new_basis = PlaneWaveBasis(new_model,
basis.Ecut, basis.fft_size, basis.variational,
basis.kcoords_global, basis.ksymops_global,
basis.kgrid, basis.kshift, basis.symmetries,
basis.comm_kpts)
ρ = DFTK.compute_density(new_basis, scfres.ψ, scfres.occupation)
energies, _ = energy_hamiltonian(new_basis, scfres.ψ, scfres.occupation;
ρ, scfres.eigenvalues, scfres.εF)
energies.total
end
L = scfres.basis.model.lattice
Ω = scfres.basis.model.unit_cell_volume
ForwardDiff.gradient(M -> HF_energy((I+M) * L), zero(L)) / Ω
end
75 changes: 75 additions & 0 deletions src/symmetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,78 @@ function check_symmetric(basis, ρin; tol=1e-10, symmetries=ρin.basis.model.sym
@assert norm(symmetrize(ρin, [symop]) - ρin) < tol
end
end

""""
Convert a `basis` into one that doesn't use BZ symmetry.
This is mainly useful for debug purposes (e.g. in cases we don't want to
bother thinking about symmetries).
"""
function unfold_bz(basis::PlaneWaveBasis)
if all(length.(basis.ksymops_global) .== 1)
return basis
else
kcoords = []
for (ik, kpt) in enumerate(basis.kcoords_global)
for (S, τ) in basis.ksymops_global[ik]
push!(kcoords, normalize_kpoint_coordinate(S * kpt))
end
end
new_basis = PlaneWaveBasis(basis.model,
basis.Ecut, basis.fft_size, basis.variational,
kcoords, [[identity_symop()] for _ in 1:length(kcoords)],
basis.kgrid, basis.kshift, basis.symmetries, basis.comm_kpts)
end
end

# find where in the irreducible basis `basis_irred` the kpoint `kpt_unfolded` is handled
function unfold_mapping(basis_irred, kpt_unfolded)
for ik_irred = 1:length(basis_irred.kpoints)
kpt_irred = basis_irred.kpoints[ik_irred]
for symop in basis_irred.ksymops[ik_irred]
Sk_irred = normalize_kpoint_coordinate(symop[1] * kpt_irred.coordinate)
k_unfolded = normalize_kpoint_coordinate(kpt_unfolded.coordinate)
if (Sk_irred k_unfolded) && (kpt_unfolded.spin == kpt_irred.spin)
return ik_irred, symop
end
end
end
error("Invalid unfolding of BZ")
end

function unfold_array_(basis_irred, basis_unfolded, data, is_ψ)
if basis_irred == basis_unfolded
return data
end
if !(basis_irred.comm_kpts == basis_irred.comm_kpts == MPI.COMM_WORLD)
error("Brillouin zone symmetry unfolding not supported with MPI yet")
end
data_unfolded = similar(data, length(basis_unfolded.kpoints))
for ik_unfolded in 1:length(basis_unfolded.kpoints)
kpt_unfolded = basis_unfolded.kpoints[ik_unfolded]
ik_irred, symop = unfold_mapping(basis_irred, kpt_unfolded)
if is_ψ
# transform ψ_k from data into ψ_Sk in data_unfolded
kunfold_coord = kpt_unfolded.coordinate
@assert normalize_kpoint_coordinate(kunfold_coord) kunfold_coord
_, ψSk = apply_ksymop(symop, basis_irred,
basis_irred.kpoints[ik_irred], data[ik_irred])
data_unfolded[ik_unfolded] = ψSk
else
# simple copy
data_unfolded[ik_unfolded] = data[ik_irred]
end
end
data_unfolded
end

function unfold_bz(scfres)
basis_unfolded = unfold_bz(scfres.basis)
ψ = unfold_array_(scfres.basis, basis_unfolded, scfres.ψ, true)
eigenvalues = unfold_array_(scfres.basis, basis_unfolded, scfres.eigenvalues, false)
occupation = unfold_array_(scfres.basis, basis_unfolded, scfres.occupation, false)
E, ham = energy_hamiltonian(basis_unfolded, ψ, occupation;
scfres.ρ, eigenvalues, scfres.εF)
@assert E.total scfres.energies.total
new_scfres = (; basis=basis_unfolded, ψ, ham, eigenvalues, occupation)
merge(scfres, new_scfres)
end
41 changes: 0 additions & 41 deletions src/terms/terms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,50 +56,9 @@ breaks_symmetries(term_builder::Magnetic) = true
include("anyonic.jl")
breaks_symmetries(term_builder::Anyonic) = true


# forces computes either nothing or an array forces[el][at][α]
compute_forces(term::Term, ψ, occ; kwargs...) = nothing # by default, no force

"""
Compute the forces of an obtained SCF solution. Returns the forces wrt. the fractional
lattice vectors. To get cartesian forces use [`compute_forces_cart`](@ref).
Returns a list of lists of forces
`[[force for atom in positions] for (element, positions) in atoms]`
which has the same structure as the `atoms` object passed to the underlying [`Model`](@ref).
"""
@timing function compute_forces(basis::PlaneWaveBasis, ψ, occ; kwargs...)
# TODO optimize allocs here
T = eltype(basis)
forces = [zeros(Vec3{T}, length(positions)) for (element, positions) in basis.model.atoms]
for term in basis.terms
f_term = compute_forces(term, ψ, occ; kwargs...)
if !isnothing(f_term)
forces += f_term
end
end
forces
end

"""
Compute the cartesian forces of an obtained SCF solution in Hartree / Bohr.
Returns a list of lists of forces
`[[force for atom in positions] for (element, positions) in atoms]`
which has the same structure as the `atoms` object passed to the underlying [`Model`](@ref).
"""
function compute_forces_cart(basis::PlaneWaveBasis, ψ, occ; kwargs...)
lattice = basis.model.lattice
forces = compute_forces(basis::PlaneWaveBasis, ψ, occ; kwargs...)
[[lattice \ f for f in forces_for_element] for forces_for_element in forces]
end

function compute_forces(scfres)
compute_forces(scfres.basis, scfres.ψ, scfres.occupation; ρ=scfres.ρ)
end
function compute_forces_cart(scfres)
compute_forces_cart(scfres.basis, scfres.ψ, scfres.occupation; ρ=scfres.ρ)
end


@doc raw"""
compute_kernel(basis::PlaneWaveBasis; kwargs...)
Expand Down
13 changes: 10 additions & 3 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ end
LinearAlgebra.mul!(Y::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p::AbstractFFTs.ScaledPlan{T,P,<:ForwardDiff.Dual}, X::AbstractArray{<:ComplexF64}) where {T,P} =
(Y .= _apply_plan(p, X))

function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual{T}}}) where T
# TODO do we want x::AbstractArray{<:ForwardDiff.Dual{T}} too?
xtil = p * ForwardDiff.value.(x)
dxtils = ntuple(ForwardDiff.npartials(eltype(x))) do n
p * ForwardDiff.partials.(x, n)
end
T = ForwardDiff.tagtype(eltype(x))
map(xtil, dxtils...) do val, parts...
Complex(
ForwardDiff.Dual{T}(real(val), map(real, parts)),
Expand All @@ -82,14 +82,21 @@ function _apply_plan(p::AbstractFFTs.ScaledPlan{T,P,<:ForwardDiff.Dual}, x::Abst
_apply_plan(p.p, p.scale * x) # for when p.scale is Dual, need out-of-place
end

# this is to avoid method ambiguities between these two:
# _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:ForwardDiff.Dual{T}}}) where T
# _apply_plan(p::AbstractFFTs.ScaledPlan{T,P,<:ForwardDiff.Dual}, x::AbstractArray) where {T,P}
function _apply_plan(p::AbstractFFTs.ScaledPlan{T,P,<:ForwardDiff.Dual}, x::AbstractArray{<:Complex{<:ForwardDiff.Dual{Tg}}}) where {T,P,Tg}
_apply_plan(p.p, p.scale * x)
end

# DFTK setup specific

next_working_fft_size(::Type{<:ForwardDiff.Dual}, size::Int) = size

_fftw_flags(::Type{<:ForwardDiff.Dual}) = FFTW.MEASURE | FFTW.UNALIGNED

function build_fft_plans(T::Type{<:Union{ForwardDiff.Dual,Complex{<:ForwardDiff.Dual}}}, fft_size)
tmp = Array{Complex{T}}(undef, fft_size...)
tmp = Array{complex(T)}(undef, fft_size...) # TODO think about other Array types
opFFT = FFTW.plan_fft(tmp, flags=_fftw_flags(T))
opBFFT = FFTW.plan_bfft(tmp, flags=_fftw_flags(T))

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Random.seed!(0)
end

if "all" in TAGS
include("stresses.jl")
mpi_nprocs() == 1 && include("stresses.jl")
end

("example" in TAGS) && include("runexamples.jl")
Expand Down
Loading

0 comments on commit acb79f9

Please sign in to comment.