diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..89b91da0b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + types: [file, python] + - id: end-of-file-fixer + types: [file, python] + - id: check-yaml + + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.961 + hooks: + - id: mypy + files: src/koopmans/ + args: [--ignore-missing-imports] + + - repo: https://github.com/pre-commit/mirrors-autopep8 + rev: v1.6.0 + hooks: + - id: autopep8 + args: [--max-line-length=120, -i] + diff --git a/ase/__init__.py b/ase/__init__.py index d2c483b0a..5ca201c19 100644 --- a/ase/__init__.py +++ b/ase/__init__.py @@ -19,7 +19,7 @@ __all__ = ['Atoms', 'Atom'] -__version__ = '0.1.4' +__version__ = '0.1.5' from ase.atom import Atom from ase.atoms import Atoms diff --git a/ase/io/espresso/_koopmans_ham.py b/ase/io/espresso/_koopmans_ham.py index d7c2d5a92..4263a446f 100644 --- a/ase/io/espresso/_koopmans_ham.py +++ b/ase/io/espresso/_koopmans_ham.py @@ -10,7 +10,7 @@ from ase.dft.kpoints import BandPath from ase.utils import basestring from ._utils import construct_kpoints_card, generic_construct_namelist, safe_string_to_list_of_floats, time_to_float, \ - read_fortran_namelist + read_fortran_namelist, dict_to_input_lines from ._wann2kc import KEYS as W2KCW_KEYS from ase.calculators.espresso import KoopmansHam @@ -38,14 +38,7 @@ def write_koopmans_ham_in(fd, atoms, input_data=None, pseudopotentials=None, continue lines.append('&{0}\n'.format(section.upper())) - for key, value in input_parameters[section].items(): - if value is True: - lines.append(' {0:16} = .true.\n'.format(key)) - elif value is False: - lines.append(' {0:16} = .false.\n'.format(key)) - elif value is not None: - # repr format to get quotes around strings - lines.append(' {0:16} = {1!r:}\n'.format(key, value)) + lines += dict_to_input_lines(input_parameters[section]) lines.append('/\n') # terminate section if input_parameters['HAM'].get('do_bands', True): diff --git a/ase/io/espresso/_koopmans_screen.py b/ase/io/espresso/_koopmans_screen.py index 5996da461..c5d154e05 100644 --- a/ase/io/espresso/_koopmans_screen.py +++ b/ase/io/espresso/_koopmans_screen.py @@ -8,7 +8,7 @@ from ase import Atoms from ase.calculators.singlepoint import SinglePointDFTCalculator from ase.utils import basestring -from ._utils import read_fortran_namelist, generic_construct_namelist, time_to_float, units +from ._utils import read_fortran_namelist, generic_construct_namelist, time_to_float, units, dict_to_input_lines from ._wann2kc import KEYS as W2KCW_KEYS from ase.calculators.espresso import KoopmansScreen @@ -34,14 +34,7 @@ def write_koopmans_screen_in(fd, atoms, input_data=None, **kwargs): continue lines.append('&{0}\n'.format(section.upper())) - for key, value in input_parameters[section].items(): - if value is True: - lines.append(' {0:16} = .true.\n'.format(key)) - elif value is False: - lines.append(' {0:16} = .false.\n'.format(key)) - elif value is not None: - # repr format to get quotes around strings - lines.append(' {0:16} = {1!r:}\n'.format(key, value)) + lines += dict_to_input_lines(input_parameters[section]) lines.append('/\n') # terminate section fd.writelines(lines) diff --git a/ase/io/espresso/_ph.py b/ase/io/espresso/_ph.py index 1754959e6..672be74b6 100644 --- a/ase/io/espresso/_ph.py +++ b/ase/io/espresso/_ph.py @@ -6,7 +6,7 @@ from pathlib import Path from ase.utils import basestring from ase.atoms import Atoms -from ._utils import read_fortran_namelist, time_to_float +from ._utils import read_fortran_namelist, time_to_float, dict_to_input_lines from ase.calculators.espresso import EspressoPh @@ -70,16 +70,7 @@ def write_ph_in(fd, atoms, **kwargs): all_parameters = dict(**atoms.calc.parameters, **masses) all_parameters.pop('pseudopotentials', None) - for key, value in all_parameters.items(): - if value is True: - ph.append(' {0:16} = .true.\n'.format(key)) - elif value is False: - ph.append(' {0:16} = .false.\n'.format(key)) - elif value is not None: - if isinstance(value, Path): - value = str(value) - # repr format to get quotes around strings - ph.append(' {0:16} = {1!r:}\n'.format(key, value)) + ph += dict_to_input_lines(all_parameters) ph.append('/\n') ph.append('0.0 0.0 0.0') diff --git a/ase/io/espresso/_projwfc.py b/ase/io/espresso/_projwfc.py index bc0dce558..fb9ce99a5 100644 --- a/ase/io/espresso/_projwfc.py +++ b/ase/io/espresso/_projwfc.py @@ -1,6 +1,6 @@ from ase.atoms import Atoms from ase.calculators.espresso import Projwfc -from ._utils import read_fortran_namelist, time_to_float +from ._utils import read_fortran_namelist, time_to_float, dict_to_input_lines from pathlib import Path @@ -56,16 +56,7 @@ def write_projwfc_in(fd, atoms, **kwargs): """ projwf = ['&projwfc\n'] - for key, value in atoms.calc.parameters.items(): - if value is True: - projwf.append(f' {key:16} = .true.\n') - elif value is False: - projwf.append(f' {key:16} = .false.\n') - elif value is not None: - if isinstance(value, Path): - value = str(value) - # repr format to get quotes around strings - projwf.append(f' {key:16} = {value!r:}\n') + projwf += dict_to_input_lines(atoms.calc.parameters) projwf.append('/\n') fd.write(''.join(projwf)) diff --git a/ase/io/espresso/_utils.py b/ase/io/espresso/_utils.py index a14186491..b59afb459 100644 --- a/ase/io/espresso/_utils.py +++ b/ase/io/espresso/_utils.py @@ -1384,14 +1384,7 @@ def write_espresso_in(fd, atoms, input_data=None, pseudopotentials=None, # and that repr converts to a QE readable representation (except bools) for section in input_parameters: flines.append('&{0}\n'.format(section.upper())) - for key, value in input_parameters[section].items(): - if value is True: - flines.append(' {0:16} = .true.\n'.format(key)) - elif value is False: - flines.append(' {0:16} = .false.\n'.format(key)) - elif value is not None: - # repr format to get quotes around strings - flines.append(' {0:16} = {1!r:}\n'.format(key, value)) + flines += dict_to_input_lines(input_parameters[section]) flines.append('/\n') # terminate section flines.append('\n') @@ -1425,3 +1418,20 @@ def write_espresso_in(fd, atoms, input_data=None, pseudopotentials=None, # DONE! fd.write(''.join(flines)) + + +def dict_to_input_lines(dct): + out = [] + for key, value in dct.items(): + if value is not None: + if isinstance(value, Path): + value = str(value) + if isinstance(value, str): + # add quotes around strings + value = f"'{value}'" + if value is True: + value = '.true.' + if value is False: + value = '.false.' + out.append(f' {key:16} = {value}\n') + return out diff --git a/ase/io/espresso/_wann2kc.py b/ase/io/espresso/_wann2kc.py index 3237db35a..fa0fe6ce3 100644 --- a/ase/io/espresso/_wann2kc.py +++ b/ase/io/espresso/_wann2kc.py @@ -8,13 +8,14 @@ from ase.atoms import Atoms from ase.calculators.singlepoint import SinglePointDFTCalculator from ase.utils import basestring -from ._utils import Namelist, read_fortran_namelist, generic_construct_namelist, time_to_float +from ._utils import Namelist, read_fortran_namelist, generic_construct_namelist, time_to_float, dict_to_input_lines from ase.calculators.espresso import Wann2KC KEYS = Namelist(( ('control', ['prefix', 'outdir', 'kcw_iverbosity', 'kcw_at_ks', 'calculation', 'lrpa', - 'mp1', 'mp2', 'mp3', 'homo_only', 'read_unitary_matrix', 'l_vcut', 'assume_isolated', 'spin_component']), + 'mp1', 'mp2', 'mp3', 'homo_only', 'read_unitary_matrix', 'l_vcut', 'assume_isolated', + 'spin_component']), ('wannier', ['seedname', 'check_ks', 'num_wann_occ', 'num_wann_emp', 'have_empty', 'has_disentangle']))) @@ -32,14 +33,7 @@ def write_wann2kc_in(fd, atoms, input_data=None, pseudopotentials=None, for section in input_parameters: assert section in KEYS.keys() lines.append('&{0}\n'.format(section.upper())) - for key, value in input_parameters[section].items(): - if value is True: - lines.append(' {0:16} = .true.\n'.format(key)) - elif value is False: - lines.append(' {0:16} = .false.\n'.format(key)) - elif value is not None: - # repr format to get quotes around strings - lines.append(' {0:16} = {1!r:}\n'.format(key, value)) + lines += dict_to_input_lines(input_parameters[section]) lines.append('/\n') # terminate section fd.writelines(lines) diff --git a/ase/io/espresso/_x2y.py b/ase/io/espresso/_x2y.py index 588e5fd30..d4f7a5f8e 100644 --- a/ase/io/espresso/_x2y.py +++ b/ase/io/espresso/_x2y.py @@ -5,7 +5,7 @@ from pathlib import Path from ase.utils import basestring from ase.atoms import Atoms -from ._utils import read_fortran_namelist, time_to_float +from ._utils import read_fortran_namelist, time_to_float, dict_to_input_lines def read_x2y_in(fileobj, calc_class): @@ -60,16 +60,7 @@ def write_x2y_in(fd, atoms, **kwargs): """ x2y = ['&inputpp\n'] - for key, value in atoms.calc.parameters.items(): - if value is True: - x2y.append(f' {key:16} = .true.\n') - elif value is False: - x2y.append(f' {key:16} = .false.\n') - elif value is not None: - if isinstance(value, Path): - value = str(value) - # repr format to get quotes around strings - x2y.append(f' {key:16} = {value!r:}\n') + x2y += dict_to_input_lines(atoms.calc.parameters) x2y.append('/\n') fd.write(''.join(x2y)) diff --git a/ase/spacegroup/xtal.py b/ase/spacegroup/xtal.py index 594774239..6dbc24b65 100644 --- a/ase/spacegroup/xtal.py +++ b/ase/spacegroup/xtal.py @@ -11,9 +11,9 @@ from scipy import spatial import ase -from ase.symbols import string2symbols -from ase.spacegroup import Spacegroup from ase.geometry import cellpar_to_cell +from ase.spacegroup import Spacegroup +from ase.symbols import string2symbols __all__ = ['crystal'] @@ -103,9 +103,9 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1, """ sg = Spacegroup(spacegroup, setting) if (not isinstance(symbols, str) and - hasattr(symbols, '__getitem__') and - len(symbols) > 0 and - isinstance(symbols[0], ase.Atom)): + hasattr(symbols, '__getitem__') and + len(symbols) > 0 and + isinstance(symbols[0], ase.Atom)): symbols = ase.Atoms(symbols) if isinstance(symbols, ase.Atoms): basis = symbols @@ -117,25 +117,25 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1, if symbols is None: symbols = basis.get_chemical_symbols() else: - basis_coords = np.array(basis, dtype=float, copy=False, ndmin=2) + basis_coords = np.array(basis, dtype=float, ndmin=2) if occupancies is not None: occupancies_dict = {} - + for index, coord in enumerate(basis_coords): # Compute all distances and get indices of nearest atoms dist = spatial.distance.cdist(coord.reshape(1, 3), basis_coords) indices_dist = np.flatnonzero(dist < symprec) - + occ = {symbols[index]: occupancies[index]} - + # Check nearest and update occupancy for index_dist in indices_dist: if index == index_dist: continue else: occ.update({symbols[index_dist]: occupancies[index_dist]}) - + occupancies_dict[index] = occ.copy() sites, kinds = sg.equivalent_sites(basis_coords, @@ -154,7 +154,7 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1, symbols = [symbols[i] for i in kinds] else: # make sure that we put the dominant species there - symbols = [sorted(occupancies_dict[i].items(), key=lambda x : x[1])[-1][0] for i in kinds] + symbols = [sorted(occupancies_dict[i].items(), key=lambda x: x[1])[-1][0] for i in kinds] if cell is None: cell = cellpar_to_cell(cellpar, ab_normal, a_direction) @@ -164,7 +164,7 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1, info['unit_cell'] = 'primitive' else: info['unit_cell'] = 'conventional' - + if 'info' in kwargs: info.update(kwargs['info']) @@ -186,7 +186,7 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1, array = basis.get_array(name) atoms.new_array(name, [array[i] for i in kinds], dtype=array.dtype, shape=array.shape[1:]) - + if kinds: atoms.new_array('spacegroup_kinds', np.asarray(kinds, dtype=int))