Skip to content

Commit

Permalink
Merge pull request #13 from BoothGroup/integral_fock
Browse files Browse the repository at this point in the history
Migrates the fock build into the Integrals class
  • Loading branch information
obackhouse authored Aug 7, 2023
2 parents 5f1b5be + 03071df commit a9845a7
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 78 deletions.
5 changes: 3 additions & 2 deletions momentGW/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from pyscf import lib
from pyscf.agf2 import mpi_helper
from pyscf.lib import logger
from pyscf.mp.mp2 import get_frozen_mask, get_nmo, get_nocc

Expand Down Expand Up @@ -87,8 +88,8 @@ def __init__(self, mf, **kwargs):
setattr(self, key, val)

# Do not modify:
self.mo_energy = mf.mo_energy
self.mo_coeff = mf.mo_coeff
self.mo_energy = mpi_helper.bcast(mf.mo_energy, root=0)
self.mo_coeff = mpi_helper.bcast(mf.mo_coeff, root=0)
self.mo_occ = mf.mo_occ
self.frozen = None
self._nocc = None
Expand Down
1 change: 1 addition & 0 deletions momentGW/evgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def kernel(

# Get the static part of the SE
se_static = gw.build_se_static(
integrals,
mo_energy=mo_energy,
mo_coeff=mo_coeff,
)
Expand Down
52 changes: 6 additions & 46 deletions momentGW/fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,11 @@
from momentGW import util


def get_j(Lpq, dm):
"""
Build the J matrix. Lpq may be distributed along the final index.
"""

nmo = dm.shape[-1]
p0, p1 = list(mpi_helper.prange(0, nmo, nmo))[0]
vj = np.zeros_like(dm)

tmp = lib.einsum("Qkl,lk->Q", Lpq, dm[p0:p1])
tmp = mpi_helper.allreduce(tmp)
vj[:, p0:p1] = lib.einsum("Qij,Q->ij", Lpq, tmp)
vj = mpi_helper.allreduce(vj)

return vj


def get_k(Lpq, dm):
"""
Build the K matrix. Lpq may be distributed along the final index.
"""

nmo = dm.shape[-1]
p0, p1 = list(mpi_helper.prange(0, nmo, nmo))[0]
vk = np.zeros_like(dm)

tmp = lib.einsum("Qik,kl->Qil", Lpq, dm[p0:p1])
tmp = mpi_helper.allreduce(tmp)
vk[:, p0:p1] = lib.einsum("Qil,Qlj->ij", tmp, Lpq)
vk = mpi_helper.allreduce(vk)

return vk


def get_jk(Lpq, dm):
return get_j(Lpq, dm), get_k(Lpq, dm)


def get_fock(Lpq, dm, h1e):
vj, vk = get_jk(Lpq, dm)
return h1e + vj - vk * 0.5


def fock_loop(
gw,
Lpq,
gf,
se,
integrals=None,
fock_diis_space=10,
fock_diis_min_space=1,
conv_tol_nelec=1e-6,
Expand All @@ -70,6 +27,9 @@ def fock_loop(
consistent field.
"""

if integrals is None:
integrals = gw.ao2mo()

h1e = np.linalg.multi_dot((gw.mo_coeff.T, gw._scf.get_hcore(), gw.mo_coeff))
nmo = gw.nmo
nocc = gw.nocc
Expand All @@ -82,7 +42,7 @@ def fock_loop(
diis.min_space = fock_diis_min_space
gf_to_dm = lambda gf: gf.get_occupied().moment(0) * 2.0
rdm1 = gf_to_dm(gf)
fock = get_fock(Lpq, rdm1, h1e)
fock = integrals.get_fock(rdm1, h1e)

buf = np.zeros((nqmo, nqmo))
converged = False
Expand All @@ -104,7 +64,7 @@ def fock_loop(
gf = gf.__class__(w, v[:nmo], chempot=se.chempot)

rdm1 = gf_to_dm(gf)
fock = get_fock(Lpq, rdm1, h1e)
fock = integrals.get_fock(rdm1, h1e)
fock = diis.update(fock, xerr=None)

if niter2 > 1:
Expand Down
30 changes: 14 additions & 16 deletions momentGW/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyscf.ao2mo import _ao2mo
from pyscf.lib import logger

from momentGW import util
from momentGW.base import BaseGW
from momentGW.fock import fock_loop
from momentGW.ints import Integrals
Expand Down Expand Up @@ -65,6 +66,7 @@ def kernel(

# Get the static part of the SE
se_static = gw.build_se_static(
integrals,
mo_energy=mo_energy,
mo_coeff=mo_coeff,
)
Expand Down Expand Up @@ -96,12 +98,14 @@ class GW(BaseGW):
def name(self):
return "G0W0"

def build_se_static(self, mo_coeff=None, mo_energy=None):
def build_se_static(self, integrals, mo_coeff=None, mo_energy=None):
"""Build the static part of the self-energy, including the
Fock matrix.
Parameters
----------
integrals : Integrals
Density-fitted integrals.
mo_energy : numpy.ndarray, optional
Molecular orbital energies. Default value is that of
`self.mo_energy`.
Expand All @@ -121,19 +125,16 @@ def build_se_static(self, mo_coeff=None, mo_energy=None):
if mo_energy is None:
mo_energy = self.mo_energy

with lib.temporary_env(self._scf, verbose=0):
with lib.temporary_env(self._scf.with_df, verbose=0):
v_mf = self._scf.get_veff() - self._scf.get_j()
if getattr(self._scf, "xc", "hf") == "hf":
se_static = np.zeros((self.nmo, self.nmo))
else:
with util.SilentSCF(self._scf):
vmf = self._scf.get_j() - self._scf.get_veff()
dm = self._scf.make_rdm1(mo_coeff=mo_coeff)
v_mf = lib.einsum("pq,pi,qj->ij", v_mf, mo_coeff, mo_coeff)
vk = integrals.get_k(dm, basis="ao")

with lib.temporary_env(self._scf.with_df, verbose=0):
with lib.temporary_env(self._scf.with_df, verbose=0):
vk = scf.hf.SCF.get_veff(self._scf, self.mol, dm)
vk -= scf.hf.SCF.get_j(self._scf, self.mol, dm)
vk = lib.einsum("pq,pi,qj->ij", vk, mo_coeff, mo_coeff)

se_static = vk - v_mf
se_static = vmf - vk * 0.5
se_static = lib.einsum("pq,pi,qj->ij", se_static, mo_coeff, mo_coeff)

if self.diagonal_se:
se_static = np.diag(np.diag(se_static))
Expand Down Expand Up @@ -255,11 +256,8 @@ def solve_dyson(self, se_moments_hole, se_moments_part, se_static, integrals=Non
gf.coupling = mpi_helper.bcast(gf.coupling, root=0)

if self.fock_loop:
if integrals is None:
raise ValueError("Lpq must be passed to solve_dyson if fock_loop=True")

try:
gf, se, conv = fock_loop(self, integrals.Lpq, gf, se, **self.fock_opts)
gf, se, conv = fock_loop(self, gf, se, integrals=integrals, **self.fock_opts)
except IndexError:
pass

Expand Down
110 changes: 110 additions & 0 deletions momentGW/ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,33 @@
Integral helpers.
"""

import contextlib
import types

import numpy as np
from pyscf import lib
from pyscf.agf2 import mpi_helper
from pyscf.lib import logger


@contextlib.contextmanager
def patch_df_loop(with_df):
"""
Context manager for monkey patching PySCF's density fitting objects
to loop over blocks of the auxiliary functions distributed over MPI.
"""

def prange(self, start, stop, end):
yield from mpi_helper.prange(start, stop, end)

pre_patch = with_df.prange
setattr(with_df, "prange", types.MethodType(prange, with_df))

yield with_df

setattr(with_df, "prange", pre_patch)


class Integrals:
"""
Container for the integrals required for GW methods.
Expand Down Expand Up @@ -43,6 +64,7 @@ def get_compression_metric(self):
Return the compression metric.
"""
# TODO cache this if needed
return None

compression = self.compression.replace("vo", "ov")
compression = set(x for x in compression.split(","))
Expand Down Expand Up @@ -120,6 +142,8 @@ def transform(self, do_Lpq=None, do_Lpx=True, do_Lia=True):
if self._rot is None:
self._rot = self.get_compression_metric()
rot = self._rot
if rot is None:
rot = np.eye(self.naux_full)

do_Lpq = self.store_full if do_Lpq is None else do_Lpq
if not any([do_Lpq, do_Lpx, do_Lia]):
Expand Down Expand Up @@ -201,6 +225,83 @@ def update_coeffs(self, mo_coeff_g=None, mo_coeff_w=None, mo_occ_w=None):
do_Lia=mo_coeff_w is not None,
)

def get_j(self, dm, basis="mo"):
"""Build the J matrix."""

assert basis in ("ao", "mo")

p0, p1 = list(mpi_helper.prange(0, self.nmo, self.nmo))[0]
vj = np.zeros_like(dm, dtype=np.result_type(dm, self.dtype))

if self.store_full and basis == "mo":
tmp = lib.einsum("Qkl,lk->Q", self.Lpq, dm[p0:p1])
tmp = mpi_helper.allreduce(tmp)
vj[:, p0:p1] = lib.einsum("Qij,Q->ij", self.Lpq, tmp)
vj = mpi_helper.allreduce(vj)

else:
if basis == "mo":
dm = np.linalg.multi_dot((self.mo_coeff, dm, self.mo_coeff.T))

with patch_df_loop(self.with_df):
for block in self.with_df.loop():
naux = block.shape[0]
if block.size == naux * self.nmo * (self.nmo + 1) // 2:
block = lib.unpack_tril(block)
block = block.reshape(naux, self.nmo, self.nmo)

tmp = lib.einsum("Qkl,lk->Q", block, dm)
vj += lib.einsum("Qij,Q->ij", block, tmp)

vj = mpi_helper.allreduce(vj)
if basis == "mo":
vj = np.linalg.multi_dot((self.mo_coeff.T, vj, self.mo_coeff))

return vj

def get_k(self, dm, basis="mo"):
"""Build the K matrix."""

assert basis in ("ao", "mo")

p0, p1 = list(mpi_helper.prange(0, self.nmo, self.nmo))[0]
vk = np.zeros_like(dm, dtype=np.result_type(dm, self.dtype))

if self.store_full and basis == "mo":
tmp = lib.einsum("Qik,kl->Qil", self.Lpq, dm[p0:p1])
tmp = mpi_helper.allreduce(tmp)
vk[:, p0:p1] = lib.einsum("Qil,Qlj->ij", tmp, self.Lpq)
vk = mpi_helper.allreduce(vk)

else:
if basis == "mo":
dm = np.linalg.multi_dot((self.mo_coeff, dm, self.mo_coeff.T))

with patch_df_loop(self.with_df):
for block in self.with_df.loop():
naux = block.shape[0]
if block.size == naux * self.nmo * (self.nmo + 1) // 2:
block = lib.unpack_tril(block)
block = block.reshape(naux, self.nmo, self.nmo)

tmp = lib.einsum("Qik,kl->Qil", block, dm)
vk += lib.einsum("Qil,Qlj->ij", tmp, block)

vk = mpi_helper.allreduce(vk)
if basis == "mo":
vk = np.linalg.multi_dot((self.mo_coeff.T, vk, self.mo_coeff))

return vk

def get_jk(self, dm, **kwargs):
"""Build the J and K matrices."""
return self.get_j(dm, **kwargs), self.get_k(dm, **kwargs)

def get_fock(self, dm, h1e, **kwargs):
"""Build the Fock matrix."""
vj, vk = self.get_jk(dm, **kwargs)
return h1e + vj - vk * 0.5

@property
def Lpq(self):
"""
Expand Down Expand Up @@ -300,6 +401,8 @@ def naux(self):
Return the number of auxiliary basis functions, after the
compression.
"""
if self._rot is None:
return self.naux_full
return self._rot.shape[1]

@property
Expand All @@ -317,3 +420,10 @@ def is_bare(self):
no self-consistencies.
"""
return self._mo_coeff_g is None and self._mo_coeff_w is None

@property
def dtype(self):
"""
Return the dtype of the integrals.
"""
return np.result_type(*self._blocks.values())
13 changes: 6 additions & 7 deletions momentGW/qsgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def kernel(
if gw.polarizability == "drpa-exact":
raise NotImplementedError("%s for polarizability=%s" % (gw.name, gw.polarizability))

if integrals is None:
integrals = gw.ao2mo()

nmo = gw.nmo
nocc = gw.nocc
naux = gw.with_df.get_naoaux()
Expand Down Expand Up @@ -123,17 +126,13 @@ def project_basis(m, c1, c2):
diis_qp.space = gw.diis_space_qp
mo_energy_prev = mo_energy.copy()
for qp_cycle in range(1, gw.max_cycle_qp + 1):
dm_ao = np.linalg.multi_dot((mo_coeff_ref, dm, mo_coeff_ref.T))
with lib.temporary_env(gw._scf.with_df, verbose=0):
j, k = gw._scf.get_jk(dm=dm_ao)
j = np.linalg.multi_dot((mo_coeff_ref.T, j, mo_coeff_ref))
k = np.linalg.multi_dot((mo_coeff_ref.T, k, mo_coeff_ref))

fock_eff = h1e + j - 0.5 * k + se_qp
fock = integrals.get_fock(dm, h1e)
fock_eff = fock + se_qp
fock_eff = diis_qp.update(fock_eff)
fock_eff = mpi_helper.bcast(fock_eff, root=0)

mo_energy, u = np.linalg.eigh(fock_eff)
u = mpi_helper.bcast(u, root=0)
mo_coeff = np.dot(mo_coeff_ref, u)

dm_prev = dm
Expand Down
1 change: 1 addition & 0 deletions momentGW/scgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def kernel(

# Get the static part of the SE
se_static = gw.build_se_static(
integrals,
mo_energy=mo_energy,
mo_coeff=mo_coeff,
)
Expand Down
Loading

0 comments on commit a9845a7

Please sign in to comment.