diff --git a/src/ert/config/design_matrix.py b/src/ert/config/design_matrix.py index f866766e41c..252fb529543 100644 --- a/src/ert/config/design_matrix.py +++ b/src/ert/config/design_matrix.py @@ -11,15 +11,10 @@ from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition from ._option_dict import option_dict -from .parsing import ( - ConfigValidationError, - ErrorInfo, -) +from .parsing import ConfigValidationError, ErrorInfo if TYPE_CHECKING: - from ert.config import ( - ParameterConfig, - ) + from ert.config import ParameterConfig DESIGN_MATRIX_GROUP = "DESIGN_MATRIX" @@ -31,10 +26,17 @@ class DesignMatrix: default_sheet: str def __post_init__(self) -> None: - self.num_realizations: Optional[int] = None - self.active_realizations: Optional[List[bool]] = None - self.design_matrix_df: Optional[pd.DataFrame] = None - self.parameter_configuration: Optional[Dict[str, ParameterConfig]] = None + try: + ( + self.active_realizations, + self.design_matrix_df, + self.parameter_configuration, + ) = self.read_design_matrix() + except (ValueError, AttributeError) as exc: + raise ConfigValidationError.with_context( + f"Error reading design matrix {self.xls_filename}: {exc}", + str(self.xls_filename), + ) from exc @classmethod def from_config_list(cls, config_list: List[str]) -> "DesignMatrix": @@ -76,9 +78,64 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix": default_sheet=default_sheet, ) + def merge_with_existing_parameters( + self, existing_parameters: List[ParameterConfig] + ) -> tuple[List[ParameterConfig], ParameterConfig | None]: + """ + This method merges the design matrix parameters with the existing parameters and + returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group. + GEN_KW group that was dropped will acquire a new name from the design matrix group. + Additionally, the ParameterConfig which is the design matrix group is returned separately. + + Args: + existing_parameters (List[ParameterConfig]): List of existing parameters + + Raises: + ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group + + Returns: + tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group + """ + + new_param_config: List[ParameterConfig] = [] + + design_parameter_group = self.parameter_configuration[DESIGN_MATRIX_GROUP] + design_keys = [] + if isinstance(design_parameter_group, GenKwConfig): + design_keys = [e.name for e in design_parameter_group.transform_functions] + + design_group_added = False + for parameter_group in existing_parameters: + if not isinstance(parameter_group, GenKwConfig): + new_param_config += [parameter_group] + continue + existing_keys = [e.name for e in parameter_group.transform_functions] + if set(existing_keys) == set(design_keys): + if design_group_added: + raise ConfigValidationError( + ( + "Multiple overlapping groups with design matrix found in existing parameters!\n" + f"{design_parameter_group.name} and {parameter_group.name}" + ) + ) + + design_parameter_group.name = parameter_group.name + design_group_added = True + elif set(design_keys) & set(existing_keys): + raise ConfigValidationError( + ( + "Overlapping parameter names found in design matrix!\n" + f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}" + "\nThey need to much exactly or not at all." + ) + ) + else: + new_param_config += [parameter_group] + return new_param_config, design_parameter_group + def read_design_matrix( self, - ) -> None: + ) -> tuple[List[bool], pd.DataFrame, Dict[str, ParameterConfig]]: # Read the parameter names (first row) as strings to prevent pandas from modifying them. # This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet. # By doing this, we can properly validate variable names, including detecting duplicates or missing names. @@ -142,11 +199,11 @@ def read_design_matrix( [[DESIGN_MATRIX_GROUP], design_matrix_df.columns] ) reals = design_matrix_df.index.tolist() - self.num_realizations = len(reals) - self.active_realizations = [x in reals for x in range(max(reals) + 1)] - - self.design_matrix_df = design_matrix_df - self.parameter_configuration = parameter_configuration + return ( + [x in reals for x in range(max(reals) + 1)], + design_matrix_df, + parameter_configuration, + ) @staticmethod def _read_excel( diff --git a/src/ert/enkf_main.py b/src/ert/enkf_main.py index 8bfdc4b248c..093375dcd5d 100644 --- a/src/ert/enkf_main.py +++ b/src/ert/enkf_main.py @@ -19,6 +19,8 @@ ) import orjson +import pandas as pd +import xarray as xr from numpy.random import SeedSequence from ert.config.ert_config import forward_model_data_to_json @@ -26,13 +28,8 @@ from ert.config.model_config import ModelConfig from ert.substitutions import Substitutions, substitute_runpath_name -from .config import ( - ExtParamConfig, - Field, - GenKwConfig, - ParameterConfig, - SurfaceConfig, -) +from .config import ExtParamConfig, Field, GenKwConfig, ParameterConfig, SurfaceConfig +from .config.design_matrix import DESIGN_MATRIX_GROUP from .run_arg import RunArg from .runpaths import Runpaths @@ -165,6 +162,29 @@ def _seed_sequence(seed: Optional[int]) -> int: return int_seed +def save_design_matrix_to_ensemble( + design_matrix_df: pd.DataFrame, + ensemble: Ensemble, + active_realizations: Iterable[int], + design_group_name: str = DESIGN_MATRIX_GROUP, +) -> None: + assert not design_matrix_df.empty + for realization_nr in active_realizations: + row = design_matrix_df.loc[realization_nr][DESIGN_MATRIX_GROUP] + ds = xr.Dataset( + { + "values": ("names", list(row.values)), + "transformed_values": ("names", list(row.values)), + "names": list(row.keys()), + } + ) + ensemble.save_parameters( + design_group_name, + realization_nr, + ds, + ) + + def sample_prior( ensemble: Ensemble, active_realizations: Iterable[int], diff --git a/src/ert/gui/simulation/ensemble_experiment_panel.py b/src/ert/gui/simulation/ensemble_experiment_panel.py index dd2cf2e513e..11e59340f26 100644 --- a/src/ert/gui/simulation/ensemble_experiment_panel.py +++ b/src/ert/gui/simulation/ensemble_experiment_panel.py @@ -15,7 +15,7 @@ from ert.gui.tools.design_matrix.design_matrix_panel import DesignMatrixPanel from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE from ert.run_models import EnsembleExperiment -from ert.validation import RangeStringArgument +from ert.validation import ActiveRange, RangeStringArgument from ert.validation.proper_name_argument import ExperimentValidation, ProperNameArgument from .experiment_config_panel import ExperimentConfigPanel @@ -85,6 +85,9 @@ def __init__( design_matrix = analysis_config.design_matrix if design_matrix is not None: + self._active_realizations_field.setText( + ActiveRange(design_matrix.active_realizations).rangestring + ) show_dm_param_button = QPushButton("Show parameters") show_dm_param_button.setObjectName("show-dm-parameters") show_dm_param_button.setMinimumWidth(50) @@ -113,23 +116,14 @@ def __init__( self.notifier.ertChanged.connect(self._update_experiment_name_placeholder) def on_show_dm_params_clicked(self, design_matrix: DesignMatrix) -> None: - assert design_matrix is not None - - if design_matrix.design_matrix_df is None: - design_matrix.read_design_matrix() - - if ( - design_matrix.design_matrix_df is not None - and not design_matrix.design_matrix_df.empty - ): - viewer = DesignMatrixPanel( - design_matrix.design_matrix_df, - design_matrix.xls_filename.name, - ) - viewer.setMinimumHeight(500) - viewer.setMinimumWidth(1000) - viewer.adjustSize() - viewer.exec_() + viewer = DesignMatrixPanel( + design_matrix.design_matrix_df, + design_matrix.xls_filename.name, + ) + viewer.setMinimumHeight(500) + viewer.setMinimumWidth(1000) + viewer.adjustSize() + viewer.exec_() @Slot(ExperimentConfigPanel) def experimentTypeChanged(self, w: ExperimentConfigPanel) -> None: diff --git a/src/ert/run_models/ensemble_experiment.py b/src/ert/run_models/ensemble_experiment.py index 7fa54a454c7..6af59ddc078 100644 --- a/src/ert/run_models/ensemble_experiment.py +++ b/src/ert/run_models/ensemble_experiment.py @@ -6,13 +6,14 @@ import numpy as np -from ert.enkf_main import sample_prior +from ert.config import ConfigValidationError +from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble from ert.ensemble_evaluator import EvaluatorServerConfig from ert.storage import Ensemble, Experiment, Storage from ert.trace import tracer from ..run_arg import create_run_arguments -from .base_run_model import BaseRunModel, StatusEvents +from .base_run_model import BaseRunModel, ErtRunError, StatusEvents if TYPE_CHECKING: from ert.config import ErtConfig, QueueConfig @@ -64,10 +65,27 @@ def run_experiment( ) -> None: self.log_at_startup() self.restart = restart + # If design matrix is present, we try to merge design matrix parameters + # to the experiment parameters and set new active realizations + parameters_config = self.ert_config.ensemble_config.parameter_configuration + design_matrix = self.ert_config.analysis_config.design_matrix + design_matrix_group = None + if design_matrix is not None: + try: + parameters_config, design_matrix_group = ( + design_matrix.merge_with_existing_parameters(parameters_config) + ) + except ConfigValidationError as exc: + raise ErtRunError(str(exc)) from exc + if not restart: self.experiment = self._storage.create_experiment( name=self.experiment_name, - parameters=self.ert_config.ensemble_config.parameter_configuration, + parameters=( + [*parameters_config, design_matrix_group] + if design_matrix_group is not None + else parameters_config + ), observations=self.ert_config.observations, responses=self.ert_config.ensemble_config.response_configuration, ) @@ -90,12 +108,21 @@ def run_experiment( np.array(self.active_realizations, dtype=bool), ensemble=self.ensemble, ) + sample_prior( self.ensemble, np.where(self.active_realizations)[0], random_seed=self.random_seed, ) + if design_matrix_group is not None and design_matrix is not None: + save_design_matrix_to_ensemble( + design_matrix.design_matrix_df, + self.ensemble, + np.where(self.active_realizations)[0], + design_matrix_group.name, + ) + self._evaluate_and_postprocess( run_args, self.ensemble, diff --git a/src/ert/run_models/model_factory.py b/src/ert/run_models/model_factory.py index 68cb4a75d56..f0688309201 100644 --- a/src/ert/run_models/model_factory.py +++ b/src/ert/run_models/model_factory.py @@ -117,13 +117,20 @@ def _setup_ensemble_experiment( args: Namespace, status_queue: SimpleQueue[StatusEvents], ) -> EnsembleExperiment: - active_realizations = _realizations(args, config.model_config.num_realizations) + active_realizations = _realizations( + args, config.model_config.num_realizations + ).tolist() + if ( + config.analysis_config.design_matrix is not None + and config.analysis_config.design_matrix.active_realizations is not None + ): + active_realizations = config.analysis_config.design_matrix.active_realizations experiment_name = args.experiment_name assert experiment_name is not None return EnsembleExperiment( random_seed=config.random_seed, - active_realizations=active_realizations.tolist(), + active_realizations=active_realizations, ensemble_name=args.current_ensemble, minimum_required_realizations=config.analysis_config.minimum_required_realizations, experiment_name=experiment_name, @@ -271,9 +278,9 @@ def _setup_iterative_ensemble_smoother( random_seed=config.random_seed, active_realizations=active_realizations.tolist(), target_ensemble=_iterative_ensemble_format(args), - number_of_iterations=int(args.num_iterations) - if args.num_iterations is not None - else 4, + number_of_iterations=( + int(args.num_iterations) if args.num_iterations is not None else 4 + ), minimum_required_realizations=config.analysis_config.minimum_required_realizations, num_retries_per_iter=4, experiment_name=experiment_name, diff --git a/tests/ert/ui_tests/cli/analysis/test_design_matrix.py b/tests/ert/ui_tests/cli/analysis/test_design_matrix.py new file mode 100644 index 00000000000..738c92a8963 --- /dev/null +++ b/tests/ert/ui_tests/cli/analysis/test_design_matrix.py @@ -0,0 +1,189 @@ +import os +import stat +from textwrap import dedent + +import numpy as np +import pandas as pd +import pytest + +from ert.cli.main import ErtCliError +from ert.config import ErtConfig +from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE +from ert.storage import open_storage +from tests.ert.ui_tests.cli.run_cli import run_cli + + +@pytest.mark.usefixtures("copy_poly_case") +def test_run_poly_example_with_design_matrix(): + design_matrix = "poly_design.xlsx" + num_realizations = 10 + a_values = list(range(num_realizations)) + design_matrix_df = pd.DataFrame( + { + "REAL": list(range(num_realizations)), + "a": a_values, + } + ) + default_sheet_df = pd.DataFrame([["b", 1], ["c", 2]]) + with pd.ExcelWriter(design_matrix) as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="DefaultSheet", header=False + ) + + with open("poly.ert", "w", encoding="utf-8") as fout: + fout.write( + dedent( + """\ + QUEUE_OPTION LOCAL MAX_RUNNING 10 + RUNPATH poly_out/realization-/iter- + NUM_REALIZATIONS 10 + MIN_REALIZATIONS 1 + GEN_DATA POLY_RES RESULT_FILE:poly.out + DESIGN_MATRIX poly_design.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultSheet + INSTALL_JOB poly_eval POLY_EVAL + FORWARD_MODEL poly_eval + """ + ) + ) + + with open("poly_eval.py", "w", encoding="utf-8") as f: + f.write( + dedent( + """\ + #!/usr/bin/env python + import json + + def _load_coeffs(filename): + with open(filename, encoding="utf-8") as f: + return json.load(f)["DESIGN_MATRIX"] + + def _evaluate(coeffs, x): + return coeffs["a"] * x**2 + coeffs["b"] * x + coeffs["c"] + + if __name__ == "__main__": + coeffs = _load_coeffs("parameters.json") + output = [_evaluate(coeffs, x) for x in range(10)] + with open("poly.out", "w", encoding="utf-8") as f: + f.write("\\n".join(map(str, output))) + """ + ) + ) + os.chmod( + "poly_eval.py", + os.stat("poly_eval.py").st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH, + ) + + run_cli( + ENSEMBLE_EXPERIMENT_MODE, + "--disable-monitor", + "poly.ert", + "--experiment-name", + "test-experiment", + ) + storage_path = ErtConfig.from_file("poly.ert").ens_path + with open_storage(storage_path) as storage: + experiment = storage.get_experiment_by_name("test-experiment") + params = experiment.get_ensemble_by_name("default").load_parameters( + "DESIGN_MATRIX" + )["values"] + np.testing.assert_array_equal(params[:, 0], a_values) + np.testing.assert_array_equal(params[:, 1], 10 * [1]) + np.testing.assert_array_equal(params[:, 2], 10 * [2]) + + +@pytest.mark.usefixtures("copy_poly_case") +@pytest.mark.parametrize( + "default_values, error_msg", + [ + ([["b", 1], ["c", 2]], None), + ([["b", 1]], "Overlapping parameter names found in design matrix!"), + ], +) +def test_run_poly_example_with_design_matrix_and_genkw_merge(default_values, error_msg): + design_matrix = "poly_design.xlsx" + num_realizations = 10 + a_values = list(range(num_realizations)) + design_matrix_df = pd.DataFrame( + { + "REAL": list(range(num_realizations)), + "a": a_values, + } + ) + default_sheet_df = pd.DataFrame(default_values) + with pd.ExcelWriter(design_matrix) as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="DefaultSheet", header=False + ) + + with open("poly.ert", "w", encoding="utf-8") as fout: + fout.write( + dedent( + """\ + QUEUE_OPTION LOCAL MAX_RUNNING 10 + RUNPATH poly_out/realization-/iter- + NUM_REALIZATIONS 10 + MIN_REALIZATIONS 1 + GEN_DATA POLY_RES RESULT_FILE:poly.out + GEN_KW COEFFS coeff_priors + DESIGN_MATRIX poly_design.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultSheet + INSTALL_JOB poly_eval POLY_EVAL + FORWARD_MODEL poly_eval + """ + ) + ) + + with open("poly_eval.py", "w", encoding="utf-8") as f: + f.write( + dedent( + """\ + #!/usr/bin/env python + import json + + def _load_coeffs(filename): + with open(filename, encoding="utf-8") as f: + return json.load(f)["COEFFS"] + + def _evaluate(coeffs, x): + return coeffs["a"] * x**2 + coeffs["b"] * x + coeffs["c"] + + if __name__ == "__main__": + coeffs = _load_coeffs("parameters.json") + output = [_evaluate(coeffs, x) for x in range(10)] + with open("poly.out", "w", encoding="utf-8") as f: + f.write("\\n".join(map(str, output))) + """ + ) + ) + os.chmod( + "poly_eval.py", + os.stat("poly_eval.py").st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH, + ) + + if error_msg: + with pytest.raises(ErtCliError, match=error_msg): + run_cli( + ENSEMBLE_EXPERIMENT_MODE, + "--disable-monitor", + "poly.ert", + "--experiment-name", + "test-experiment", + ) + return + run_cli( + ENSEMBLE_EXPERIMENT_MODE, + "--disable-monitor", + "poly.ert", + "--experiment-name", + "test-experiment", + ) + storage_path = ErtConfig.from_file("poly.ert").ens_path + with open_storage(storage_path) as storage: + experiment = storage.get_experiment_by_name("test-experiment") + params = experiment.get_ensemble_by_name("default").load_parameters("COEFFS")[ + "values" + ] + np.testing.assert_array_equal(params[:, 0], a_values) + np.testing.assert_array_equal(params[:, 1], 10 * [1]) + np.testing.assert_array_equal(params[:, 2], 10 * [2]) diff --git a/tests/ert/unit_tests/config/test_analysis_config.py b/tests/ert/unit_tests/config/test_analysis_config.py index 9482bf60ebb..ca827485b70 100644 --- a/tests/ert/unit_tests/config/test_analysis_config.py +++ b/tests/ert/unit_tests/config/test_analysis_config.py @@ -1,6 +1,7 @@ from textwrap import dedent import hypothesis.strategies as st +import pandas as pd import pytest from hypothesis import given @@ -15,22 +16,33 @@ def test_analysis_config_from_file_is_same_as_from_dict(monkeypatch, tmp_path): - with open(tmp_path / "my_design_matrix.xlsx", "w", encoding="utf-8"): - pass + with pd.ExcelWriter(tmp_path / "my_design_matrix.xlsx") as xl_write: + design_matrix_df = pd.DataFrame( + { + "REAL": [0, 1, 2], + "a": [1, 2, 3], + "b": [0, 2, 0], + } + ) + default_sheet_df = pd.DataFrame([["a", 1], ["b", 4]]) + design_matrix_df.to_excel(xl_write, index=False, sheet_name="my_sheet") + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="my_default_sheet", header=False + ) monkeypatch.chdir(tmp_path) assert ErtConfig.from_file_contents( dedent( """ - NUM_REALIZATIONS 10 - MIN_REALIZATIONS 10 + NUM_REALIZATIONS 3 + MIN_REALIZATIONS 3 ANALYSIS_SET_VAR STD_ENKF ENKF_TRUNCATION 0.8 DESIGN_MATRIX my_design_matrix.xlsx DESIGN_SHEET:my_sheet DEFAULT_SHEET:my_default_sheet """ ) ).analysis_config == AnalysisConfig.from_dict( { - ConfigKeys.NUM_REALIZATIONS: 10, - ConfigKeys.MIN_REALIZATIONS: "10", + ConfigKeys.NUM_REALIZATIONS: 3, + ConfigKeys.MIN_REALIZATIONS: "3", ConfigKeys.ANALYSIS_SET_VAR: [ ("STD_ENKF", "ENKF_TRUNCATION", 0.8), ], diff --git a/tests/ert/unit_tests/gui/simulation/test_run_dialog.py b/tests/ert/unit_tests/gui/simulation/test_run_dialog.py index 2787be480ba..0287e5ff9c4 100644 --- a/tests/ert/unit_tests/gui/simulation/test_run_dialog.py +++ b/tests/ert/unit_tests/gui/simulation/test_run_dialog.py @@ -2,6 +2,7 @@ from queue import SimpleQueue from unittest.mock import MagicMock, Mock, patch +import pandas as pd import pytest from pytestqt.qtbot import QtBot from qtpy import QtWidgets @@ -732,15 +733,26 @@ def test_that_stdout_and_stderr_buttons_react_to_file_content( def test_that_design_matrix_show_parameters_button_is_visible( design_matrix_entry, qtbot: QtBot, storage ): - xls_filename = "design_matrix.xls" - with open(f"{xls_filename}", "w", encoding="utf-8"): - pass + xls_filename = "design_matrix.xlsx" + design_matrix_df = pd.DataFrame( + { + "REAL": list(range(3)), + "a": [0, 1, 2], + } + ) + default_sheet_df = pd.DataFrame([["b", 1], ["c", 2]]) + with pd.ExcelWriter(xls_filename) as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="DefaultSheet", header=False + ) + config_file = "minimal_config.ert" with open(config_file, "w", encoding="utf-8") as f: f.write("NUM_REALIZATIONS 1") if design_matrix_entry: f.write( - f"\nDESIGN_MATRIX {xls_filename} DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultValues" + f"\nDESIGN_MATRIX {xls_filename} DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultSheet" ) args_mock = Mock() diff --git a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py index 2326f76bc3a..3ed8c7309e2 100644 --- a/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py +++ b/tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py @@ -3,6 +3,83 @@ import pytest from ert.config.design_matrix import DESIGN_MATRIX_GROUP, DesignMatrix +from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition + + +@pytest.mark.parametrize( + "parameters, error_msg", + [ + pytest.param( + {"COEFFS": ["a", "b"]}, + "", + id="genkw_replaced", + ), + pytest.param( + {"COEFFS": ["a"]}, + "Overlapping parameter names found in design matrix!", + id="ValidationErrorOverlapping", + ), + pytest.param( + {"COEFFS": ["aa", "bb"], "COEFFS2": ["cc", "dd"]}, + "", + id="DESIGN_MATRIX_GROUP", + ), + pytest.param( + {"COEFFS": ["a", "b"], "COEFFS2": ["a", "b"]}, + "Multiple overlapping groups with design matrix found in existing parameters!", + id="ValidationErrorMultipleGroups", + ), + ], +) +def test_read_and_merge_with_existing_parameters(tmp_path, parameters, error_msg): + extra_genkw_config = [] + if parameters: + for group_name in parameters: + extra_genkw_config.append( + GenKwConfig( + name=group_name, + forward_init=False, + template_file="", + transform_function_definitions=[ + TransformFunctionDefinition(param, "UNIFORM", [0, 1]) + for param in parameters[group_name] + ], + output_file="kw.txt", + update=True, + ) + ) + + realizations = [0, 1, 2] + design_path = tmp_path / "design_matrix.xlsx" + design_matrix_df = pd.DataFrame( + { + "REAL": realizations, + "a": [1, 2, 3], + "b": [0, 2, 0], + } + ) + default_sheet_df = pd.DataFrame([["a", 1], ["b", 4]]) + with pd.ExcelWriter(design_path) as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + default_sheet_df.to_excel( + xl_write, index=False, sheet_name="DefaultValues", header=False + ) + design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + if error_msg: + with pytest.raises(ValueError, match=error_msg): + design_matrix.merge_with_existing_parameters(extra_genkw_config) + elif len(parameters) == 1: + new_config_parameters, design_group = ( + design_matrix.merge_with_existing_parameters(extra_genkw_config) + ) + assert len(new_config_parameters) == 0 + assert design_group.name == "COEFFS" + elif len(parameters) == 2: + new_config_parameters, design_group = ( + design_matrix.merge_with_existing_parameters(extra_genkw_config) + ) + assert len(new_config_parameters) == 2 + assert design_group.name == DESIGN_MATRIX_GROUP def test_reading_design_matrix(tmp_path): @@ -23,10 +100,8 @@ def test_reading_design_matrix(tmp_path): xl_write, index=False, sheet_name="DefaultValues", header=False ) design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") - design_matrix.read_design_matrix() design_params = design_matrix.parameter_configuration.get(DESIGN_MATRIX_GROUP, []) assert all(param in design_params for param in ("a", "b", "c", "one", "d")) - assert design_matrix.num_realizations == 3 assert design_matrix.active_realizations == [True, True, False, False, True] @@ -62,9 +137,9 @@ def test_reading_design_matrix_validate_reals(tmp_path, real_column, error_msg): default_sheet_df.to_excel( xl_write, index=False, sheet_name="DefaultValues", header=False ) - design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + with pytest.raises(ValueError, match=error_msg): - design_matrix.read_design_matrix() + DesignMatrix(design_path, "DesignSheet01", "DefaultValues") @pytest.mark.parametrize( @@ -98,9 +173,9 @@ def test_reading_design_matrix_validate_headers(tmp_path, column_names, error_ms default_sheet_df.to_excel( xl_write, index=False, sheet_name="DefaultValues", header=False ) - design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + with pytest.raises(ValueError, match=error_msg): - design_matrix.read_design_matrix() + DesignMatrix(design_path, "DesignSheet01", "DefaultValues") @pytest.mark.parametrize( @@ -134,9 +209,9 @@ def test_reading_design_matrix_validate_cells(tmp_path, values, error_msg): default_sheet_df.to_excel( xl_write, index=False, sheet_name="DefaultValues", header=False ) - design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + with pytest.raises(ValueError, match=error_msg): - design_matrix.read_design_matrix() + DesignMatrix(design_path, "DesignSheet01", "DefaultValues") @pytest.mark.parametrize( @@ -180,9 +255,9 @@ def test_reading_default_sheet_validation(tmp_path, data, error_msg): default_sheet_df.to_excel( xl_write, index=False, sheet_name="DefaultValues", header=False ) - design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + with pytest.raises(ValueError, match=error_msg): - design_matrix.read_design_matrix() + DesignMatrix(design_path, "DesignSheet01", "DefaultValues") def test_default_values_used(tmp_path): @@ -202,7 +277,6 @@ def test_default_values_used(tmp_path): xl_write, index=False, sheet_name="DefaultValues", header=False ) design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") - design_matrix.read_design_matrix() df = design_matrix.design_matrix_df np.testing.assert_equal(df[DESIGN_MATRIX_GROUP, "one"], np.array([1, 1, 1, 1])) np.testing.assert_equal(df[DESIGN_MATRIX_GROUP, "b"], np.array([0, 2, 0, 1])) diff --git a/tests/ert/unit_tests/test_libres_facade.py b/tests/ert/unit_tests/test_libres_facade.py index 98ecd91a359..7453ade7d8d 100644 --- a/tests/ert/unit_tests/test_libres_facade.py +++ b/tests/ert/unit_tests/test_libres_facade.py @@ -2,12 +2,15 @@ from datetime import datetime, timedelta from textwrap import dedent +import numpy as np import pytest +from pandas import ExcelWriter from pandas.core.frame import DataFrame from resdata.summary import Summary from ert.config import ErtConfig -from ert.enkf_main import sample_prior +from ert.config.design_matrix import DESIGN_MATRIX_GROUP, DesignMatrix +from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble from ert.libres_facade import LibresFacade from ert.storage import open_storage @@ -253,3 +256,52 @@ def test_load_gen_kw_not_sorted(storage, tmpdir, snapshot): data = ensemble.load_all_gen_kw_data() snapshot.assert_match(data.round(12).to_csv(), "gen_kw_unsorted") + + +@pytest.mark.parametrize( + "reals, expect_error", + [ + pytest.param( + list(range(10)), + False, + id="correct_active_realizations", + ), + pytest.param([10, 11], True, id="incorrect_active_realizations"), + ], +) +def test_save_parameters_to_storage_from_design_dataframe( + tmp_path, reals, expect_error +): + design_path = tmp_path / "design_matrix.xlsx" + ensemble_size = 10 + a_values = np.random.default_rng().uniform(-5, 5, 10) + b_values = np.random.default_rng().uniform(-5, 5, 10) + c_values = np.random.default_rng().uniform(-5, 5, 10) + design_matrix_df = DataFrame({"a": a_values, "b": b_values, "c": c_values}) + with ExcelWriter(design_path) as xl_write: + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") + DataFrame().to_excel( + xl_write, index=False, sheet_name="DefaultValues", header=False + ) + design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") + with open_storage(tmp_path / "storage", mode="w") as storage: + experiment_id = storage.create_experiment( + parameters=[design_matrix.parameter_configuration[DESIGN_MATRIX_GROUP]] + ) + ensemble = storage.create_ensemble( + experiment_id, name="default", ensemble_size=ensemble_size + ) + if expect_error: + with pytest.raises(KeyError): + save_design_matrix_to_ensemble( + design_matrix.design_matrix_df, ensemble, reals + ) + else: + save_design_matrix_to_ensemble( + design_matrix.design_matrix_df, ensemble, reals + ) + params = ensemble.load_parameters(DESIGN_MATRIX_GROUP)["values"] + all(params.names.values == ["a", "b", "c"]) + np.testing.assert_array_almost_equal(params[:, 0], a_values) + np.testing.assert_array_almost_equal(params[:, 1], b_values) + np.testing.assert_array_almost_equal(params[:, 2], c_values)