Skip to content

Commit

Permalink
Add default values using Pandas assign in design_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Nov 12, 2024
1 parent 44ece92 commit b5d3671
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,10 @@ def read_design_matrix(
error_msg = "\n".join(error_list)
raise ValueError(f"Design matrix is not valid, error:\n{error_msg}")

defaults = DesignMatrix._read_defaultssheet(
self.xls_filename, self.default_sheet
defaults_to_use = DesignMatrix._read_defaultssheet(
self.xls_filename, self.default_sheet, design_matrix_df.columns.to_list()
)
for k, v in defaults.items():
if k not in design_matrix_df.columns:
design_matrix_df[k] = v
design_matrix_df = design_matrix_df.assign(**defaults_to_use)

parameter_configuration: Dict[str, ParameterConfig] = {}
transform_function_definitions: List[TransformFunctionDefinition] = []
Expand Down Expand Up @@ -200,11 +198,14 @@ def _validate_design_matrix(design_matrix: pd.DataFrame) -> List[str]:

@staticmethod
def _read_defaultssheet(
xls_filename: Union[Path, str], defaults_sheetname: str
xls_filename: Union[Path, str],
defaults_sheetname: str,
existing_parameters: List[str],
) -> Dict[str, Union[str, float]]:
"""
Construct a dict of keys and values to be used as defaults from the
first two columns in a spreadsheet.
first two columns in a spreadsheet. Only returns the keys that are
different from the exisiting parameters.
Returns a dict of default values
Expand All @@ -230,7 +231,11 @@ def _read_defaultssheet(
if not default_df[0].is_unique:
raise ValueError("Default sheet contains duplicate parameter names")

return {row[0]: convert_to_numeric(row[1]) for _, row in default_df.iterrows()}
return {
row[0]: convert_to_numeric(row[1])
for _, row in default_df.iterrows()
if row[0] not in existing_parameters
}


def convert_to_numeric(x: str) -> Union[str, float]:
Expand Down
27 changes: 27 additions & 0 deletions tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,30 @@ def test_reading_default_sheet_validation(tmp_path, data, error_msg):
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
with pytest.raises(ValueError, match=error_msg):
design_matrix.read_design_matrix()


def test_default_values_used(tmp_path):
design_path = tmp_path / "design_matrix.xlsx"
design_matrix_df = pd.DataFrame(
{
"REAL": [0, 1, 2, 3],
"a": [1, 2, 3, 4],
"b": [0, 2, 0, 1],
"c": ["low", "high", "medium", "low"],
}
)
default_sheet_df = pd.DataFrame([["one", 1], ["b", 4], ["d", "case_name"]])
with pd.ExcelWriter(design_path) as xl_write:
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
default_sheet_df.to_excel(
xl_write, index=False, sheet_name="DefaultValues", header=False
)
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
design_matrix.read_design_matrix()
df = design_matrix.design_matrix_df
np.testing.assert_equal(df[DESIGN_MATRIX_GROUP, "one"], np.array([1, 1, 1, 1]))
np.testing.assert_equal(df[DESIGN_MATRIX_GROUP, "b"], np.array([0, 2, 0, 1]))
np.testing.assert_equal(
df[DESIGN_MATRIX_GROUP, "d"],
np.array(["case_name", "case_name", "case_name", "case_name"]),
)

0 comments on commit b5d3671

Please sign in to comment.