diff --git a/tests/unit_tests/config/egrid_generator.py b/tests/unit_tests/config/egrid_generator.py index 95d5edb5a0b..13996ac97ee 100644 --- a/tests/unit_tests/config/egrid_generator.py +++ b/tests/unit_tests/config/egrid_generator.py @@ -267,6 +267,11 @@ class EGrid: file_head: Filehead global_grid: GlobalGrid + @property + def shape(self) -> Tuple[int, int, int]: + grid_head = self.global_grid.grid_head + return (grid_head.num_x, grid_head.num_y, grid_head.num_z) + def to_file( self, filelike, diff --git a/tests/unit_tests/storage/test_local_ensemble.py b/tests/unit_tests/storage/test_local_ensemble.py index 3ee3e289687..31270915b12 100644 --- a/tests/unit_tests/storage/test_local_ensemble.py +++ b/tests/unit_tests/storage/test_local_ensemble.py @@ -1,6 +1,3 @@ -import os -from datetime import datetime - import numpy as np import pytest import xarray as xr @@ -11,8 +8,6 @@ from ert.config.field import Field from ert.field_utils import FieldFileFormat from ert.storage import open_storage -from ert.storage.local_ensemble import _Failure -from ert.storage.realization_storage_state import RealizationStorageState def test_that_egrid_files_are_saved_and_loaded_correctly(tmp_path): @@ -160,23 +155,3 @@ def test_that_loading_parameter_via_response_api_fails(tmp_path): ) with pytest.raises(ValueError, match="PARAMETER is not a response"): prior.load_responses("PARAMETER", (0,)) - - -def test_get_failure(tmp_path): - with open_storage(tmp_path, mode="w") as storage: - experiment = storage.create_experiment() - ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=1) - - error_file = ensemble._path / "realization-9" / ensemble._error_log_name - os.makedirs(os.path.dirname(error_file), exist_ok=True) - - error = _Failure( - type=RealizationStorageState.PARENT_FAILURE, - message="Something is wrong", - time=datetime.now(), - ) - with open(error_file, mode="w", encoding="utf-8") as f: - print(error.json(), file=f) - - assert ensemble.get_failure(9) == error - assert ensemble.get_failure(8) is None diff --git a/tests/unit_tests/storage/test_local_storage.py b/tests/unit_tests/storage/test_local_storage.py index 93598e1ab46..50de709854c 100644 --- a/tests/unit_tests/storage/test_local_storage.py +++ b/tests/unit_tests/storage/test_local_storage.py @@ -2,15 +2,21 @@ import tempfile from collections import defaultdict from dataclasses import dataclass, field -from typing import List +from pathlib import Path +from typing import Any, Dict, List from unittest.mock import patch from uuid import UUID import hypothesis.strategies as st +import numpy as np import pytest -from hypothesis.stateful import Bundle, RuleBasedStateMachine, rule +import xarray as xr +from hypothesis.extra.numpy import arrays +from hypothesis.stateful import Bundle, RuleBasedStateMachine, initialize, rule from ert.config import ( + EnkfObs, + Field, GenDataConfig, GenKwConfig, ParameterConfig, @@ -18,9 +24,15 @@ SummaryConfig, SurfaceConfig, ) +from ert.config.enkf_observation_implementation_type import ( + EnkfObservationImplementationType, +) +from ert.config.general_observation import GenObservation +from ert.config.observation_vector import ObsVector from ert.storage import StorageReader, open_storage from ert.storage import local_storage as local from ert.storage.realization_storage_state import RealizationStorageState +from tests.unit_tests.config.egrid_generator import egrids from tests.unit_tests.config.summary_generator import summary_keys @@ -90,17 +102,6 @@ def test_to_accessor(tmp_path): storage_reader.to_accessor() -@st.composite -def refcase(draw): - datetimes = draw(st.lists(st.datetimes())) - container_type = draw(st.sampled_from([set(), list(), None])) - if isinstance(container_type, set): - return set(datetimes) - elif isinstance(container_type, list): - return [str(date) for date in datetimes] - return None - - parameter_configs = st.lists( st.one_of( st.builds( @@ -124,21 +125,88 @@ def refcase(draw): input_file=st.text( alphabet=st.characters(min_codepoint=65, max_codepoint=90) ), - keys=st.lists(summary_keys), - refcase=refcase(), + keys=st.lists(summary_keys, max_size=3), + refcase=st.just(None), ), ), unique_by=lambda x: x.name, ) ensemble_sizes = st.integers(min_value=1, max_value=1000) +coordinates = st.integers(min_value=1, max_value=100) + + +words = st.text( + min_size=1, max_size=8, alphabet=st.characters(min_codepoint=65, max_codepoint=90) +) + + +gen_observations = st.integers(min_value=1, max_value=10).flatmap( + lambda size: st.builds( + GenObservation, + values=arrays(np.double, shape=size), + stds=arrays( + np.double, + elements=st.floats(min_value=0.1, max_value=1.0), + shape=size, + ), + indices=arrays( + np.int64, + elements=st.integers(min_value=0, max_value=100), + shape=size, + ), + std_scaling=arrays(np.double, shape=size), + ) +) + + +@st.composite +def observation_dicts(draw): + return {draw(st.integers(min_value=0, max_value=200)): draw(gen_observations)} + + +observations = st.builds( + EnkfObs, + obs_vectors=st.dictionaries( + words, + st.builds( + ObsVector, + observation_type=st.just(EnkfObservationImplementationType.GEN_OBS), + observation_key=words, + data_key=words, + observations=observation_dicts(), + ), + ), +) @dataclass class Experiment: - ensembles: List[UUID] = field(default_factory=list) + ensembles: Dict[UUID, Dict[str, Any]] = field(default_factory=dict) parameters: List[ParameterConfig] = field(default_factory=list) responses: List[ResponseConfig] = field(default_factory=list) + observations: Dict[str, xr.Dataset] = field(default_factory=dict) + + +@st.composite +def fields(draw, egrid) -> List[Field]: + grid_file, grid = egrid + nx, ny, nz = grid.shape + return [ + draw( + st.builds( + Field, + name=st.just(f"Field{i}"), + file_format=st.just("roff_binary"), + grid_file=st.just(grid_file), + nx=st.just(nx), + ny=st.just(ny), + nz=st.just(nz), + output_file=st.just(Path(f"field{i}.roff")), + ) + ) + for i in range(10) + ] class StatefulTest(RuleBasedStateMachine): @@ -153,6 +221,21 @@ def __init__(self): experiment_ids = Bundle("experiments") ensemble_ids = Bundle("ensembles") failures = Bundle("failures") + field_list = Bundle("field_list") + grid = Bundle("grid") + + @initialize(target=grid, egrid=egrids) + def create_grid(self, egrid): + grid_file = self.tmpdir + "/grid.egrid" + egrid.to_file(grid_file) + return (grid_file, egrid) + + @initialize( + target=field_list, + fields=grid.flatmap(fields), + ) + def create_field_list(self, fields): + return fields @rule() def double_open_timeout(self): @@ -171,25 +254,65 @@ def reopen(self): @rule( target=experiment_ids, - parameters=parameter_configs, + parameters=st.one_of(parameter_configs, field_list), responses=response_configs, + obs=observations, ) def create_experiment( - self, parameters: List[ParameterConfig], responses: List[ResponseConfig] + self, + parameters: List[ParameterConfig], + responses: List[ResponseConfig], + obs: EnkfObs, ): experiment_id = self.storage.create_experiment( - parameters=parameters, responses=responses + parameters=parameters, responses=responses, observations=obs.datasets ).id self.experiments[experiment_id].parameters = parameters self.experiments[experiment_id].responses = responses + self.experiments[experiment_id].observations = obs.datasets # Ensure that there is at least one ensemble in the experiment # to avoid https://github.com/equinor/ert/issues/7040 ensemble = self.storage.create_ensemble(experiment_id, ensemble_size=1) - self.experiments[experiment_id].ensembles.append(ensemble.id) + self.experiments[experiment_id].ensembles[ensemble.id] = {} return experiment_id + @rule( + ensemble_id=ensemble_ids, + field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)), + ) + def save_field(self, ensemble_id: UUID, field_data): + ensemble = self.storage.get_ensemble(ensemble_id) + experiment_id = ensemble.experiment_id + parameters = self.experiments[experiment_id].parameters + fields = [p for p in parameters if isinstance(p, Field)] + for f in fields: + self.experiments[experiment_id].ensembles[ensemble_id][f.name] = field_data + ensemble.save_parameters( + f.name, + 1, + xr.DataArray( + field_data, + name="values", + dims=["x", "y", "z"], # type: ignore + ).to_dataset(), + ) + + @rule( + ensemble_id=ensemble_ids, + ) + def get_field(self, ensemble_id: UUID): + ensemble = self.storage.get_ensemble(ensemble_id) + experiment_id = ensemble.experiment_id + field_names = self.experiments[experiment_id].ensembles[ensemble_id].keys() + for f in field_names: + field_data = ensemble.load_parameters(f, 1) + np.testing.assert_array_equal( + self.experiments[experiment_id].ensembles[ensemble_id][f], + field_data["values"], + ) + @rule( target=ensemble_ids, experiment=experiment_ids, @@ -198,7 +321,7 @@ def create_experiment( def create_ensemble(self, experiment: UUID, ensemble_size: int): ensemble = self.storage.create_ensemble(experiment, ensemble_size=ensemble_size) assert ensemble in self.storage.ensembles - self.experiments[experiment].ensembles.append(ensemble.id) + self.experiments[experiment].ensembles[ensemble.id] = {} # https://github.com/equinor/ert/issues/7046 # assert ( @@ -220,7 +343,7 @@ def create_ensemble_from_prior(self, prior: UUID): experiment, ensemble_size=size, prior_ensemble=prior ) assert ensemble in self.storage.ensembles - self.experiments[experiment].ensembles.append(ensemble.id) + self.experiments[experiment].ensembles[ensemble.id] = {} # https://github.com/equinor/ert/issues/7046 # assert ( # ensemble.get_ensemble_state() @@ -240,6 +363,9 @@ def get_experiment(self, id: UUID): list(experiment.response_configuration.values()) == self.experiments[id].responses ) + assert self.experiments[id].observations == pytest.approx( + experiment.observations + ) @rule(id=ensemble_ids) def get_ensemble(self, id: UUID):