Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added typing in lots of io places, added SileSlicer #695

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ we hit release version 1.0.0.
- removed `Selector` and `TimeSelector`, they were never used internally

### Changed
- `stdoutSileSiesta.read_*` now defaults to read the *next* entry, and not the last
- `stdoutSileSiesta.read_*` changed MD output functionality, see #586 for details
- `AtomNeighbours` changed name to `AtomNeighbor` to follow #393
- removed `Lattice.translate|move`, they did not make sense, and so their
usage should be deferred to `Lattice.add` instead.
Expand All @@ -66,7 +68,7 @@ we hit release version 1.0.0.

### Added
- Creation of honeycomb flakes (`sisl.geom.honeycomb_flake`,
`sisl.geom.graphene_flake`). #636
`sisl.geom.graphene_flake`), #636
- added `Geometry.as_supercell` to create the supercell structure,
thanks to @pfebrer for the suggestion
- added `Lattice.to` and `Lattice.new` to function the same
Expand Down
4 changes: 2 additions & 2 deletions src/sisl/_core/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sisl._dispatcher import AbstractDispatch, ClassDispatcher, TypeDispatcher
from sisl._internal import set_module
from sisl._math_small import cross3, dot3
from sisl.messages import SislError, deprecate, deprecation, info, warn
from sisl.messages import SislError, deprecate, deprecation, warn
from sisl.shape.prism4 import Cuboid
from sisl.typing import Axes, Axies, Axis
from sisl.utils.mathematics import fnorm
Expand Down Expand Up @@ -274,7 +274,7 @@ def conv(v):
"must have that BC."
)
if changed.any() and (~bc).all() and nsc > 1:
info(
warn(
f"{self.__class__.__name__}.set_boundary_condition is having image connections (nsc={nsc}>1) "
"while having a non-periodic boundary condition."
)
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/_core/tests/test_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@

def test_lattice_info():
lat = Lattice(1, nsc=[3, 3, 3])
with pytest.warns(sisl.SislInfo) as record:
with pytest.warns(sisl.SislWarning) as record:

Check warning on line 566 in src/sisl/_core/tests/test_lattice.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/_core/tests/test_lattice.py#L566

Added line #L566 was not covered by tests
lat.set_boundary_condition(b=Lattice.BC.DIRICHLET)
lat.set_boundary_condition(c=Lattice.BC.PERIODIC)
assert len(record) == 1
55 changes: 45 additions & 10 deletions src/sisl/io/_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,43 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from functools import reduce, update_wrapper
from itertools import zip_longest
from numbers import Integral
from textwrap import dedent
from typing import Any, Callable, Optional, Type

Func = Callable[..., Optional[Any]]


def postprocess_tuple(*funcs):
"""Post-processes the returned value according to the funcs for multiple data

The internal algorithm will use zip_longest to apply the *last* `funcs[-1]` to
the remaining return functions. Useful if there are many tuples that can
be handled equivalently:

Examples
--------
>>> postprocess_tuple(np.array) == postprocess_tuple(np.array, np.array)
"""

def post(ret):
nonlocal funcs
if isinstance(ret[0], tuple):
return tuple(
func(r)
for r, func in zip_longest(zip(*ret), funcs, fillvalue=funcs[-1])
)
return funcs[0](ret)

return post


def is_sliceable(method: Func):
"""Check whether a function is implemented with the `SileBinder` decorator"""
return isinstance(method, (SileBinder, SileBound))

Check warning on line 39 in src/sisl/io/_multiple.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/_multiple.py#L39

Added line #L39 was not covered by tests


class SileSlicer:
"""Handling io-methods in sliced behaviour for multiple returns

Expand All @@ -26,6 +56,7 @@
func: Func,
key: Type[Any],
*,
check_empty: Optional[Func] = None,
skip_func: Optional[Func] = None,
postprocess: Optional[Callable[..., Any]] = None,
):
Expand All @@ -39,6 +70,16 @@
self.skip_func = func
else:
self.skip_func = skip_func

if check_empty is None:

def check_empty(r):
if isinstance(r, tuple):
return reduce(lambda x, y: x and y is None, r, True)
return r is None

self.check_empty = check_empty

if postprocess is None:

def postprocess(ret):
Expand All @@ -60,11 +101,6 @@

inf = 100000000000000

def check_none(r):
if isinstance(r, tuple):
return reduce(lambda x, y: x and y is None, r, True)
return r is None

# Determine whether we can reduce the call overheads
start = 0
stop = inf
Expand Down Expand Up @@ -100,7 +136,7 @@

# now do actual parsing
retval = func(obj, *args, **kwargs)
while not check_none(retval):
while not self.check_empty(retval):
append(retval)
if len(retvals) >= stop:
# quick exit
Expand All @@ -112,11 +148,10 @@
return None

# ensure the next call won't use this key
# This will prohibit the use
# This will enable the use
# tmp = sile.read_geometry[:10]
# tmp() # will return the first 10
# tmp() # will return the default (single) item
self.key = None
# tmp() # will returns the first 10
# tmp() # will returns the next 10
if isinstance(key, Integral):
return retvals[key]

Expand Down
48 changes: 28 additions & 20 deletions src/sisl/io/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,27 @@ class cubeSile(Sile):
)
def write_lattice(
self,
lattice,
fmt="15.10e",
lattice: Lattice,
fmt: str = "15.10e",
size=None,
origin=None,
unit="Bohr",
unit: str = "Bohr",
*args,
**kwargs,
):
"""Writes `Lattice` object attached to this grid
Parameters
----------
lattice : Lattice
lattice :
lattice to be written
fmt : str, optional
fmt :
floating point format for stored values
size : (3, ), optional
shape of the stored grid (``[1, 1, 1]``)
origin : (3, ), optional
origin of the cell (``[0, 0, 0]``)
unit: str, optional
unit:
what length unit should the cube file data be written in
"""
sile_raise_write(self)
Expand Down Expand Up @@ -83,27 +83,27 @@ def write_lattice(
@sile_fh_open()
def write_geometry(
self,
geometry,
fmt="15.10e",
geometry: Geometry,
fmt: str = "15.10e",
size=None,
origin=None,
unit="Bohr",
unit: str = "Bohr",
*args,
**kwargs,
):
"""Writes `Geometry` object attached to this grid
Parameters
----------
geometry : Geometry
geometry :
geometry to be written
fmt : str, optional
fmt :
floating point format for stored values
size : (3, ), optional
shape of the stored grid (``[1, 1, 1]``)
origin : (3, ), optional
origin of the cell (``[0, 0, 0]``)
unit: str, optional
unit:
what length unit should the cube file data be written in
"""
sile_raise_write(self)
Expand Down Expand Up @@ -140,24 +140,32 @@ def write_geometry(
)

@sile_fh_open()
def write_grid(self, grid, fmt=".5e", imag=False, unit="Bohr", *args, **kwargs):
def write_grid(
self,
grid: Grid,
fmt: str = ".5e",
imag: bool = False,
unit: str = "Bohr",
*args,
**kwargs,
):
"""Write `Grid` to the contained file
Parameters
----------
grid : Grid
grid :
the grid to be written in the CUBE file
fmt : str, optional
fmt :
format used for precision output
imag : bool, optional
imag :
write only imaginary part of the grid, default to only writing the
real part.
buffersize : int, optional
size of the buffer while writing the data, (6144)
unit: str, optional
unit:
what length unit should the cube file data be written in.
The grid data is assumed to be unit-less, this unit only refers
to the lattice vectors and atomic coordinates.
buffersize : int, optional
size of the buffer while writing the data, (6144)
"""
# Check that we can write to the file
sile_raise_write(self)
Expand Down Expand Up @@ -220,7 +228,7 @@ def _r_header_dict(self):
return header

@sile_fh_open()
def read_lattice(self, na=False):
def read_lattice(self, na: bool = False):
"""Returns `Lattice` object from the CUBE file
Parameters
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/io/siesta/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def read_fermi_level(self):
return float(self.readline())

@sile_fh_open()
def read_data(self, as_dataarray=False):
def read_data(self, as_dataarray: bool = False):
"""Returns data associated with the bands file
The energy levels are shifted with respect to the Fermi-level.
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/io/siesta/fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_force(self):
return f

@sile_fh_open()
def write_force(self, f, fmt=".9e"):
def write_force(self, f, fmt: str = ".9e"):
"""Write forces to file
Parameters
Expand Down
14 changes: 9 additions & 5 deletions src/sisl/io/siesta/fc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from typing import Optional

import numpy as np

from sisl._internal import set_module
Expand All @@ -18,7 +20,9 @@ class fcSileSiesta(SileSiesta):
"""Force constant file"""

@sile_fh_open()
def read_force(self, displacement=None, na=None):
def read_force(
self, displacement: Optional[float] = None, na: Optional[int] = None
):
"""Reads all displacement forces by multiplying with the displacement value
Since the force constant file does not contain the non-displaced configuration
Expand All @@ -33,12 +37,12 @@ def read_force(self, displacement=None, na=None):
Parameters
----------
displacement : float, optional
displacement :
the used displacement in the calculation, since Siesta 4.1-b4 this value
is written in the FC file and hence not required.
If prior Siesta versions are used and this is not supplied the 0.04 Bohr displacement
will be assumed.
na : int, optional
na :
number of atoms in geometry (for returning correct number of atoms), since Siesta 4.1-b4
this value is written in the FC file and hence not required.
If prior Siesta versions are used then the file is expected to only contain 1-atom displacement.
Expand Down Expand Up @@ -67,12 +71,12 @@ def read_force(self, displacement=None, na=None):
return self.read_force_constant(na) * displacement.reshape(1, 3, 2, 1, 1)

@sile_fh_open()
def read_force_constant(self, na=None):
def read_force_constant(self, na: Optional[int] = None):
"""Reads the force-constant stored in the FC file
Parameters
----------
na : int, optional
na :
number of atoms in the unit-cell, if not specified it will guess on only
one atom displacement.
Expand Down
Loading