Skip to content

Commit

Permalink
enabled custom check_empty for SileBound
Browse files Browse the repository at this point in the history
And implemented this for read_scf.
Now there is a slicer for MD steps in the read_scf
cycle.

Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Mar 1, 2024
1 parent 862e55d commit 929631a
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 100 deletions.
25 changes: 15 additions & 10 deletions src/sisl/io/_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
func: Func,
key: Type[Any],
*,
check_empty: Optional[Func] = None,
skip_func: Optional[Func] = None,
postprocess: Optional[Callable[..., Any]] = None,
):
Expand All @@ -69,6 +70,16 @@ def __init__(
self.skip_func = func
else:
self.skip_func = skip_func

if check_empty is None:

def check_empty(r):
if isinstance(r, tuple):
return reduce(lambda x, y: x and y is None, r, True)
return r is None

self.check_empty = check_empty

if postprocess is None:

def postprocess(ret):
Expand All @@ -90,11 +101,6 @@ def __call__(self, *args, **kwargs):

inf = 100000000000000

def check_none(r):
if isinstance(r, tuple):
return reduce(lambda x, y: x and y is None, r, True)
return r is None

# Determine whether we can reduce the call overheads
start = 0
stop = inf
Expand Down Expand Up @@ -130,7 +136,7 @@ def check_none(r):

# now do actual parsing
retval = func(obj, *args, **kwargs)
while not check_none(retval):
while not self.check_empty(retval):
append(retval)
if len(retvals) >= stop:
# quick exit
Expand All @@ -142,11 +148,10 @@ def check_none(r):
return None

# ensure the next call won't use this key
# This will prohibit the use
# This will enable the use
# tmp = sile.read_geometry[:10]
# tmp() # will return the first 10
# tmp() # will return the default (single) item
self.key = None
# tmp() # will returns the first 10
# tmp() # will returns the next 10
if isinstance(key, Integral):
return retvals[key]

Expand Down
153 changes: 74 additions & 79 deletions src/sisl/io/siesta/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
import os
from functools import lru_cache
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -50,6 +51,56 @@ def _parse_spin(attr, match):
return Spin()


def _read_scf_empty(scf):
if isinstance(scf, tuple):
return len(scf[0]) == 0

Check warning on line 56 in src/sisl/io/siesta/stdout.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/siesta/stdout.py#L56

Added line #L56 was not covered by tests
return len(scf) == 0


def _read_scf_md_process(scfs):

if len(scfs) == 0:
return None

Check warning on line 63 in src/sisl/io/siesta/stdout.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/siesta/stdout.py#L63

Added line #L63 was not covered by tests

if not isinstance(scfs, list):
# single MD request either as:
# - np.ndarray
# - np.ndarray, tuple
# - pd.DataFrame
return scfs

Check warning on line 70 in src/sisl/io/siesta/stdout.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/siesta/stdout.py#L70

Added line #L70 was not covered by tests

has_props = isinstance(scfs[0], tuple)
if has_props:
my_len = lambda scf: len(scf[0])

Check warning on line 74 in src/sisl/io/siesta/stdout.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/siesta/stdout.py#L74

Added line #L74 was not covered by tests
else:
my_len = len

scf_len1 = np.all(_a.fromiterd(map(my_len, scfs)) == 1)
if isinstance(scfs[0], (np.ndarray, tuple)):

if has_props:
props = scfs[0][1]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable props is not used.
scfs = [scf[0] for scf in scfs]

Check warning on line 83 in src/sisl/io/siesta/stdout.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/siesta/stdout.py#L82-L83

Added lines #L82 - L83 were not covered by tests

if scf_len1:
scfs = np.array(scfs)

Check warning on line 86 in src/sisl/io/siesta/stdout.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/siesta/stdout.py#L86

Added line #L86 was not covered by tests
if has_props:
return scfs, prop

Check warning on line 88 in src/sisl/io/siesta/stdout.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/io/siesta/stdout.py#L88

Added line #L88 was not covered by tests
return scfs

# We are dealing with a dataframe
import pandas as pd

df = pd.concat(
scfs,
keys=_a.arangei(1, len(scfs) + 1),
names=["imd"],
)
if scf_len1:
df.reset_index("iscf", inplace=True)
return df


@set_module("sisl.io.siesta")
class stdoutSileSiesta(SileSiesta):
"""Output file from Siesta
Expand Down Expand Up @@ -663,28 +714,32 @@ def read_data(self, *args, **kwargs):
val = val[0]
return val

@sile_fh_open(True)
@SileBinder(
default_slice=-1, check_empty=_read_scf_empty, postprocess=_read_scf_md_process
)
@sile_fh_open()
def read_scf(
self, key="scf", iscf=-1, imd=None, as_dataframe=False, ret_header=False
self,
key: str = "scf",
iscf: Optional[int] = -1,
as_dataframe: bool = False,
ret_header: bool = False,
):
r"""Parse SCF information and return a table of SCF information depending on what is requested
Parameters
----------
key : {'scf', 'ts-scf'}
parse SCF information from Siesta SCF or TranSiesta SCF
iscf : int, optional
iscf :
which SCF cycle should be stored. If ``-1`` only the final SCF step is stored,
for None *all* SCF cycles are returned. When `iscf` values queried are not found they
will be truncated to the nearest SCF step.
imd: int or None, optional
whether only a particular MD step is queried, if None, all MD steps are
parsed and returned. A negative number wraps for the last MD steps.
as_dataframe: boolean, optional
as_dataframe:
whether the information should be returned as a `pandas.DataFrame`. The advantage of this
format is that everything is indexed and therefore you know what each value means.You can also
perform operations very easily on a dataframe.
ret_header: bool, optional
ret_header:
whether to also return the headers that define each value in the returned array,
will have no effect if `as_dataframe` is true.
"""
Expand All @@ -697,11 +752,6 @@ def read_scf(
raise ValueError(
f"{self.__class__.__name__}.read_scf requires iscf argument to *not* be 0!"
)
if not imd is None:
if imd == 0:
raise ValueError(
f"{self.__class__.__name__}.read_scf requires imd argument to *not* be 0!"
)

def reset_d(d, line):
if line.startswith("SCF cycle converged") or line.startswith(
Expand Down Expand Up @@ -857,7 +907,6 @@ def construct_data(d, data):
data.extend(d[key])
d["data"] = data

md = []
scf = []
for line in self:
parse_next(line, d)
Expand All @@ -869,6 +918,7 @@ def construct_data(d, data):

if iscf is None or iscf < 0:
scf.append(data)

elif data[0] <= iscf:
# this ensures we will retain the latest iscf in
# case the requested iscf is too big
Expand Down Expand Up @@ -900,80 +950,25 @@ def construct_data(d, data):
# truncate to 0
scf = scf[max(len(scf) + iscf, 0)]

# Populate md
md.append(np.array(scf))
# Reset SCF data
scf = []

# In case we wanted a given MD step and it's this one, just stop reading
# We are going to return the last MD (see below)
if imd == len(md):
break
# found a full MD
break

# Define the function that is going to convert the information of a MDstep to a Dataset
if as_dataframe:
import pandas as pd

def MDstep_dataframe(scf):
scf = np.atleast_2d(scf)
return pd.DataFrame(
scf[..., 1:],
index=pd.Index(scf[..., 0].ravel().astype(np.int32), name="iscf"),
columns=props[1:],
)

# Now we know how many MD steps there are

# We will return stuff based on what the user requested
# For pandas DataFrame this will be dependent
# 1. all MD steps requested => imd == index, iscf == column (regardless of iscf==none|int)
# 2. 1 MD step requested => iscf == index

if imd is None:
if as_dataframe:
if len(md) == 0:
# return an empty dataframe (with imd as index)
return pd.DataFrame(index=pd.Index([], name="imd"), columns=props)
# Regardless of what the user requests we will always have imd == index
# and iscf a column, a user may easily change this.
df = pd.concat(
map(MDstep_dataframe, md),
keys=_a.arangei(1, len(md) + 1),
names=["imd"],
)
if iscf is not None:
df.reset_index("iscf", inplace=True)
return df

if iscf is not None:
# since each MD step may be a different number of SCF steps
# we can only convert for a specific entry
md = np.array(md)
if ret_header:
return md, props
return md

# correct imd to ensure we check against the final size
imd = min(len(md) - 1, max(len(md) + imd, 0))
if len(md) == 0:
# no data collected
if as_dataframe:
if len(scf) == 0:
return pd.DataFrame(index=pd.Index([], name="iscf"), columns=props[1:])
md = np.array(md[imd])
if ret_header:
return md, props
return md

if imd > len(md):
raise ValueError(
f"{self.__class__.__name__}.read_scf could not find requested MD step ({imd})."
scf = np.atleast_2d(scf)
return pd.DataFrame(
scf[..., 1:],
index=pd.Index(scf[..., 0].ravel().astype(np.int32), name="iscf"),
columns=props[1:],
)

# If a certain imd was requested, get it
# Remember that if imd is positive, we stopped reading at the moment we reached it
scf = np.array(md[imd])
if as_dataframe:
return MDstep_dataframe(scf)
# Convert to numpy array
scf = np.array(scf)
if ret_header:
return scf, props
return scf
Expand Down
26 changes: 15 additions & 11 deletions src/sisl/io/siesta/tests/test_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_dir = osp.join("sisl", "io", "siesta")


@pytest.mark.only
def test_md_nose_out(sisl_files):
f = sisl_files(_dir, "md_nose.out")
out = stdoutSileSiesta(f)
Expand Down Expand Up @@ -66,19 +65,24 @@ def test_md_nose_out(sisl_files):
for S, T in zip(sstatic, stotal):
assert not np.allclose(S, T)


def test_md_nose_out_scf(sisl_files):
f = sisl_files(_dir, "md_nose.out")
out = stdoutSileSiesta(f)

# Ensure SCF reads are consistent
scf_last = out.read_scf()
scf = out.read_scf(imd=-1)
scf_last = out.read_scf[:]()
scf = out.read_scf[-1]()
assert np.allclose(scf_last[-1], scf)
for i in range(len(scf_last)):
scf = out.read_scf(imd=i + 1)
scf = out.read_scf[i]()
assert np.allclose(scf_last[i], scf)

scf_all = out.read_scf(iscf=None, imd=-1)
scf = out.read_scf(imd=-1)
scf_all = out.read_scf[-1](iscf=None)
scf = out.read_scf[-1]()
assert np.allclose(scf_all[-1], scf)
for i in range(len(scf_all)):
scf = out.read_scf(iscf=i + 1, imd=-1)
scf = out.read_scf[-1](iscf=i + 1)
assert np.allclose(scf_all[i], scf)


Expand Down Expand Up @@ -109,15 +113,15 @@ def test_md_nose_out_dataframe(sisl_files):
f = sisl_files(_dir, "md_nose.out")
out = stdoutSileSiesta(f)

data = out.read_scf()
df = out.read_scf(as_dataframe=True)
data = out.read_scf[:]()
df = out.read_scf[:](as_dataframe=True)
# this will read all MD-steps and only latest iscf
assert len(data) == len(df)
assert df.index.names == ["imd"]

df = out.read_scf(iscf=None, as_dataframe=True)
df = out.read_scf[:](iscf=None, as_dataframe=True)
assert df.index.names == ["imd", "iscf"]
df = out.read_scf(iscf=None, imd=-1, as_dataframe=True)
df = out.read_scf(iscf=None, as_dataframe=True)
assert df.index.names == ["iscf"]


Expand Down

0 comments on commit 929631a

Please sign in to comment.