Skip to content

Commit

Permalink
Change from using repr to str for printing the values in input files (#…
Browse files Browse the repository at this point in the history
…61)

… because repr causes issues with numpy v2
  • Loading branch information
elinscott authored Jun 18, 2024
1 parent cc20d49 commit 6811c59
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 83 deletions.
30 changes: 30 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
types: [file, python]
- id: end-of-file-fixer
types: [file, python]
- id: check-yaml

- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.961
hooks:
- id: mypy
files: src/koopmans/
args: [--ignore-missing-imports]

- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v1.6.0
hooks:
- id: autopep8
args: [--max-line-length=120, -i]

2 changes: 1 addition & 1 deletion ase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__all__ = ['Atoms', 'Atom']

__version__ = '0.1.4'
__version__ = '0.1.5'

from ase.atom import Atom
from ase.atoms import Atoms
Expand Down
11 changes: 2 additions & 9 deletions ase/io/espresso/_koopmans_ham.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ase.dft.kpoints import BandPath
from ase.utils import basestring
from ._utils import construct_kpoints_card, generic_construct_namelist, safe_string_to_list_of_floats, time_to_float, \
read_fortran_namelist
read_fortran_namelist, dict_to_input_lines
from ._wann2kc import KEYS as W2KCW_KEYS

from ase.calculators.espresso import KoopmansHam
Expand Down Expand Up @@ -38,14 +38,7 @@ def write_koopmans_ham_in(fd, atoms, input_data=None, pseudopotentials=None,
continue

lines.append('&{0}\n'.format(section.upper()))
for key, value in input_parameters[section].items():
if value is True:
lines.append(' {0:16} = .true.\n'.format(key))
elif value is False:
lines.append(' {0:16} = .false.\n'.format(key))
elif value is not None:
# repr format to get quotes around strings
lines.append(' {0:16} = {1!r:}\n'.format(key, value))
lines += dict_to_input_lines(input_parameters[section])
lines.append('/\n') # terminate section

if input_parameters['HAM'].get('do_bands', True):
Expand Down
11 changes: 2 additions & 9 deletions ase/io/espresso/_koopmans_screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ase import Atoms
from ase.calculators.singlepoint import SinglePointDFTCalculator
from ase.utils import basestring
from ._utils import read_fortran_namelist, generic_construct_namelist, time_to_float, units
from ._utils import read_fortran_namelist, generic_construct_namelist, time_to_float, units, dict_to_input_lines
from ._wann2kc import KEYS as W2KCW_KEYS
from ase.calculators.espresso import KoopmansScreen

Expand All @@ -34,14 +34,7 @@ def write_koopmans_screen_in(fd, atoms, input_data=None, **kwargs):
continue

lines.append('&{0}\n'.format(section.upper()))
for key, value in input_parameters[section].items():
if value is True:
lines.append(' {0:16} = .true.\n'.format(key))
elif value is False:
lines.append(' {0:16} = .false.\n'.format(key))
elif value is not None:
# repr format to get quotes around strings
lines.append(' {0:16} = {1!r:}\n'.format(key, value))
lines += dict_to_input_lines(input_parameters[section])
lines.append('/\n') # terminate section

fd.writelines(lines)
Expand Down
13 changes: 2 additions & 11 deletions ase/io/espresso/_ph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from ase.utils import basestring
from ase.atoms import Atoms
from ._utils import read_fortran_namelist, time_to_float
from ._utils import read_fortran_namelist, time_to_float, dict_to_input_lines
from ase.calculators.espresso import EspressoPh


Expand Down Expand Up @@ -70,16 +70,7 @@ def write_ph_in(fd, atoms, **kwargs):
all_parameters = dict(**atoms.calc.parameters, **masses)
all_parameters.pop('pseudopotentials', None)

for key, value in all_parameters.items():
if value is True:
ph.append(' {0:16} = .true.\n'.format(key))
elif value is False:
ph.append(' {0:16} = .false.\n'.format(key))
elif value is not None:
if isinstance(value, Path):
value = str(value)
# repr format to get quotes around strings
ph.append(' {0:16} = {1!r:}\n'.format(key, value))
ph += dict_to_input_lines(all_parameters)
ph.append('/\n')
ph.append('0.0 0.0 0.0')

Expand Down
13 changes: 2 additions & 11 deletions ase/io/espresso/_projwfc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ase.atoms import Atoms
from ase.calculators.espresso import Projwfc
from ._utils import read_fortran_namelist, time_to_float
from ._utils import read_fortran_namelist, time_to_float, dict_to_input_lines
from pathlib import Path


Expand Down Expand Up @@ -56,16 +56,7 @@ def write_projwfc_in(fd, atoms, **kwargs):
"""

projwf = ['&projwfc\n']
for key, value in atoms.calc.parameters.items():
if value is True:
projwf.append(f' {key:16} = .true.\n')
elif value is False:
projwf.append(f' {key:16} = .false.\n')
elif value is not None:
if isinstance(value, Path):
value = str(value)
# repr format to get quotes around strings
projwf.append(f' {key:16} = {value!r:}\n')
projwf += dict_to_input_lines(atoms.calc.parameters)
projwf.append('/\n')

fd.write(''.join(projwf))
Expand Down
26 changes: 18 additions & 8 deletions ase/io/espresso/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,14 +1384,7 @@ def write_espresso_in(fd, atoms, input_data=None, pseudopotentials=None,
# and that repr converts to a QE readable representation (except bools)
for section in input_parameters:
flines.append('&{0}\n'.format(section.upper()))
for key, value in input_parameters[section].items():
if value is True:
flines.append(' {0:16} = .true.\n'.format(key))
elif value is False:
flines.append(' {0:16} = .false.\n'.format(key))
elif value is not None:
# repr format to get quotes around strings
flines.append(' {0:16} = {1!r:}\n'.format(key, value))
flines += dict_to_input_lines(input_parameters[section])
flines.append('/\n') # terminate section
flines.append('\n')

Expand Down Expand Up @@ -1425,3 +1418,20 @@ def write_espresso_in(fd, atoms, input_data=None, pseudopotentials=None,

# DONE!
fd.write(''.join(flines))


def dict_to_input_lines(dct):
out = []
for key, value in dct.items():
if value is not None:
if isinstance(value, Path):
value = str(value)
if isinstance(value, str):
# add quotes around strings
value = f"'{value}'"
if value is True:
value = '.true.'
if value is False:
value = '.false.'
out.append(f' {key:16} = {value}\n')
return out
14 changes: 4 additions & 10 deletions ase/io/espresso/_wann2kc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from ase.atoms import Atoms
from ase.calculators.singlepoint import SinglePointDFTCalculator
from ase.utils import basestring
from ._utils import Namelist, read_fortran_namelist, generic_construct_namelist, time_to_float
from ._utils import Namelist, read_fortran_namelist, generic_construct_namelist, time_to_float, dict_to_input_lines
from ase.calculators.espresso import Wann2KC


KEYS = Namelist((
('control', ['prefix', 'outdir', 'kcw_iverbosity', 'kcw_at_ks', 'calculation', 'lrpa',
'mp1', 'mp2', 'mp3', 'homo_only', 'read_unitary_matrix', 'l_vcut', 'assume_isolated', 'spin_component']),
'mp1', 'mp2', 'mp3', 'homo_only', 'read_unitary_matrix', 'l_vcut', 'assume_isolated',
'spin_component']),
('wannier', ['seedname', 'check_ks', 'num_wann_occ', 'num_wann_emp', 'have_empty', 'has_disentangle'])))


Expand All @@ -32,14 +33,7 @@ def write_wann2kc_in(fd, atoms, input_data=None, pseudopotentials=None,
for section in input_parameters:
assert section in KEYS.keys()
lines.append('&{0}\n'.format(section.upper()))
for key, value in input_parameters[section].items():
if value is True:
lines.append(' {0:16} = .true.\n'.format(key))
elif value is False:
lines.append(' {0:16} = .false.\n'.format(key))
elif value is not None:
# repr format to get quotes around strings
lines.append(' {0:16} = {1!r:}\n'.format(key, value))
lines += dict_to_input_lines(input_parameters[section])
lines.append('/\n') # terminate section

fd.writelines(lines)
Expand Down
13 changes: 2 additions & 11 deletions ase/io/espresso/_x2y.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from ase.utils import basestring
from ase.atoms import Atoms
from ._utils import read_fortran_namelist, time_to_float
from ._utils import read_fortran_namelist, time_to_float, dict_to_input_lines


def read_x2y_in(fileobj, calc_class):
Expand Down Expand Up @@ -60,16 +60,7 @@ def write_x2y_in(fd, atoms, **kwargs):
"""

x2y = ['&inputpp\n']
for key, value in atoms.calc.parameters.items():
if value is True:
x2y.append(f' {key:16} = .true.\n')
elif value is False:
x2y.append(f' {key:16} = .false.\n')
elif value is not None:
if isinstance(value, Path):
value = str(value)
# repr format to get quotes around strings
x2y.append(f' {key:16} = {value!r:}\n')
x2y += dict_to_input_lines(atoms.calc.parameters)
x2y.append('/\n')

fd.write(''.join(x2y))
Expand Down
26 changes: 13 additions & 13 deletions ase/spacegroup/xtal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from scipy import spatial

import ase
from ase.symbols import string2symbols
from ase.spacegroup import Spacegroup
from ase.geometry import cellpar_to_cell
from ase.spacegroup import Spacegroup
from ase.symbols import string2symbols

__all__ = ['crystal']

Expand Down Expand Up @@ -103,9 +103,9 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1,
"""
sg = Spacegroup(spacegroup, setting)
if (not isinstance(symbols, str) and
hasattr(symbols, '__getitem__') and
len(symbols) > 0 and
isinstance(symbols[0], ase.Atom)):
hasattr(symbols, '__getitem__') and
len(symbols) > 0 and
isinstance(symbols[0], ase.Atom)):
symbols = ase.Atoms(symbols)
if isinstance(symbols, ase.Atoms):
basis = symbols
Expand All @@ -117,25 +117,25 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1,
if symbols is None:
symbols = basis.get_chemical_symbols()
else:
basis_coords = np.array(basis, dtype=float, copy=False, ndmin=2)
basis_coords = np.array(basis, dtype=float, ndmin=2)

if occupancies is not None:
occupancies_dict = {}

for index, coord in enumerate(basis_coords):
# Compute all distances and get indices of nearest atoms
dist = spatial.distance.cdist(coord.reshape(1, 3), basis_coords)
indices_dist = np.flatnonzero(dist < symprec)

occ = {symbols[index]: occupancies[index]}

# Check nearest and update occupancy
for index_dist in indices_dist:
if index == index_dist:
continue
else:
occ.update({symbols[index_dist]: occupancies[index_dist]})

occupancies_dict[index] = occ.copy()

sites, kinds = sg.equivalent_sites(basis_coords,
Expand All @@ -154,7 +154,7 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1,
symbols = [symbols[i] for i in kinds]
else:
# make sure that we put the dominant species there
symbols = [sorted(occupancies_dict[i].items(), key=lambda x : x[1])[-1][0] for i in kinds]
symbols = [sorted(occupancies_dict[i].items(), key=lambda x: x[1])[-1][0] for i in kinds]

if cell is None:
cell = cellpar_to_cell(cellpar, ab_normal, a_direction)
Expand All @@ -164,7 +164,7 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1,
info['unit_cell'] = 'primitive'
else:
info['unit_cell'] = 'conventional'

if 'info' in kwargs:
info.update(kwargs['info'])

Expand All @@ -186,7 +186,7 @@ def crystal(symbols=None, basis=None, occupancies=None, spacegroup=1, setting=1,
array = basis.get_array(name)
atoms.new_array(name, [array[i] for i in kinds],
dtype=array.dtype, shape=array.shape[1:])

if kinds:
atoms.new_array('spacegroup_kinds', np.asarray(kinds, dtype=int))

Expand Down

0 comments on commit 6811c59

Please sign in to comment.