Skip to content

Commit

Permalink
Split run_calculator and calculate to have pre_ and post_steps
Browse files Browse the repository at this point in the history
  • Loading branch information
elinscott committed May 6, 2024
1 parent 559da8b commit e9b8f14
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 97 deletions.
5 changes: 2 additions & 3 deletions src/koopmans/calculators/_environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ def __init__(self, *args, **kwargs):
# Add dictionary of environ settings
self.environ_settings = copy.deepcopy(_default_settings)

def calculate(self):
# Generic function for running a calculation
def _pre_calculate(self):
self.write_environ_in()
super().calculate()
super()._pre_calculate()

def set_environ_settings(self, settings, use_defaults=True):
self.environ_settings = settings
Expand Down
24 changes: 14 additions & 10 deletions src/koopmans/calculators/_koopmans_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,16 @@ def __init__(self, atoms: Atoms, alphas: Optional[List[List[float]]] = None,
# koopmans.workflows._koopmans_dscf.py for more details)
self.fixed_band: Optional[bands.Band] = None

def calculate(self):
# Create a private attribute to keep track of whether the spin channels have been swapped
self._spin_channels_are_swapped: bool = False

def _pre_calculate(self):
# kcp.x imposes nelup >= neldw, so if we try to run a calcualtion with neldw > nelup, swap the spin channels
if self.parameters.nspin == 2:
spin_channels_are_swapped = self.parameters.nelup < self.parameters.neldw
else:
spin_channels_are_swapped = False

# Swap the spin channels
if spin_channels_are_swapped:
self._swap_spin_channels()
self._spin_channels_are_swapped = self.parameters.nelup < self.parameters.neldw
# Swap the spin channels if required
if self._spin_channels_are_swapped:
self._swap_spin_channels()

# Write out screening parameters to file
if self.parameters.get('do_orbdep', False):
Expand All @@ -143,7 +143,11 @@ def calculate(self):
# Autogenerate the nr keywords
self._autogenerate_nr()

super().calculate()
super()._pre_calculate()

def _post_calculate(self):

super()._post_calculate()

# Check spin-up and spin-down eigenvalues match
if 'eigenvalues' in self.results and self.parameters.do_outerloop \
Expand All @@ -154,7 +158,7 @@ def calculate(self):
utils.warn('Spin-up and spin-down eigenvalues differ substantially')

# Swap the spin channels back
if spin_channels_are_swapped:
if self._spin_channels_are_swapped:
self._swap_spin_channels()

def _swap_spin_channels(self):
Expand Down
7 changes: 5 additions & 2 deletions src/koopmans/calculators/_koopmans_ham.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ def write_alphas(self):
filling = [True for _ in range(len(alphas))]
utils.write_alpha_file(self.directory, alphas, filling)

def _calculate(self):
def _pre_calculate(self):
super()._pre_calculate()
self.write_alphas()
super()._calculate()

def _post_calculate(self):
super()._post_calculate()
if isinstance(self.parameters.kpts, BandPath) and len(self.parameters.kpts.kpts) > 1:
# Add the bandstructure to the results
self.generate_band_structure()
Expand Down
4 changes: 2 additions & 2 deletions src/koopmans/calculators/_koopmans_screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def __init__(self, atoms: Atoms, *args, **kwargs):

self.command = ParallelCommandWithPostfix(f'kcw.x -in PREFIX{self.ext_in} > PREFIX{self.ext_out} 2>&1')

def calculate(self):
def _pre_calculate(self):
# Check eps infinity
kpoints = [self.parameters.mp1, self.parameters.mp2, self.parameters.mp3]
if np.max(kpoints) > 1 and self.parameters.eps_inf is None:
utils.warn('You have not specified a value for eps_inf. This will mean that the screening parameters will '
'converge very slowly with respect to the k- and q-point grids')

super().calculate()
super()._pre_calculate()

def is_converged(self):
raise NotImplementedError('TODO')
Expand Down
16 changes: 2 additions & 14 deletions src/koopmans/calculators/_ph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,10 @@ def is_converged(self):
def is_complete(self):
return self.results['job done']

def _calculate(self):
super()._calculate()
def _post_calculate(self):
super()._post_calculate()
if self.parameters.trans:
self.read_dynG()
else:
self.read_stdout()


def read_dynG(self):
with open(self.parameters.fildyn, 'r') as fd:
Expand All @@ -56,12 +53,3 @@ def read_dynG(self):
k = [x.strip() for x in flines].index('Effective Charges E-U: Z_{alpha}{s,beta}')
epsilon = np.array([x.split() for x in flines[i + 2: k - 1]], dtype=float)
self.results['dielectric tensor'] = epsilon

def read_stdout(self):
path=f'{self.prefix}{self.ext_out}'
with open(path, 'r') as fd:
flines = fd.readlines()

i = [x.strip() for x in flines].index('Dielectric constant in cartesian axis')
epsilon = np.array([x.split()[1:4] for x in flines[i + 2: i + 5]], dtype=float)
self.results['dielectric tensor'] = epsilon
7 changes: 5 additions & 2 deletions src/koopmans/calculators/_projwfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ def __init__(self, atoms: Atoms, *args, **kwargs):
self.pseudo_dir: Optional[Path] = None
self.spin_polarized: Optional[bool] = None

def _calculate(self):
def _pre_calculate(self):
super()._pre_calculate()
for attr in ['pseudopotentials', 'pseudo_dir', 'spin_polarized']:
if not hasattr(self, attr):
raise ValueError(f'Please set {self.__class__.__name__}.{attr} before calling '
f'{self.__class__.__name__.calculate()}')
super()._calculate()

def _post_calculate(self):
super()._post_calculate()
self.generate_dos()

@property
Expand Down
15 changes: 11 additions & 4 deletions src/koopmans/calculators/_pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,32 @@ def __init__(self, atoms: Atoms, *args, **kwargs):
self.command = ParallelCommandWithPostfix(os.environ.get(
'ASE_ESPRESSO_COMMAND', self.command))

def calculate(self):
def _pre_calculate(self):
# Update ibrav and celldms
if cell_follows_qe_conventions(self.atoms.cell):
self.parameters.update(**cell_to_parameters(self.atoms.cell))
else:
self.parameters.ibrav = 0
super().calculate()

def _calculate(self):
# Make sure kpts has been correctly provided
if self.parameters.calculation == 'bands':
if not isinstance(self.parameters.kpts, BandPath):
raise KeyError('You are running a calculation that requires a kpoint path; please provide a BandPath '
'as the kpts parameter')

super()._pre_calculate()

super()._calculate()
return

def _post_calculate(self):

super()._post_calculate()

if isinstance(self.parameters.kpts, BandPath):
# Add the bandstructure to the results. This is very un-ASE-y and might eventually be replaced
self.generate_band_structure()

return

def is_complete(self):
return self.results.get('job done', False)
Expand Down
32 changes: 26 additions & 6 deletions src/koopmans/calculators/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,44 @@ def directory(self, value: Union[Path, str]):
self.parameters.directory = self._directory

def calculate(self):
# Generic function for running a calculation
"""Generic function for running a calculator"""

# First, check the corresponding program is installed
self.check_code_is_installed()
# First run any pre-calculation steps
self._pre_calculate()

# Then call the relevant ASE calculate() function
self._calculate()

# Then check if the calculation completed
if not self.is_complete():
# Then run any post-calculation steps
self._post_calculate()

def _post_calculate(self):
"""Perform any necessary post-calculation steps after running the calculation"""

# Check if the calculation completed
if not self.is_complete():
raise CalculationFailed(
f'{self.directory}/{self.prefix} failed; check the Quantum ESPRESSO output file for more details')

# Then check convergence
# Check convergence
self.check_convergence()

return

def _pre_calculate(self):
"""Perform any necessary pre-calculation steps before running the calculation"""

# By default, check the corresponding program is installed
self.check_code_is_installed()

return

def _calculate(self):
"""Run the calculation using the ASE calculator's calculate() method
This method should NOT be overwritten by child classes. Child classes should only modify _pre_calculate() and
_post_calculate() to perform any necessary pre- and post-calculation steps."""

# ASE expects self.command to be a string
command = copy.deepcopy(self.command)
self.command = str(command)
Expand Down
2 changes: 1 addition & 1 deletion src/koopmans/workflows/_dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _run(self):
if calc.parameters.empty_states_maxstep is None:
calc.parameters.empty_states_maxstep = 300

self.run_calculator(calc, enforce_ss=self.parameters.fix_spin_contamination)
self.run_calculator(calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination)

return calc

Expand Down
12 changes: 8 additions & 4 deletions src/koopmans/workflows/_koopmans_dfpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,19 @@ def _run(self):
self.bands.alphas = kc_screen_calc.results['alphas']
else:
# If there is orbital grouping, do the orbitals one-by-one

kc_screen_calcs = []
# 1) Create the calculators (in subdirectories)
for band in self.bands.to_solve:
# 1) Create the calculator (in a subdirectory)
kc_screen_calc = self.new_calculator('kc_screen', i_orb=band.index)
kc_screen_calc.directory /= f'band_{band.index}'
kc_screen_calcs.append(kc_screen_calc)

# 2) Run the calculator
self.run_calculator(kc_screen_calc)
# 2) Run the calculators (possibly in parallel)
self.run_calculators(kc_screen_calcs)

# 3) Store the computed screening parameter (accounting for band groupings)
# 3) Store the computed screening parameters (accounting for band groupings)
for band, kc_screen_calc in zip(self.bands.to_solve, kc_screen_calcs):
for b in self.bands:
if b.group == band.group:
alpha = kc_screen_calc.results['alphas'][band.spin]
Expand Down
10 changes: 5 additions & 5 deletions src/koopmans/workflows/_koopmans_dscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def perform_initialization(self) -> None:
# to copy the previously calculated Wannier functions
calc = self.new_kcp_calculator('dft_dummy')
calc.directory = Path('init')
self.run_calculator(calc, enforce_ss=False)
self.run_calculator(calc)

# DFT restarting from Wannier functions (after copying the Wannier functions)
calc = self.new_kcp_calculator('dft_init', restart_mode='restart',
Expand Down Expand Up @@ -390,7 +390,7 @@ def perform_initialization(self) -> None:
else:
raise OSError(f'Could not find {evcw_file}')

self.run_calculator(calc, enforce_ss=False)
self.run_calculator(calc)

# Check the consistency between the PW and CP band gaps
pw_calc = [c for c in self.calculations if isinstance(
Expand All @@ -414,7 +414,7 @@ def perform_initialization(self) -> None:
elif self.parameters.functional in ['ki', 'pkipz']:
calc = self.new_kcp_calculator('dft_init')
calc.directory = Path('init')
self.run_calculator(calc, enforce_ss=self.parameters.fix_spin_contamination)
self.run_calculator(calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination)

# Use the KS eigenfunctions as better guesses for the variational orbitals
self._overwrite_canonical_with_variational_orbitals(calc)
Expand All @@ -440,7 +440,7 @@ def perform_initialization(self) -> None:
# DFT from scratch
calc = self.new_kcp_calculator('dft_init')
calc.directory = Path('init')
self.run_calculator(calc, enforce_ss=self.parameters.fix_spin_contamination)
self.run_calculator(calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination)

if self.parameters.init_orbitals == 'kohn-sham':
# Initialize the density with DFT and use the KS eigenfunctions as guesses for the variational orbitals
Expand Down Expand Up @@ -525,7 +525,7 @@ def perform_alpha_calculations(self) -> None:

# Run the calculation and store the result. Note that we only need to continue
# enforcing the spin symmetry if the density will change
self.run_calculator(trial_calc, enforce_ss=self.parameters.fix_spin_contamination and i_sc > 1)
self.run_calculator(trial_calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination and i_sc > 1)

alpha_dep_calcs = [trial_calc]

Expand Down
2 changes: 1 addition & 1 deletion src/koopmans/workflows/_unfold_and_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _run(self) -> None:
calc.spreads = spreads[mask].tolist()

# Run the calculator
self.run_calculator(calc, enforce_ss=False)
self.run_calculator(calc)

# Merge the two calculations to print out the DOS and bands
calc = self.new_ui_calculator('merge')
Expand Down
Loading

0 comments on commit e9b8f14

Please sign in to comment.