Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport fix memory usage for esupdate #9503

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 1 addition & 82 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,85 +142,6 @@ def _load_param_ensemble_array(
return config_node.load_parameters(ensemble, param_group, iens_active_index)


def _get_observations_and_responses(
ensemble: Ensemble,
selected_observations: Iterable[str],
iens_active_index: npt.NDArray[np.int_],
) -> polars.DataFrame:
"""Fetches and aligns selected observations with their corresponding simulated responses from an ensemble."""
observations_by_type = ensemble.experiment.observations

dfs = []
for (
response_type,
response_cls,
) in ensemble.experiment.response_configuration.items():
if response_type not in observations_by_type:
continue

observations_for_type = observations_by_type[response_type].filter(
polars.col("observation_key").is_in(list(selected_observations))
)
responses_for_type = ensemble.load_responses(
response_type, realizations=tuple(iens_active_index)
)

# Note that if there are duplicate entries for one
# response at one index, they are aggregated together
# with "mean" by default
pivoted = responses_for_type.pivot(
on="realization",
index=["response_key", *response_cls.primary_key],
aggregate_function="mean",
)

# We need to either assume that if there is a time column
# we will approx-join that, or we could specify in response configs
# that there is a column that requires an approx "asof" join.
# Suggest we simplify and assume that there is always only
# one "time" column, which we will reindex towards the response dataset
# with a given resolution
if "time" in pivoted:
joined = observations_for_type.join_asof(
pivoted,
by=["response_key", *response_cls.primary_key],
on="time",
tolerance="1s",
)
else:
joined = observations_for_type.join(
pivoted,
how="left",
on=["response_key", *response_cls.primary_key],
)

joined = (
joined.with_columns(
polars.concat_str(response_cls.primary_key, separator=", ").alias(
"__tmp_index_key__" # Avoid potential collisions w/ primary key
)
)
.drop(response_cls.primary_key)
.rename({"__tmp_index_key__": "index"})
)

first_columns = [
"response_key",
"index",
"observation_key",
"observations",
"std",
]
joined = joined.select(
first_columns + [c for c in joined.columns if c not in first_columns]
)

dfs.append(joined)

ensemble.load_responses.cache_clear()
return polars.concat(dfs)


def _expand_wildcards(
input_list: npt.NDArray[np.str_], patterns: List[str]
) -> List[str]:
Expand Down Expand Up @@ -260,9 +181,7 @@ def _load_observations_and_responses(
List[ObservationAndResponseSnapshot],
],
]:
# cols: response_key, index, observation_key, observations, std, *[1, ...nreals]
observations_and_responses = _get_observations_and_responses(
ensemble,
observations_and_responses = ensemble.get_observations_and_responses(
selected_observations,
iens_active_index,
)
Expand Down
172 changes: 167 additions & 5 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID

import numpy as np
Expand Down Expand Up @@ -626,8 +626,32 @@ def save_cross_correlations(
file_path = os.path.join(self.mount_point, "corr_XY.nc")
self._storage._to_netcdf_transaction(file_path, dataset)

@lru_cache # noqa: B019
def load_responses(self, key: str, realizations: Tuple[int]) -> polars.DataFrame:
def load_responses(
self, key: str, realizations: Tuple[int, ...]
) -> polars.DataFrame:
"""Load responses for key and realizations into xarray Dataset.

For each given realization, response data is loaded from the NetCDF
file whose filename matches the given key parameter.

Parameters
----------
key : str
Response key to load.
realizations : tuple of int
Realization indices to load.

Returns
-------
responses : DataFrame
Loaded polars DataFrame with responses.
"""

return self._load_responses_lazy(key, realizations).collect()

def _load_responses_lazy(
self, key: str, realizations: Tuple[int, ...]
) -> polars.LazyFrame:
"""Load responses for key and realizations into xarray Dataset.

For each given realization, response data is loaded from the NetCDF
Expand Down Expand Up @@ -660,14 +684,14 @@ def load_responses(self, key: str, realizations: Tuple[int]) -> polars.DataFrame
input_path = self._realization_dir(realization) / f"{response_type}.parquet"
if not input_path.exists():
raise KeyError(f"No response for key {key}, realization: {realization}")
df = polars.read_parquet(input_path)
df = polars.scan_parquet(input_path)

if select_key:
df = df.filter(polars.col("response_key") == key)

loaded.append(df)

return polars.concat(loaded) if loaded else polars.DataFrame()
return polars.concat(loaded) if loaded else polars.DataFrame().lazy()

@deprecated("Use load_responses")
def load_all_summary_data(
Expand Down Expand Up @@ -905,3 +929,141 @@ def get_response_state(
else RealizationStorageState.UNDEFINED
for e in self.experiment.response_configuration
}

def get_observations_and_responses(
self,
selected_observations: Iterable[str],
iens_active_index: npt.NDArray[np.int_],
) -> polars.DataFrame:
"""Fetches and aligns selected observations with their corresponding simulated responses from an ensemble."""
observations_by_type = self.experiment.observations

with polars.StringCache():
dfs_per_response_type = []
for (
response_type,
response_cls,
) in self.experiment.response_configuration.items():
if response_type not in observations_by_type:
continue

observations_for_type = (
observations_by_type[response_type]
.filter(
polars.col("observation_key").is_in(list(selected_observations))
)
.with_columns(
[
polars.col("response_key")
.cast(polars.Categorical)
.alias("response_key")
]
)
)

observed_cols = {
k: observations_for_type[k].unique()
for k in ["response_key", *response_cls.primary_key]
}

reals = iens_active_index.tolist()
reals.sort()
# too much memory to do it all at once, go per realization
first_columns: polars.DataFrame | None = None
realization_columns: List[polars.DataFrame] = []
for real in reals:
responses = self._load_responses_lazy(
response_type, (real,)
).with_columns(
[
polars.col("response_key")
.cast(polars.Categorical)
.alias("response_key")
]
)

# Filter out responses without observations
for col, observed_values in observed_cols.items():
responses = responses.filter(
polars.col(col).is_in(observed_values)
)

pivoted = responses.collect().pivot(
on="realization",
index=["response_key", *response_cls.primary_key],
values="values",
aggregate_function="mean",
)

if pivoted.is_empty():
# There are no responses for this realization,
# so we explicitly create a column of nans
# to represent this. We are basically saying that
# for this realization, each observation points
# to a NaN response.
joined = observations_for_type.with_columns(
polars.Series(
str(real),
[np.nan] * len(observations_for_type),
dtype=polars.Float32,
)
)
elif "time" in pivoted:
joined = observations_for_type.join_asof(
pivoted,
by=["response_key", *response_cls.primary_key],
on="time",
tolerance="1s",
)
else:
joined = observations_for_type.join(
pivoted,
how="left",
on=["response_key", *response_cls.primary_key],
)

joined = (
joined.with_columns(
polars.concat_str(
response_cls.primary_key, separator=", "
).alias(
"__tmp_index_key__"
# Avoid potential collisions w/ primary key
)
)
.drop(response_cls.primary_key)
.rename({"__tmp_index_key__": "index"})
)

if first_columns is None:
# The "leftmost" index columns are not yet collected.
# They are the same for all iterations, and indexed the same
# because we do a left join for the observations.
# Hence, we select these columns only once.
first_columns = joined.select(
[
"response_key",
"index",
"observation_key",
"observations",
"std",
]
)

realization_columns.append(joined.select(str(real)))

if first_columns is None:
# Not a single realization had any responses to the
# observations. Hence, there is no need to include
# it in the dataset
continue

dfs_per_response_type.append(
polars.concat(
[first_columns, *realization_columns], how="horizontal"
)
)

return polars.concat(dfs_per_response_type, how="vertical").with_columns(
polars.col("response_key").cast(polars.String).alias("response_key")
)
2 changes: 1 addition & 1 deletion tests/ert/performance_tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_memory_smoothing(poly_template):
)

stats = memray._memray.compute_statistics(str(poly_template / "memray.bin"))
assert stats.peak_memory_allocated < 1024**2 * 450
assert stats.peak_memory_allocated < 1024**2 * 300


def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None:
Expand Down
Loading
Loading