Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reading designmatrix #8800

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
"matplotlib",
"netCDF4",
"numpy<2",
"openpyxl", # extra dependency for pandas (excel)
"orjson",
"packaging",
"pandas",
Expand Down
2 changes: 2 additions & 0 deletions src/ert/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .analysis_config import AnalysisConfig
from .analysis_module import AnalysisModule, ESSettings, IESSettings
from .capture_validation import capture_validation
from .design_matrix import DesignMatrix
from .enkf_observation_implementation_type import EnkfObservationImplementationType
from .ensemble_config import EnsembleConfig
from .ert_config import ErtConfig
Expand Down Expand Up @@ -48,6 +49,7 @@
"ConfigValidationError",
"ConfigValidationError",
"ConfigWarning",
"DesignMatrix",
"ESSettings",
"EnkfObs",
"EnkfObservationImplementationType",
Expand Down
4 changes: 2 additions & 2 deletions src/ert/config/analysis_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AnalysisConfig:
ies_module: IESSettings = field(default_factory=IESSettings)
observation_settings: UpdateSettings = field(default_factory=UpdateSettings)
num_iterations: int = 1
design_matrix_args: Optional[DesignMatrix] = None
design_matrix: Optional[DesignMatrix] = None

@no_type_check
@classmethod
Expand Down Expand Up @@ -194,7 +194,7 @@ def from_dict(cls, config_dict: ConfigDict) -> "AnalysisConfig":
observation_settings=obs_settings,
es_module=es_settings,
ies_module=ies_settings,
design_matrix_args=DesignMatrix.from_config_list(design_matrix_config_list)
design_matrix=DesignMatrix.from_config_list(design_matrix_config_list)
if design_matrix_config_list is not None
else None,
)
Expand Down
189 changes: 188 additions & 1 deletion src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,37 @@

from dataclasses import dataclass
from pathlib import Path
from typing import List
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import numpy as np
import pandas as pd
from pandas.api.types import is_integer_dtype

from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition

from ._option_dict import option_dict
from .parsing import (
ConfigValidationError,
ErrorInfo,
)

if TYPE_CHECKING:
from ert.config import (
ParameterConfig,
)

DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"


@dataclass
class DesignMatrix:
xls_filename: Path
design_sheet: str
default_sheet: str
num_realizations: Optional[int] = None
active_realizations: Optional[List[bool]] = None
design_matrix_df: Optional[pd.DataFrame] = None
parameter_configuration: Optional[Dict[str, ParameterConfig]] = None

@classmethod
def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
Expand All @@ -41,6 +58,12 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
errors.append(
ErrorInfo("Missing required DEFAULT_SHEET").set_context(config_list)
)
if design_sheet is not None and design_sheet == default_sheet:
errors.append(
ErrorInfo(
"DESIGN_SHEET and DEFAULT_SHEET can not point to the same sheet."
).set_context(config_list)
)
if errors:
raise ConfigValidationError.from_collected(errors)
assert design_sheet is not None
Expand All @@ -50,3 +73,167 @@ def from_config_list(cls, config_list: List[str]) -> "DesignMatrix":
design_sheet=design_sheet,
default_sheet=default_sheet,
)

def read_design_matrix(
self,
) -> None:
param_names = (
pd.read_excel(
self.xls_filename,
sheet_name=self.design_sheet,
nrows=1,
header=None,
dtype="string",
)
.iloc[0]
.apply(lambda x: x.strip() if isinstance(x, str) else x)
)
design_matrix_df = DesignMatrix._read_excel(
self.xls_filename,
self.design_sheet,
header=None,
skiprows=1,
)
design_matrix_df.columns = param_names.to_list()

if "REAL" in design_matrix_df.columns:
if not is_integer_dtype(design_matrix_df.dtypes["REAL"]) or any(
design_matrix_df["REAL"] < 0
):
raise ValueError("REAL column must only contain positive integers")
design_matrix_df = design_matrix_df.set_index(
"REAL", drop=True, verify_integrity=True
)

if error_list := DesignMatrix._validate_design_matrix(design_matrix_df):
error_msg = "\n".join(error_list)
raise ValueError(f"Design matrix is not valid, error:\n{error_msg}")

defaults = DesignMatrix._read_defaultssheet(
self.xls_filename, self.default_sheet
)
for k, v in defaults.items():
if k not in design_matrix_df.columns:
design_matrix_df[k] = v

parameter_configuration: Dict[str, ParameterConfig] = {}
transform_function_definitions: List[TransformFunctionDefinition] = []
for parameter in design_matrix_df.columns:
transform_function_definitions.append(
TransformFunctionDefinition(
name=parameter,
param_name="RAW",
values=[],
)
)
parameter_configuration[DESIGN_MATRIX_GROUP] = GenKwConfig(
name=DESIGN_MATRIX_GROUP,
forward_init=False,
template_file=None,
output_file=None,
transform_function_definitions=transform_function_definitions,
update=False,
)

design_matrix_df.columns = pd.MultiIndex.from_product(
[[DESIGN_MATRIX_GROUP], design_matrix_df.columns]
)
reals = design_matrix_df.index.tolist()
self.num_realizations = len(reals)
self.active_realizations = [x in reals for x in range(max(reals) + 1)]

self.design_matrix_df = design_matrix_df
self.parameter_configuration = parameter_configuration

@staticmethod
def _read_excel(
file_name: Union[Path, str],
sheet_name: str,
usecols: Optional[Union[int, List[int]]] = None,
header: Optional[int] = 0,
skiprows: Optional[int] = None,
dtype: Optional[str] = None,
) -> pd.DataFrame:
"""
Make dataframe from excel file
:return: Dataframe
:raises: OsError if file not found
:raises: ValueError if file not loaded correctly
"""
dframe: pd.DataFrame = pd.read_excel(
file_name,
sheet_name,
usecols=usecols,
header=header,
skiprows=skiprows,
dtype=dtype,
)
return dframe.dropna(axis=1, how="all")

@staticmethod
def _validate_design_matrix(design_matrix: pd.DataFrame) -> List[str]:
"""
Validate user inputted design matrix
:raises: ValueError if design matrix contains empty headers or empty cells
"""
if design_matrix.empty:
return []
errors = []
column_na_mask = design_matrix.columns.isna()
column_indexes_unnamed = [
index for index, value in enumerate(column_na_mask) if value
]
if len(column_indexes_unnamed) > 0:
errors.append(
f"Column headers not present in column {column_indexes_unnamed}"
)
if not design_matrix.columns[~column_na_mask].is_unique:
errors.append("Duplicate parameter names found in design sheet")
empties = [
f"Realization {design_matrix.index[i]}, column {design_matrix.columns[j]}"
for i, j in zip(*np.where(pd.isna(design_matrix)))
]
if len(empties) > 0:
errors.append(f"Design matrix contains empty cells {empties}")
return errors

@staticmethod
def _read_defaultssheet(
xls_filename: Union[Path, str], defaults_sheetname: str
) -> Dict[str, Union[str, float]]:
"""
Construct a dict of keys and values to be used as defaults from the
first two columns in a spreadsheet.

Returns a dict of default values

:raises: ValueError if defaults sheet is non-empty but non-parsable
"""
default_df = DesignMatrix._read_excel(
xls_filename,
defaults_sheetname,
header=None,
dtype="string",
)
if default_df.empty:
return {}
if len(default_df.columns) < 2:
raise ValueError("Defaults sheet must have at least two columns")
empty_cells = [
f"Row {default_df.index[i]}, column {default_df.columns[j]}"
for i, j in zip(*np.where(pd.isna(default_df)))
]
if len(empty_cells) > 0:
raise ValueError(f"Default sheet contains empty cells {empty_cells}")
default_df[0] = default_df[0].apply(lambda x: x.strip())
if not default_df[0].is_unique:
raise ValueError("Default sheet contains duplicate parameter names")

return {row[0]: convert_to_numeric(row[1]) for _, row in default_df.iterrows()}


def convert_to_numeric(x: str) -> Union[str, float]:
try:
return pd.to_numeric(x)
except ValueError:
return x
7 changes: 4 additions & 3 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ class TransformFunctionDefinition:
class GenKwConfig(ParameterConfig):
template_file: Optional[str]
output_file: Optional[str]
transform_function_definitions: (
List[TransformFunctionDefinition] | List[Dict[Any, Any]]
)
transform_function_definitions: List[TransformFunctionDefinition]
forward_init_file: Optional[str] = None

def __post_init__(self) -> None:
Expand All @@ -90,6 +88,9 @@ def __post_init__(self) -> None:
)
self._validate()

def __contains__(self, item: str) -> bool:
return item in [v.name for v in self.transform_function_definitions]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do this lazily?

Suggested change
return item in [v.name for v in self.transform_function_definitions]
return any(item in v.name for v in self.transform_function_definitions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that does the same thing, but will see if there is a better way for this?


def __len__(self) -> int:
return len(self.transform_functions)

Expand Down
Loading