Skip to content

Commit

Permalink
Merge pull request #429 from pyiron/FlattenedStorage-StoreDefaults
Browse files Browse the repository at this point in the history
Flattened storage store defaults
  • Loading branch information
Leimeroth authored Sep 10, 2021
2 parents 442e22b + 010f985 commit 4038228
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
24 changes: 20 additions & 4 deletions pyiron_base/generic/flattenedstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ class FlattenedStorage:
3
"""

__version__ = "0.1.0"
__hdf_version__ = "0.2.0"
__version__ = "0.2.0"
__hdf_version__ = "0.3.0"

def __init__(self, num_chunks=1, num_elements=1, **kwargs):
"""
Expand All @@ -150,6 +150,7 @@ def __init__(self, num_chunks=1, num_elements=1, **kwargs):
# Also store indices of chunk recently added
self.prev_chunk_index = 0
self.prev_element_index = 0
self._fill_values = {}

self._init_arrays()

Expand Down Expand Up @@ -201,24 +202,33 @@ def _get_per_element_slice(self, frame):
end = start + self._per_chunk_arrays["length"][frame]
return slice(start, end, 1)


def _resize_elements(self, new):
old_max = self._num_elements_alloc
self._num_elements_alloc = new
for k, a in self._per_element_arrays.items():
new_shape = (new,) + a.shape[1:]
try:
a.resize(new_shape)
except ValueError:
self._per_element_arrays[k] = np.resize(a, new_shape)
if old_max < new:
for k in self._per_element_arrays.keys():
if k in self._fill_values.keys():
self._per_element_arrays[k][old_max:] = self._fill_values[k]

def _resize_chunks(self, new):
old_max = self._num_chunks_alloc
self._num_chunks_alloc = new
for k, a in self._per_chunk_arrays.items():
new_shape = (new,) + a.shape[1:]
try:
a.resize(new_shape)
except ValueError:
self._per_chunk_arrays[k] = np.resize(a, new_shape)
if old_max < new:
for k in self._per_chunk_arrays.keys():
if k in self._fill_values.keys():
self._per_chunk_arrays[k][old_max:] = self._fill_values[k]

def add_array(self, name, shape=(), dtype=np.float64, fill=None, per="element"):
"""
Expand Down Expand Up @@ -285,6 +295,7 @@ def add_array(self, name, shape=(), dtype=np.float64, fill=None, per="element"):
store[name] = np.empty(shape=shape, dtype=dtype)
else:
store[name] = np.full(shape=shape, fill_value=fill, dtype=dtype)
self._fill_values[name] = fill

def get_array(self, name, frame):
"""
Expand Down Expand Up @@ -495,6 +506,8 @@ def write_array(name, array, hdf):
for k, a in self._per_chunk_arrays.items():
write_array(k, a, hdf_arrays)

hdf_s_lst["_fill_values"] = self._fill_values

def from_hdf(self, hdf, group_name="flat_storage"):

def read_array(name, hdf):
Expand Down Expand Up @@ -529,7 +542,7 @@ def read_array(name, hdf):
self._per_element_arrays[k] = a
elif a.shape[0] == self._num_chunks_alloc:
self._per_chunk_arrays[k] = a
elif version == "0.2.0":
elif version == "0.2.0" or "0.3.0":
with hdf_s_lst.open("element_arrays") as hdf_arrays:
for k in hdf_arrays.list_nodes():
self._per_element_arrays[k] = read_array(k, hdf_arrays)
Expand All @@ -547,3 +560,6 @@ def read_array(name, hdf):
if a.shape[0] != self._num_elements_alloc:
raise RuntimeError(f"per-element array {k} read inconsistently from HDF: "
f"shape {a.shape[0]} does not match global allocation {self._num_elements_alloc}!")

if version >= "0.3.0":
self._fill_values = hdf_s_lst["_fill_values"]
38 changes: 38 additions & 0 deletions tests/generic/test_flattenedstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,41 @@ def test_hdf_chunklength_one(self):
self.assertEqual(store.get_array("bar", i), read.get_array("bar", i),
"per chunk values not equal after reading from HDF!")

def test_fill_value(self):
"""Test if fill values are correctly assigned when resizing an array and if self._fill_value is correctly read from hdf."""
# Test for per chunk arrays
store = FlattenedStorage()
store.add_array("bar", per="chunk", dtype=bool, fill=True)
store.add_array("foo", per="chunk")
for i in range(3):
store.add_chunk(1, bar=False, foo=i)
store._resize_chunks(6)
self.assertTrue(np.all(store._per_chunk_arrays["bar"][:3]==False), "value is overwritten when resizing")
self.assertTrue(np.all(store._per_chunk_arrays["bar"][3:]==True), "fill value is not correctly set when resizing")
self.assertTrue(np.all(store._per_chunk_arrays["foo"][0:3]==np.array((0,1,2))), "values in array changed on resizing")
# Test for per element arrays
store = FlattenedStorage()
store.add_array("bar", per="element", fill=np.nan)
store.add_array("foo", per="element")
for i in range(1,4):
store.add_chunk(i*2, bar=i*[i, i**2], foo=i*[i, i**2])
store._resize_elements(15)
self.assertTrue(np.all(store._per_element_arrays["foo"][:12]==store._per_element_arrays["bar"][:12]), "arrays are not equal up to resized part")
self.assertTrue(np.all(np.isnan(store._per_element_arrays["bar"][12:])), "array should np.nan where not set")
# Test hdf
store = FlattenedStorage()
store.add_array("bar", per="element", fill=np.nan)
store.add_array("foo", per="element")
store.add_array("fooTrue", per="chunk", dtype=bool, fill=True)
store.add_array("barText", per="chunk", dtype="U4", fill="fill")
hdf = self.project.create_hdf(self.project.path, "test_fill_values")
store.to_hdf(hdf)
read=FlattenedStorage()
read.from_hdf(hdf)
# normally it is possible to compare 2 dicts using ==, but np.nan!=np.nan so this has to be explicitly tested.
for k, v in store._fill_values.items():
if isinstance(v, float) and np.isnan(v):
self.assertTrue(np.isnan(read._fill_values[k]))
else:
self.assertEqual(v, read._fill_values[k], "value read from hdf differs from original value")
self.assertEqual(read._fill_values.keys(), store._fill_values.keys(), "keys read from hdf differ from original keys")

0 comments on commit 4038228

Please sign in to comment.