Skip to content

Commit

Permalink
Add field and observations to state storage test
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Feb 2, 2024
1 parent 848c2be commit 65404c7
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 47 deletions.
5 changes: 5 additions & 0 deletions tests/unit_tests/config/egrid_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,11 @@ class EGrid:
file_head: Filehead
global_grid: GlobalGrid

@property
def shape(self) -> Tuple[int, int, int]:
grid_head = self.global_grid.grid_head
return (grid_head.num_x, grid_head.num_y, grid_head.num_z)

def to_file(
self,
filelike,
Expand Down
25 changes: 0 additions & 25 deletions tests/unit_tests/storage/test_local_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
from datetime import datetime

import numpy as np
import pytest
import xarray as xr
Expand All @@ -11,8 +8,6 @@
from ert.config.field import Field
from ert.field_utils import FieldFileFormat
from ert.storage import open_storage
from ert.storage.local_ensemble import _Failure
from ert.storage.realization_storage_state import RealizationStorageState


def test_that_egrid_files_are_saved_and_loaded_correctly(tmp_path):
Expand Down Expand Up @@ -160,23 +155,3 @@ def test_that_loading_parameter_via_response_api_fails(tmp_path):
)
with pytest.raises(ValueError, match="PARAMETER is not a response"):
prior.load_responses("PARAMETER", (0,))


def test_get_failure(tmp_path):
with open_storage(tmp_path, mode="w") as storage:
experiment = storage.create_experiment()
ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=1)

error_file = ensemble._path / "realization-9" / ensemble._error_log_name
os.makedirs(os.path.dirname(error_file), exist_ok=True)

error = _Failure(
type=RealizationStorageState.PARENT_FAILURE,
message="Something is wrong",
time=datetime.now(),
)
with open(error_file, mode="w", encoding="utf-8") as f:
print(error.json(), file=f)

assert ensemble.get_failure(9) == error
assert ensemble.get_failure(8) is None
170 changes: 148 additions & 22 deletions tests/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@
import tempfile
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List
from pathlib import Path
from typing import Any, Dict, List
from unittest.mock import patch
from uuid import UUID

import hypothesis.strategies as st
import numpy as np
import pytest
from hypothesis.stateful import Bundle, RuleBasedStateMachine, rule
import xarray as xr
from hypothesis.extra.numpy import arrays
from hypothesis.stateful import Bundle, RuleBasedStateMachine, initialize, rule

from ert.config import (
EnkfObs,
Field,
GenDataConfig,
GenKwConfig,
ParameterConfig,
ResponseConfig,
SummaryConfig,
SurfaceConfig,
)
from ert.config.enkf_observation_implementation_type import (
EnkfObservationImplementationType,
)
from ert.config.general_observation import GenObservation
from ert.config.observation_vector import ObsVector
from ert.storage import StorageReader, open_storage
from ert.storage import local_storage as local
from ert.storage.realization_storage_state import RealizationStorageState
from tests.unit_tests.config.egrid_generator import egrids
from tests.unit_tests.config.summary_generator import summary_keys


Expand Down Expand Up @@ -90,17 +102,6 @@ def test_to_accessor(tmp_path):
storage_reader.to_accessor()


@st.composite
def refcase(draw):
datetimes = draw(st.lists(st.datetimes()))
container_type = draw(st.sampled_from([set(), list(), None]))
if isinstance(container_type, set):
return set(datetimes)
elif isinstance(container_type, list):
return [str(date) for date in datetimes]
return None


parameter_configs = st.lists(
st.one_of(
st.builds(
Expand All @@ -124,21 +125,88 @@ def refcase(draw):
input_file=st.text(
alphabet=st.characters(min_codepoint=65, max_codepoint=90)
),
keys=st.lists(summary_keys),
refcase=refcase(),
keys=st.lists(summary_keys, max_size=3),
refcase=st.just(None),
),
),
unique_by=lambda x: x.name,
)

ensemble_sizes = st.integers(min_value=1, max_value=1000)
coordinates = st.integers(min_value=1, max_value=100)


words = st.text(
min_size=1, max_size=8, alphabet=st.characters(min_codepoint=65, max_codepoint=90)
)


gen_observations = st.integers(min_value=1, max_value=10).flatmap(
lambda size: st.builds(
GenObservation,
values=arrays(np.double, shape=size),
stds=arrays(
np.double,
elements=st.floats(min_value=0.1, max_value=1.0),
shape=size,
),
indices=arrays(
np.int64,
elements=st.integers(min_value=0, max_value=100),
shape=size,
),
std_scaling=arrays(np.double, shape=size),
)
)


@st.composite
def observation_dicts(draw):
return {draw(st.integers(min_value=0, max_value=200)): draw(gen_observations)}


observations = st.builds(
EnkfObs,
obs_vectors=st.dictionaries(
words,
st.builds(
ObsVector,
observation_type=st.just(EnkfObservationImplementationType.GEN_OBS),
observation_key=words,
data_key=words,
observations=observation_dicts(),
),
),
)


@dataclass
class Experiment:
ensembles: List[UUID] = field(default_factory=list)
ensembles: Dict[UUID, Dict[str, Any]] = field(default_factory=dict)
parameters: List[ParameterConfig] = field(default_factory=list)
responses: List[ResponseConfig] = field(default_factory=list)
observations: Dict[str, xr.Dataset] = field(default_factory=dict)


@st.composite
def fields(draw, egrid) -> List[Field]:
grid_file, grid = egrid
nx, ny, nz = grid.shape
return [
draw(
st.builds(
Field,
name=st.just(f"Field{i}"),
file_format=st.just("roff_binary"),
grid_file=st.just(grid_file),
nx=st.just(nx),
ny=st.just(ny),
nz=st.just(nz),
output_file=st.just(Path(f"field{i}.roff")),
)
)
for i in range(10)
]


class StatefulTest(RuleBasedStateMachine):
Expand All @@ -153,6 +221,21 @@ def __init__(self):
experiment_ids = Bundle("experiments")
ensemble_ids = Bundle("ensembles")
failures = Bundle("failures")
field_list = Bundle("field_list")
grid = Bundle("grid")

@initialize(target=grid, egrid=egrids)
def create_grid(self, egrid):
grid_file = self.tmpdir + "/grid.egrid"
egrid.to_file(grid_file)
return (grid_file, egrid)

@initialize(
target=field_list,
fields=grid.flatmap(fields),
)
def create_field_list(self, fields):
return fields

@rule()
def double_open_timeout(self):
Expand All @@ -171,25 +254,65 @@ def reopen(self):

@rule(
target=experiment_ids,
parameters=parameter_configs,
parameters=st.one_of(parameter_configs, field_list),
responses=response_configs,
obs=observations,
)
def create_experiment(
self, parameters: List[ParameterConfig], responses: List[ResponseConfig]
self,
parameters: List[ParameterConfig],
responses: List[ResponseConfig],
obs: EnkfObs,
):
experiment_id = self.storage.create_experiment(
parameters=parameters, responses=responses
parameters=parameters, responses=responses, observations=obs.datasets
).id
self.experiments[experiment_id].parameters = parameters
self.experiments[experiment_id].responses = responses
self.experiments[experiment_id].observations = obs.datasets

# Ensure that there is at least one ensemble in the experiment
# to avoid https://github.com/equinor/ert/issues/7040
ensemble = self.storage.create_ensemble(experiment_id, ensemble_size=1)
self.experiments[experiment_id].ensembles.append(ensemble.id)
self.experiments[experiment_id].ensembles[ensemble.id] = {}

return experiment_id

@rule(
ensemble_id=ensemble_ids,
field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)),
)
def save_field(self, ensemble_id: UUID, field_data):
ensemble = self.storage.get_ensemble(ensemble_id)
experiment_id = ensemble.experiment_id
parameters = self.experiments[experiment_id].parameters
fields = [p for p in parameters if isinstance(p, Field)]
for f in fields:
self.experiments[experiment_id].ensembles[ensemble_id][f.name] = field_data
ensemble.save_parameters(
f.name,
1,
xr.DataArray(
field_data,
name="values",
dims=["x", "y", "z"], # type: ignore
).to_dataset(),
)

@rule(
ensemble_id=ensemble_ids,
)
def get_field(self, ensemble_id: UUID):
ensemble = self.storage.get_ensemble(ensemble_id)
experiment_id = ensemble.experiment_id
field_names = self.experiments[experiment_id].ensembles[ensemble_id].keys()
for f in field_names:
field_data = ensemble.load_parameters(f, 1)
np.testing.assert_array_equal(
self.experiments[experiment_id].ensembles[ensemble_id][f],
field_data["values"],
)

@rule(
target=ensemble_ids,
experiment=experiment_ids,
Expand All @@ -198,7 +321,7 @@ def create_experiment(
def create_ensemble(self, experiment: UUID, ensemble_size: int):
ensemble = self.storage.create_ensemble(experiment, ensemble_size=ensemble_size)
assert ensemble in self.storage.ensembles
self.experiments[experiment].ensembles.append(ensemble.id)
self.experiments[experiment].ensembles[ensemble.id] = {}

# https://github.com/equinor/ert/issues/7046
# assert (
Expand All @@ -220,7 +343,7 @@ def create_ensemble_from_prior(self, prior: UUID):
experiment, ensemble_size=size, prior_ensemble=prior
)
assert ensemble in self.storage.ensembles
self.experiments[experiment].ensembles.append(ensemble.id)
self.experiments[experiment].ensembles[ensemble.id] = {}
# https://github.com/equinor/ert/issues/7046
# assert (
# ensemble.get_ensemble_state()
Expand All @@ -240,6 +363,9 @@ def get_experiment(self, id: UUID):
list(experiment.response_configuration.values())
== self.experiments[id].responses
)
assert self.experiments[id].observations == pytest.approx(
experiment.observations
)

@rule(id=ensemble_ids)
def get_ensemble(self, id: UUID):
Expand Down

0 comments on commit 65404c7

Please sign in to comment.