Skip to content

Commit

Permalink
Davidson iterations for tddft on GPU (#305)
Browse files Browse the repository at this point in the history
* some mmodifications

* test

* finish writting, start debugging

* Finish debugging and unit tests.

* remove some comments and unused codes

* after review the codes

* change the threshold in precond

* add the import _response_functions

* change codes according to review comments
  • Loading branch information
puzhichen authored Jan 20, 2025
1 parent 86ca248 commit 6ea5c6c
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 127 deletions.
209 changes: 124 additions & 85 deletions gpu4pyscf/tdscf/_lr_eig.py

Large diffs are not rendered by default.

35 changes: 24 additions & 11 deletions gpu4pyscf/tdscf/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@

import numpy as np
import cupy as cp
import scipy.linalg
from pyscf import gto
from pyscf import lib
from pyscf.tdscf import rhf as tdhf_cpu
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh, eig as lr_eig, real_eig
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh, real_eig
from gpu4pyscf import scf
from gpu4pyscf.lib.cupy_helper import contract, tag_array
from gpu4pyscf.lib import utils
Expand Down Expand Up @@ -53,7 +51,7 @@ def gen_tda_operation(mf, fock_ao=None, singlet=True, wfnsym=None):
orbo2 = orbo * 2. # *2 for double occupancy

e_ia = hdiag = mo_energy[viridx] - mo_energy[occidx,None]
hdiag = hdiag.ravel().get()
hdiag = hdiag.ravel()
vresp = mf.gen_response(singlet=singlet, hermi=0)
nocc, nvir = e_ia.shape

Expand All @@ -66,7 +64,7 @@ def vind(zs):
v1mo = contract('xpq,qo->xpo', v1ao, orbo)
v1mo = contract('xpo,pv->xov', v1mo, orbv.conj())
v1mo += zs * e_ia
return v1mo.reshape(v1mo.shape[0],-1).get()
return v1mo.reshape(v1mo.shape[0],-1)

return vind, hdiag

Expand Down Expand Up @@ -100,11 +98,15 @@ class TDBase(lib.StreamObject):
get_ab = NotImplemented

def get_precond(self, hdiag):
threshold_t=1.0e-4
def precond(x, e, *args):
if isinstance(e, np.ndarray):
e = e[0]
n_states = x.shape[0]
diagd = cp.repeat(hdiag.reshape(1,-1), n_states, axis=0)
e = e.reshape(-1,1)
diagd = hdiag - (e-self.level_shift)
diagd[abs(diagd)<1e-8] = 1e-8
diagd = cp.where(abs(diagd) < threshold_t, cp.sign(diagd)*threshold_t, diagd)
a_size = x.shape[1]//2
diagd[:,a_size:] = diagd[:,a_size:]*(-1)
return x/diagd
return precond

Expand Down Expand Up @@ -170,6 +172,17 @@ def _contract_multipole(tdobj, ints, hermi=True, xy=None):
class TDA(TDBase):
__doc__ = tdhf_cpu.TDA.__doc__

def get_precond(self, hdiag):
threshold_t=1.0e-4
def precond(x, e, *args):
n_states = x.shape[0]
diagd = cp.repeat(hdiag.reshape(1,-1), n_states, axis=0)
e = e.reshape(-1,1)
diagd = hdiag - (e-self.level_shift)
diagd = cp.where(abs(diagd) < threshold_t, cp.sign(diagd)*threshold_t, diagd)
return x/diagd
return precond

def gen_vind(self, mf=None):
'''Generate function to compute Ax'''
if mf is None:
Expand Down Expand Up @@ -228,7 +241,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down Expand Up @@ -291,10 +304,10 @@ def vind(zs):
v1_top += xs * e_ia # AX
v1_bot += ys * e_ia # (A*)Y
return cp.hstack((v1_top.reshape(nz,nocc*nvir),
-v1_bot.reshape(nz,nocc*nvir))).get()
-v1_bot.reshape(nz,nocc*nvir)))

hdiag = cp.hstack([hdiag.ravel(), -hdiag.ravel()])
return vind, hdiag.get()
return vind, hdiag


class TDHF(TDBase):
Expand Down
7 changes: 3 additions & 4 deletions gpu4pyscf/tdscf/rks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
import cupy as cp
from pyscf import lib
from pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.dft.rks import KohnShamDFT
from gpu4pyscf.lib.cupy_helper import contract, tag_array, transpose_sum
from gpu4pyscf.lib import logger
Expand Down Expand Up @@ -54,7 +54,6 @@ def gen_vind(self, mf=None):
d_ia = e_ia ** .5
ed_ia = e_ia * d_ia
hdiag = e_ia.ravel() ** 2
hdiag = hdiag.get()
vresp = mf.gen_response(singlet=singlet, hermi=1)
nocc, nvir = e_ia.shape

Expand All @@ -71,7 +70,7 @@ def vind(zs):
v1mo = contract('xpo,pv->xov', v1mo, orbv)
v1mo += zs * ed_ia
v1mo *= d_ia
return v1mo.reshape(v1mo.shape[0],-1).get()
return v1mo.reshape(v1mo.shape[0],-1)

return vind, hdiag

Expand All @@ -95,7 +94,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down
8 changes: 4 additions & 4 deletions gpu4pyscf/tdscf/tests/test_tdrhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ def test_tda_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,nocc,nvir)
ref = mf.to_cpu().TDA().set(singlet=False).gen_vind()[0](zs)
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

df_mf = self.df_mf
ref = df_mf.to_cpu().TDA().set(singlet=True).gen_vind()[0](zs)
dat = df_mf.TDA().set(singlet=True).gen_vind()[0](cp.asarray(zs))
dat = df_mf.TDA().set(singlet=True).gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tdhf_vind(self):
Expand All @@ -140,12 +140,12 @@ def test_tdhf_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,2,nocc,nvir)
ref = mf.to_cpu().TDHF().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDHF().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDHF().set(singlet=True).gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

df_mf = self.df_mf
ref = df_mf.to_cpu().TDHF().set(singlet=False).gen_vind()[0](zs)
dat = df_mf.TDHF().set(singlet=False).gen_vind()[0](zs)
dat = df_mf.TDHF().set(singlet=False).gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions gpu4pyscf/tdscf/tests/test_tdrks.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_tda_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,nocc,nvir)
ref = mf.to_cpu().TDA().set(singlet=False).gen_vind()[0](zs)
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().set(singlet=False).gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tddft_vind(self):
Expand All @@ -261,7 +261,7 @@ def test_tddft_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,2,nocc,nvir)
ref = mf.to_cpu().TDDFT().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDDFT().set(singlet=True).gen_vind()[0](zs)
dat = mf.TDDFT().set(singlet=True).gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_casida_tddft_vind(self):
Expand All @@ -271,7 +271,7 @@ def test_casida_tddft_vind(self):
nvir = nmo - nocc
zs = np.random.rand(3,nocc,nvir)
ref = mf.to_cpu().CasidaTDDFT().gen_vind()[0](zs)
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs))
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions gpu4pyscf/tdscf/tests/test_tduhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_tda_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDA().set().gen_vind()[0](zs)
dat = mf.TDA().set().gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().set().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tdhf_vind(self):
Expand All @@ -111,7 +111,7 @@ def test_tdhf_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,2,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDHF().set().gen_vind()[0](zs)
dat = mf.TDHF().set().gen_vind()[0](zs)
dat = mf.TDHF().set().gen_vind()[0](zs).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions gpu4pyscf/tdscf/tests/test_tduks.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_tda_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDA().gen_vind()[0](zs)
dat = mf.TDA().gen_vind()[0](cp.asarray(zs))
dat = mf.TDA().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_tddft_vind(self):
Expand All @@ -198,7 +198,7 @@ def test_tddft_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,2,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().TDDFT().gen_vind()[0](zs)
dat = mf.TDDFT().gen_vind()[0](cp.asarray(zs))
dat = mf.TDDFT().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

def test_casida_tddft_vind(self):
Expand All @@ -209,7 +209,7 @@ def test_casida_tddft_vind(self):
nvirb = nmo - noccb
zs = np.random.rand(3,nocca*nvira+noccb*nvirb)
ref = mf.to_cpu().CasidaTDDFT().gen_vind()[0](zs)
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs))
dat = mf.CasidaTDDFT().gen_vind()[0](cp.asarray(zs)).get()
self.assertAlmostEqual(abs(ref - dat).max(), 0, 9)

if __name__ == "__main__":
Expand Down
22 changes: 11 additions & 11 deletions gpu4pyscf/tdscf/uhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def gen_tda_operation(mf, fock_ao=None, wfnsym=None):
e_ia_a = mo_energy[0][viridxa] - mo_energy[0][occidxa,None]
e_ia_b = mo_energy[1][viridxb] - mo_energy[1][occidxb,None]
e_ia = cp.hstack((e_ia_a.reshape(-1), e_ia_b.reshape(-1)))
hdiag = e_ia.get()
hdiag = e_ia
nocca, nvira = e_ia_a.shape
noccb, nvirb = e_ia_b.shape

Expand All @@ -88,7 +88,7 @@ def vind(zs):
v1a += za * e_ia_a
v1b += zb * e_ia_b
hx = cp.hstack((v1a.reshape(nz,-1), v1b.reshape(nz,-1)))
return hx.get()
return hx

return vind, hdiag

Expand Down Expand Up @@ -185,7 +185,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down Expand Up @@ -258,7 +258,7 @@ def gen_vind(self):
orbva = mo_coeff[0][:,viridxa]
orbov = (orbob, orbva)
e_ia = mo_energy[0][viridxa] - mo_energy[1][occidxb,None]
hdiag = e_ia.ravel().get()
hdiag = e_ia.ravel()

elif extype == 1:
occidxa = mo_occ[0] > 0
Expand All @@ -267,7 +267,7 @@ def gen_vind(self):
orbvb = mo_coeff[1][:,viridxb]
orbov = (orboa, orbvb)
e_ia = mo_energy[1][viridxb] - mo_energy[0][occidxa,None]
hdiag = e_ia.ravel().get()
hdiag = e_ia.ravel()

vresp = gen_uhf_response_sf(
mf, hermi=0, collinear=self.collinear,
Expand All @@ -283,7 +283,7 @@ def vind(zs):
v1mo = contract('xpq,qo->xpo', v1ao, orbo)
v1mo = contract('xpo,pv->xov', v1mo, orbv.conj())
v1mo += zs * e_ia
return v1mo.reshape(len(v1mo), -1).get()
return v1mo.reshape(len(v1mo), -1)

return vind, hdiag

Expand Down Expand Up @@ -461,10 +461,10 @@ def vind(zs):
v1_bot[:,:nocca*nvira] += v1a_bot.reshape(nz,-1)
v1_top[:,nocca*nvira:] += v1b_top.reshape(nz,-1)
v1_bot[:,nocca*nvira:] += v1b_bot.reshape(nz,-1)
return cp.hstack([v1_top, -v1_bot]).get()
return cp.hstack([v1_top, -v1_bot])

hdiag = cp.hstack([hdiag.ravel(), -hdiag.ravel()])
return vind, hdiag.get()
return vind, hdiag


class TDHF(TDBase):
Expand Down Expand Up @@ -578,9 +578,9 @@ def gen_vind(self):

extype = self.extype
if extype == 0:
hdiag = cp.hstack([e_ia_b2a.ravel(), -e_ia_a2b.ravel()]).get()
hdiag = cp.hstack([e_ia_b2a.ravel(), -e_ia_a2b.ravel()])
else:
hdiag = cp.hstack([e_ia_a2b.ravel(), -e_ia_b2a.ravel()]).get()
hdiag = cp.hstack([e_ia_a2b.ravel(), -e_ia_b2a.ravel()])

vresp = gen_uhf_response_sf(
mf, hermi=0, collinear=self.collinear,
Expand Down Expand Up @@ -681,7 +681,7 @@ def vind(zs):
v1_top += zs_a2b * e_ia_a2b
v1_bot += zs_b2a * e_ia_b2a
hx = cp.hstack([v1_top.reshape(nz,-1), -v1_bot.reshape(nz,-1)])
return hx.get()
return hx

return vind, hdiag

Expand Down
8 changes: 4 additions & 4 deletions gpu4pyscf/tdscf/uks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cupy as cp
from pyscf import symm
from pyscf import lib
from pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.tdscf._lr_eig import eigh as lr_eigh
from gpu4pyscf.dft.rks import KohnShamDFT
from gpu4pyscf.lib.cupy_helper import contract, tag_array, transpose_sum
from gpu4pyscf.lib import logger
Expand Down Expand Up @@ -69,7 +69,7 @@ def gen_vind(self, mf=None):
d_ia = e_ia**.5
ed_ia = e_ia * d_ia
hdiag = e_ia ** 2
hdiag = hdiag.get()
hdiag = hdiag
vresp = mf.gen_response(mo_coeff, mo_occ, hermi=1)
nocca, nvira = e_ia_a.shape
noccb, nvirb = e_ia_b.shape
Expand All @@ -96,7 +96,7 @@ def vind(zs):
hx = cp.hstack((v1a.reshape(nz,-1), v1b.reshape(nz,-1)))
hx += ed_ia * zs
hx *= d_ia
return hx.get()
return hx

return vind, hdiag

Expand All @@ -120,7 +120,7 @@ def kernel(self, x0=None, nstates=None):
precond = self.get_precond(hdiag)

def pickeig(w, v, nroots, envs):
idx = np.where(w > self.positive_eig_threshold)[0]
idx = cp.where(w > self.positive_eig_threshold)[0]
return w[idx], v[:,idx], idx

x0sym = None
Expand Down

0 comments on commit 6ea5c6c

Please sign in to comment.