diff --git a/src/koopmans/calculators/_environ.py b/src/koopmans/calculators/_environ.py index ba97cb9a8..6263d6c83 100644 --- a/src/koopmans/calculators/_environ.py +++ b/src/koopmans/calculators/_environ.py @@ -26,10 +26,9 @@ def __init__(self, *args, **kwargs): # Add dictionary of environ settings self.environ_settings = copy.deepcopy(_default_settings) - def calculate(self): - # Generic function for running a calculation + def _pre_calculate(self): self.write_environ_in() - super().calculate() + super()._pre_calculate() def set_environ_settings(self, settings, use_defaults=True): self.environ_settings = settings diff --git a/src/koopmans/calculators/_koopmans_cp.py b/src/koopmans/calculators/_koopmans_cp.py index 6f2fe28f2..f90c89e46 100644 --- a/src/koopmans/calculators/_koopmans_cp.py +++ b/src/koopmans/calculators/_koopmans_cp.py @@ -119,16 +119,16 @@ def __init__(self, atoms: Atoms, alphas: Optional[List[List[float]]] = None, # koopmans.workflows._koopmans_dscf.py for more details) self.fixed_band: Optional[bands.Band] = None - def calculate(self): + # Create a private attribute to keep track of whether the spin channels have been swapped + self._spin_channels_are_swapped: bool = False + + def _pre_calculate(self): # kcp.x imposes nelup >= neldw, so if we try to run a calcualtion with neldw > nelup, swap the spin channels if self.parameters.nspin == 2: - spin_channels_are_swapped = self.parameters.nelup < self.parameters.neldw - else: - spin_channels_are_swapped = False - - # Swap the spin channels - if spin_channels_are_swapped: - self._swap_spin_channels() + self._spin_channels_are_swapped = self.parameters.nelup < self.parameters.neldw + # Swap the spin channels if required + if self._spin_channels_are_swapped: + self._swap_spin_channels() # Write out screening parameters to file if self.parameters.get('do_orbdep', False): @@ -143,7 +143,11 @@ def calculate(self): # Autogenerate the nr keywords self._autogenerate_nr() - super().calculate() + super()._pre_calculate() + + def _post_calculate(self): + + super()._post_calculate() # Check spin-up and spin-down eigenvalues match if 'eigenvalues' in self.results and self.parameters.do_outerloop \ @@ -154,7 +158,7 @@ def calculate(self): utils.warn('Spin-up and spin-down eigenvalues differ substantially') # Swap the spin channels back - if spin_channels_are_swapped: + if self._spin_channels_are_swapped: self._swap_spin_channels() def _swap_spin_channels(self): diff --git a/src/koopmans/calculators/_koopmans_ham.py b/src/koopmans/calculators/_koopmans_ham.py index ebe9ed4f9..babfe2736 100644 --- a/src/koopmans/calculators/_koopmans_ham.py +++ b/src/koopmans/calculators/_koopmans_ham.py @@ -49,9 +49,12 @@ def write_alphas(self): filling = [True for _ in range(len(alphas))] utils.write_alpha_file(self.directory, alphas, filling) - def _calculate(self): + def _pre_calculate(self): + super()._pre_calculate() self.write_alphas() - super()._calculate() + + def _post_calculate(self): + super()._post_calculate() if isinstance(self.parameters.kpts, BandPath) and len(self.parameters.kpts.kpts) > 1: # Add the bandstructure to the results self.generate_band_structure() diff --git a/src/koopmans/calculators/_koopmans_screen.py b/src/koopmans/calculators/_koopmans_screen.py index 4ad3da277..999e8ce06 100644 --- a/src/koopmans/calculators/_koopmans_screen.py +++ b/src/koopmans/calculators/_koopmans_screen.py @@ -34,14 +34,14 @@ def __init__(self, atoms: Atoms, *args, **kwargs): self.command = ParallelCommandWithPostfix(f'kcw.x -in PREFIX{self.ext_in} > PREFIX{self.ext_out} 2>&1') - def calculate(self): + def _pre_calculate(self): # Check eps infinity kpoints = [self.parameters.mp1, self.parameters.mp2, self.parameters.mp3] if np.max(kpoints) > 1 and self.parameters.eps_inf is None: utils.warn('You have not specified a value for eps_inf. This will mean that the screening parameters will ' 'converge very slowly with respect to the k- and q-point grids') - super().calculate() + super()._pre_calculate() def is_converged(self): raise NotImplementedError('TODO') diff --git a/src/koopmans/calculators/_ph.py b/src/koopmans/calculators/_ph.py index 17388cf60..7b0ff0799 100644 --- a/src/koopmans/calculators/_ph.py +++ b/src/koopmans/calculators/_ph.py @@ -40,13 +40,10 @@ def is_converged(self): def is_complete(self): return self.results['job done'] - def _calculate(self): - super()._calculate() + def _post_calculate(self): + super()._post_calculate() if self.parameters.trans: self.read_dynG() - else: - self.read_stdout() - def read_dynG(self): with open(self.parameters.fildyn, 'r') as fd: @@ -56,12 +53,3 @@ def read_dynG(self): k = [x.strip() for x in flines].index('Effective Charges E-U: Z_{alpha}{s,beta}') epsilon = np.array([x.split() for x in flines[i + 2: k - 1]], dtype=float) self.results['dielectric tensor'] = epsilon - - def read_stdout(self): - path=f'{self.prefix}{self.ext_out}' - with open(path, 'r') as fd: - flines = fd.readlines() - - i = [x.strip() for x in flines].index('Dielectric constant in cartesian axis') - epsilon = np.array([x.split()[1:4] for x in flines[i + 2: i + 5]], dtype=float) - self.results['dielectric tensor'] = epsilon diff --git a/src/koopmans/calculators/_projwfc.py b/src/koopmans/calculators/_projwfc.py index 1338cdbd1..2ffb75232 100644 --- a/src/koopmans/calculators/_projwfc.py +++ b/src/koopmans/calculators/_projwfc.py @@ -49,12 +49,15 @@ def __init__(self, atoms: Atoms, *args, **kwargs): self.pseudo_dir: Optional[Path] = None self.spin_polarized: Optional[bool] = None - def _calculate(self): + def _pre_calculate(self): + super()._pre_calculate() for attr in ['pseudopotentials', 'pseudo_dir', 'spin_polarized']: if not hasattr(self, attr): raise ValueError(f'Please set {self.__class__.__name__}.{attr} before calling ' f'{self.__class__.__name__.calculate()}') - super()._calculate() + + def _post_calculate(self): + super()._post_calculate() self.generate_dos() @property diff --git a/src/koopmans/calculators/_pw.py b/src/koopmans/calculators/_pw.py index 60b14575c..c5ce88102 100644 --- a/src/koopmans/calculators/_pw.py +++ b/src/koopmans/calculators/_pw.py @@ -38,25 +38,32 @@ def __init__(self, atoms: Atoms, *args, **kwargs): self.command = ParallelCommandWithPostfix(os.environ.get( 'ASE_ESPRESSO_COMMAND', self.command)) - def calculate(self): + def _pre_calculate(self): # Update ibrav and celldms if cell_follows_qe_conventions(self.atoms.cell): self.parameters.update(**cell_to_parameters(self.atoms.cell)) else: self.parameters.ibrav = 0 - super().calculate() - def _calculate(self): + # Make sure kpts has been correctly provided if self.parameters.calculation == 'bands': if not isinstance(self.parameters.kpts, BandPath): raise KeyError('You are running a calculation that requires a kpoint path; please provide a BandPath ' 'as the kpts parameter') + + super()._pre_calculate() - super()._calculate() + return + + def _post_calculate(self): + + super()._post_calculate() if isinstance(self.parameters.kpts, BandPath): # Add the bandstructure to the results. This is very un-ASE-y and might eventually be replaced self.generate_band_structure() + + return def is_complete(self): return self.results.get('job done', False) diff --git a/src/koopmans/calculators/_utils.py b/src/koopmans/calculators/_utils.py index ced8a8aac..39ef7562e 100644 --- a/src/koopmans/calculators/_utils.py +++ b/src/koopmans/calculators/_utils.py @@ -128,24 +128,44 @@ def directory(self, value: Union[Path, str]): self.parameters.directory = self._directory def calculate(self): - # Generic function for running a calculation + """Generic function for running a calculator""" - # First, check the corresponding program is installed - self.check_code_is_installed() + # First run any pre-calculation steps + self._pre_calculate() # Then call the relevant ASE calculate() function self._calculate() - # Then check if the calculation completed - if not self.is_complete(): + # Then run any post-calculation steps + self._post_calculate() + + def _post_calculate(self): + """Perform any necessary post-calculation steps after running the calculation""" + # Check if the calculation completed + if not self.is_complete(): raise CalculationFailed( f'{self.directory}/{self.prefix} failed; check the Quantum ESPRESSO output file for more details') - # Then check convergence + # Check convergence self.check_convergence() + return + + def _pre_calculate(self): + """Perform any necessary pre-calculation steps before running the calculation""" + + # By default, check the corresponding program is installed + self.check_code_is_installed() + + return + def _calculate(self): + """Run the calculation using the ASE calculator's calculate() method + + This method should NOT be overwritten by child classes. Child classes should only modify _pre_calculate() and + _post_calculate() to perform any necessary pre- and post-calculation steps.""" + # ASE expects self.command to be a string command = copy.deepcopy(self.command) self.command = str(command) diff --git a/src/koopmans/workflows/_dft.py b/src/koopmans/workflows/_dft.py index d342284a1..db7bfa9c9 100644 --- a/src/koopmans/workflows/_dft.py +++ b/src/koopmans/workflows/_dft.py @@ -53,7 +53,7 @@ def _run(self): if calc.parameters.empty_states_maxstep is None: calc.parameters.empty_states_maxstep = 300 - self.run_calculator(calc, enforce_ss=self.parameters.fix_spin_contamination) + self.run_calculator(calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination) return calc diff --git a/src/koopmans/workflows/_koopmans_dfpt.py b/src/koopmans/workflows/_koopmans_dfpt.py index 1363a5d6f..9131208cf 100644 --- a/src/koopmans/workflows/_koopmans_dfpt.py +++ b/src/koopmans/workflows/_koopmans_dfpt.py @@ -196,15 +196,19 @@ def _run(self): self.bands.alphas = kc_screen_calc.results['alphas'] else: # If there is orbital grouping, do the orbitals one-by-one + + kc_screen_calcs = [] + # 1) Create the calculators (in subdirectories) for band in self.bands.to_solve: - # 1) Create the calculator (in a subdirectory) kc_screen_calc = self.new_calculator('kc_screen', i_orb=band.index) kc_screen_calc.directory /= f'band_{band.index}' + kc_screen_calcs.append(kc_screen_calc) - # 2) Run the calculator - self.run_calculator(kc_screen_calc) + # 2) Run the calculators (possibly in parallel) + self.run_calculators(kc_screen_calcs) - # 3) Store the computed screening parameter (accounting for band groupings) + # 3) Store the computed screening parameters (accounting for band groupings) + for band, kc_screen_calc in zip(self.bands.to_solve, kc_screen_calcs): for b in self.bands: if b.group == band.group: alpha = kc_screen_calc.results['alphas'][band.spin] diff --git a/src/koopmans/workflows/_koopmans_dscf.py b/src/koopmans/workflows/_koopmans_dscf.py index da01bf5fd..e0cced0c3 100644 --- a/src/koopmans/workflows/_koopmans_dscf.py +++ b/src/koopmans/workflows/_koopmans_dscf.py @@ -352,7 +352,7 @@ def perform_initialization(self) -> None: # to copy the previously calculated Wannier functions calc = self.new_kcp_calculator('dft_dummy') calc.directory = Path('init') - self.run_calculator(calc, enforce_ss=False) + self.run_calculator(calc) # DFT restarting from Wannier functions (after copying the Wannier functions) calc = self.new_kcp_calculator('dft_init', restart_mode='restart', @@ -390,7 +390,7 @@ def perform_initialization(self) -> None: else: raise OSError(f'Could not find {evcw_file}') - self.run_calculator(calc, enforce_ss=False) + self.run_calculator(calc) # Check the consistency between the PW and CP band gaps pw_calc = [c for c in self.calculations if isinstance( @@ -414,7 +414,7 @@ def perform_initialization(self) -> None: elif self.parameters.functional in ['ki', 'pkipz']: calc = self.new_kcp_calculator('dft_init') calc.directory = Path('init') - self.run_calculator(calc, enforce_ss=self.parameters.fix_spin_contamination) + self.run_calculator(calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination) # Use the KS eigenfunctions as better guesses for the variational orbitals self._overwrite_canonical_with_variational_orbitals(calc) @@ -440,7 +440,7 @@ def perform_initialization(self) -> None: # DFT from scratch calc = self.new_kcp_calculator('dft_init') calc.directory = Path('init') - self.run_calculator(calc, enforce_ss=self.parameters.fix_spin_contamination) + self.run_calculator(calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination) if self.parameters.init_orbitals == 'kohn-sham': # Initialize the density with DFT and use the KS eigenfunctions as guesses for the variational orbitals @@ -525,7 +525,7 @@ def perform_alpha_calculations(self) -> None: # Run the calculation and store the result. Note that we only need to continue # enforcing the spin symmetry if the density will change - self.run_calculator(trial_calc, enforce_ss=self.parameters.fix_spin_contamination and i_sc > 1) + self.run_calculator(trial_calc, enforce_spin_symmetry=self.parameters.fix_spin_contamination and i_sc > 1) alpha_dep_calcs = [trial_calc] diff --git a/src/koopmans/workflows/_unfold_and_interp.py b/src/koopmans/workflows/_unfold_and_interp.py index 77745e0cc..a72783b1e 100644 --- a/src/koopmans/workflows/_unfold_and_interp.py +++ b/src/koopmans/workflows/_unfold_and_interp.py @@ -95,7 +95,7 @@ def _run(self) -> None: calc.spreads = spreads[mask].tolist() # Run the calculator - self.run_calculator(calc, enforce_ss=False) + self.run_calculator(calc) # Merge the two calculations to print out the DOS and bands calc = self.new_ui_calculator('merge') diff --git a/src/koopmans/workflows/_workflow.py b/src/koopmans/workflows/_workflow.py index 8f27635d0..97a26ba04 100644 --- a/src/koopmans/workflows/_workflow.py +++ b/src/koopmans/workflows/_workflow.py @@ -716,58 +716,65 @@ def supercell_to_primitive(self, matrix: Optional[npt.NDArray[np.int_]] = None): self.atoms = self.atoms[mask] - def run_calculator(self, master_qe_calc: calculators.Calc, enforce_ss=False): - ''' - Wrapper for run_calculator_single that manages the optional enforcing of spin symmetry + def run_calculator(self, master_calc: calculators.Calc, enforce_spin_symmetry: bool = False): + ''' Run a calculator. + + If enforce_spin_symmetry is True, the calculation will be run with spin symmetry enforced. + Ultimately this wraps self.run_calculators + + :param master_calc: the calculator to run + :param enforce_spin_symmetry: whether to enforce spin symmetry ''' - if enforce_ss: - if not isinstance(master_qe_calc, calculators.CalculatorCanEnforceSpinSym): - raise NotImplementedError(f'{master_qe_calc.__class__.__name__} cannot enforce spin symmetry') + if enforce_spin_symmetry: + if not isinstance(master_calc, calculators.CalculatorCanEnforceSpinSym): + raise NotImplementedError(f'{master_calc.__class__.__name__} cannot enforce spin symmetry') - if not master_qe_calc.from_scratch: + if not master_calc.from_scratch: # PBE with nspin=1 dummy - qe_calc = master_qe_calc.nspin1_dummy_calculator() + qe_calc = master_calc.nspin1_dummy_calculator() qe_calc.skip_qc = True - self.run_calculator_single(qe_calc) + self.run_calculator(qe_calc) # Copy over nspin=2 wavefunction to nspin=1 tmp directory (if it has not been done already) if self.parameters.from_scratch: - master_qe_calc.convert_wavefunction_2to1() + master_calc.convert_wavefunction_2to1() # PBE with nspin=1 - qe_calc = master_qe_calc.nspin1_calculator() - self.run_calculator_single(qe_calc) + qe_calc = master_calc.nspin1_calculator() + self.run_calculator(qe_calc) # PBE from scratch with nspin=2 (dummy run for creating files of appropriate size) - qe_calc = master_qe_calc.nspin2_dummy_calculator() + qe_calc = master_calc.nspin2_dummy_calculator() qe_calc.skip_qc = True - self.run_calculator_single(qe_calc) + self.run_calculator(qe_calc) # Copy over nspin=1 wavefunction to nspin=2 tmp directory (if it has not been done already) if self.parameters.from_scratch: - master_qe_calc.convert_wavefunction_1to2() + master_calc.convert_wavefunction_1to2() # PBE with nspin=2, reading in the spin-symmetric nspin=1 wavefunction - master_qe_calc.prepare_to_read_nspin1() - self.run_calculator_single(master_qe_calc) + master_calc.prepare_to_read_nspin1() + self.run_calculator(master_calc) else: - self.run_calculator_single(master_qe_calc) + self.run_calculators([master_calc]) return - def run_calculator_single(self, qe_calc: calculators.Calc): - # Runs qe_calc.calculate with additional checks + def _pre_run_calculator(self, qe_calc: calculators.Calc) -> bool: + """Perform operations that need to occur before a calculation is run + + :param qe_calc: The calculator to run + :return: Whether the calculation should be run + """ # If an output file already exists, check if the run completed successfully - verb = 'Running' if not self.parameters.from_scratch: calc_file = qe_calc.directory / qe_calc.prefix if calc_file.with_suffix(qe_calc.ext_out).is_file(): - verb = 'Rerunning' is_complete = self.load_old_calculator(qe_calc) if is_complete: @@ -785,37 +792,72 @@ def run_calculator_single(self, qe_calc: calculators.Calc): if isinstance(qe_calc, calculators.PhCalculator): qe_calc.read_dynG() - return - - if not self.silent: - dir_str = os.path.relpath(qe_calc.directory) + '/' - self.print(f'{verb} {dir_str}{qe_calc.prefix}...', end='', flush=True) + return False # Update postfix if relevant if self.parameters.npool: if isinstance(qe_calc.command, ParallelCommandWithPostfix): qe_calc.command.postfix = f'-npool {self.parameters.npool}' - try: - qe_calc.calculate() - except CalculationFailed: - self.print(' failed') - raise + return True + + def _run_calculators(self, calcs: List[calculators.Calc]) -> None: + """Run a list of calculators, without doing anything else (other than printing messages.) + + Any pre- or post-processing of each calculator should have been done in _pre_run_calculator and _post_run_calculator. + + :param calcs: The calculators to run + """ + + for calc in calcs: + verb = "Running" if self.parameters.from_scratch else "Rerunning" - if not self.silent: - self.print(' done') + if not self.silent: + dir_str = os.path.relpath(calc.directory) + '/' + self.print(f'{verb} {dir_str}{calc.prefix}...', end='', flush=True) + try: + calc.calculate() + except CalculationFailed: + self.print(' failed') + raise + + if not self.silent: + self.print(' done') + + return + + def _post_run_calculator(self, calc: calculators.Calc) -> None: + """ + Perform any operations that need to be performed after a calculation has run. + """ # Store the calculator - self.calculations.append(qe_calc) + self.calculations.append(calc) # Ensure we inherit any modifications made to the atoms object - if qe_calc.atoms != self.atoms: - self.atoms = qe_calc.atoms + if calc.atoms != self.atoms: + self.atoms = calc.atoms # If we reached here, all future calculations should be performed from scratch self.parameters.from_scratch = True return + + def run_calculators(self, calcs: List[calculators.Calc]): + ''' + Run a list of *independent* calculators (default implementation is to run them in sequence) + ''' + + calcs_to_run = [] + for calc in calcs: + proceed = self._pre_run_calculator(calc) + if proceed: + calcs_to_run.append(calc) + + self._run_calculators(calcs_to_run) + + for calc in calcs_to_run: + self._post_run_calculator(calc) def load_old_calculator(self, qe_calc: calculators.Calc) -> bool: # This is a separate function so that it can be monkeypatched by the test suite diff --git a/tests/helpers/patches/_check.py b/tests/helpers/patches/_check.py index 5bb24ac19..5181d83da 100644 --- a/tests/helpers/patches/_check.py +++ b/tests/helpers/patches/_check.py @@ -166,15 +166,23 @@ def _print_messages(self, messages: List[Dict[str, str]]) -> None: message = message.replace('disagreements', 'disagreement') raise CalculationFailed(message) - - def _calculate(self): - # Before running the calculation, check the settings are the same - + + def _load_benchmark(self) -> Calc: with utils.chdir(self.directory): # type: ignore[attr-defined] # By moving into the directory where the calculation was run, we ensure when we read in the settings that # paths are interpreted relative to this particular working directory with open(benchmark_filename(self), 'r') as fd: benchmark = read_encoded_json(fd) + return benchmark + + def _pre_calculate(self): + """Before running the calculation, check the settings are the same""" + + # Perform the pre_calculate first, as sometimes this function modifies the input parameters + super()._pre_calculate() + + # Load the benchmark + benchmark = self._load_benchmark() # Compare the settings unique_keys: Set[str] = set(list(self.parameters.keys()) + list(benchmark.parameters.keys())) @@ -208,8 +216,9 @@ def _calculate(self): # Check that the right files exist # TODO - # Run the calculation - super()._calculate() + def _post_calculate(self): + # Perform the post_calculate first, as sometimes this function adds extra entries to self.results + super()._post_calculate() # Check the results if self.skip_qc: @@ -218,6 +227,7 @@ def _calculate(self): # the corresponding workflow pass else: + benchmark = self._load_benchmark() self._check_results(benchmark) # Check the expected files were produced