Skip to content

Commit

Permalink
Return cross correlations for multiple GEN_KWs
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 13, 2024
1 parent 9725d7d commit 359eb5a
Show file tree
Hide file tree
Showing 16 changed files with 532 additions and 505 deletions.
15 changes: 15 additions & 0 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from collections.abc import Callable, Iterable, Sequence
from fnmatch import fnmatch
from itertools import groupby
from typing import (
TYPE_CHECKING,
Generic,
Expand Down Expand Up @@ -168,6 +169,7 @@ def _load_observations_and_responses(
npt.NDArray[np.float64],
tuple[
npt.NDArray[np.float64],
list[str],
npt.NDArray[np.float64],
list[ObservationAndResponseSnapshot],
],
Expand Down Expand Up @@ -315,6 +317,7 @@ def _load_observations_and_responses(

return S[obs_mask], (
observations[obs_mask],
obs_keys[obs_mask],
scaled_errors[obs_mask],
update_snapshot,
)
Expand Down Expand Up @@ -458,6 +461,7 @@ def adaptive_localization_progress_callback(
S,
(
observation_values,
observation_keys,
observation_errors,
update_snapshot,
),
Expand All @@ -474,6 +478,14 @@ def adaptive_localization_progress_callback(
num_obs = len(observation_values)

smoother_snapshot.update_step_snapshots = update_snapshot
# Used as labels for observations in cross-correlation matrix.
# Say we have two observation groups "FOPR" and "WOPR" where "FOPR" has
# 2 responses and "WOPR" has 3.
# In this case we create a list [FOPR_0, FOPR_1, WOPR_0, WOPR_1, WOPR_2]
# as labels for observations.
unique_obs_names = [
f"{k}_{i}" for k, g in groupby(observation_keys) for i, _ in enumerate(list(g))
]

if num_obs == 0:
msg = "No active observations for update step"
Expand Down Expand Up @@ -577,6 +589,8 @@ def correlation_callback(
cross_correlations_,
param_group,
parameter_names[: cross_correlations_.shape[0]],
unique_obs_names,
list(observation_keys),
)
logger.info(
f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes"
Expand Down Expand Up @@ -639,6 +653,7 @@ def analysis_IES(
S,
(
observation_values,
_,
observation_errors,
update_snapshot,
),
Expand Down
3 changes: 2 additions & 1 deletion src/ert/resources/forward_models/template_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import argparse
import json
import os
from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any

import jinja2
import yaml
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Sequence
from collections.abc import Sequence

import pandas

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import contextlib
import json
import os
from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any

import numpy
import pandas as pd
Expand Down
39 changes: 25 additions & 14 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import pandas as pd
import polars as pl
import xarray as xr
from pydantic import BaseModel
from typing_extensions import deprecated
Expand Down Expand Up @@ -560,16 +561,15 @@ def load_parameters(

return self._load_dataset(group, realizations)

def load_cross_correlations(self) -> xr.Dataset:
input_path = self.mount_point / "corr_XY.nc"

def load_cross_correlations(self) -> pl.DataFrame:
input_path = self.mount_point / "corr_XY.parquet"
if not input_path.exists():
raise FileNotFoundError(
f"No cross-correlation data available at '{input_path}'. Make sure to run the update with "
"Adaptive Localization enabled."
)
logger.info("Loading cross correlations")
return xr.open_dataset(input_path, engine="scipy")
return pl.read_parquet(input_path)

@require_write
def save_observation_scaling_factors(self, dataset: polars.DataFrame) -> None:
Expand All @@ -592,17 +592,28 @@ def save_cross_correlations(
cross_correlations: npt.NDArray[np.float64],
param_group: str,
parameter_names: list[str],
unique_obs_names: list[str],
observation_keys: list[str],
) -> None:
data_vars = {
param_group: xr.DataArray(
data=cross_correlations,
dims=["parameter", "response"],
coords={"parameter": parameter_names},
)
}
dataset = xr.Dataset(data_vars)
file_path = os.path.join(self.mount_point, "corr_XY.nc")
self._storage._to_netcdf_transaction(file_path, dataset)
n_responses = cross_correlations.shape[1]
new_df = pl.DataFrame(
{
"param_group": [param_group]
* (len(parameter_names) * len(unique_obs_names)),
"param_name": np.repeat(parameter_names, n_responses),
"obs_group": observation_keys * len(parameter_names),
"obs_name": unique_obs_names * len(parameter_names),
"value": cross_correlations.ravel(),
}
)

file_path = os.path.join(self.mount_point, "corr_XY.parquet")
if os.path.exists(file_path):
existing_df = pl.read_parquet(file_path)
df = pl.concat([existing_df, new_df])
else:
df = new_df
self._storage._to_parquet_transaction(file_path, df)

def load_responses(
self, key: str, realizations: tuple[int, ...]
Expand Down
264 changes: 264 additions & 0 deletions test-data/ert/heat_equation/Plot_correlations.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions test-data/ert/heat_equation/config.ert
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ QUEUE_OPTION LOCAL MAX_RUNNING 100

RANDOM_SEED 11223344

ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.1

NUM_REALIZATIONS 100
GRID CASE.EGRID

OBS_CONFIG observations

FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True

GEN_KW INIT_TEMP_SCALE init_temp_prior.txt
GEN_KW CORR_LENGTH corr_length_prior.txt

GEN_DATA MY_RESPONSE RESULT_FILE:gen_data_%d.out REPORT_STEPS:10,71,132,193,255,316,377,438 INPUT_FORMAT:ASCII

INSTALL_JOB heat_equation HEAT_EQUATION
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/corr_length_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x NORMAL 0.8 0.1
23 changes: 19 additions & 4 deletions test-data/ert/heat_equation/heat_equation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""Partial Differential Equations to use as forward models."""

import json
import sys

import geostat
Expand Down Expand Up @@ -51,16 +52,28 @@ def heat_equation(
return u_


def sample_prior_conductivity(ensemble_size, nx, rng):
def sample_prior_conductivity(ensemble_size, nx, rng, corr_length):
mesh = np.meshgrid(np.linspace(0, 1, nx), np.linspace(0, 1, nx))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=0.8))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=corr_length))


def load_parameters(filename):
with open(filename, encoding="utf-8") as f:
return json.load(f)


if __name__ == "__main__":
iens = int(sys.argv[1])
iteration = int(sys.argv[2])
rng = np.random.default_rng(iens)
cond = sample_prior_conductivity(ensemble_size=1, nx=nx, rng=rng).reshape(nx, nx)

parameters = load_parameters("parameters.json")
init_temp_scale = parameters["INIT_TEMP_SCALE"]
corr_length = parameters["CORR_LENGTH"]

cond = sample_prior_conductivity(
ensemble_size=1, nx=nx, rng=rng, corr_length=float(corr_length["x"])
).reshape(nx, nx)

if iteration == 0:
resfo.write(
Expand All @@ -78,7 +91,9 @@ def sample_prior_conductivity(ensemble_size, nx, rng):
# Note that this could be avoided if we used an implicit solver.
dt = dx**2 / (4 * max(np.max(cond), np.max(cond)))

response = heat_equation(u_init, cond, dx, dt, k_start, k_end, rng)
scaled_u_init = u_init * float(init_temp_scale["x"])

response = heat_equation(scaled_u_init, cond, dx, dt, k_start, k_end, rng)

index = sorted((obs.x, obs.y) for obs in obs_coordinates)
for time_step in obs_times:
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/init_temp_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x UNIFORM 0 1
548 changes: 105 additions & 443 deletions test-data/ert/poly_example/Plot_correlations.ipynb

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions tests/ert/ui_tests/cli/analysis/test_adaptive_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,15 @@ def test_that_posterior_generalized_variance_increases_in_cutoff():
)

cross_correlations = prior_ensemble_cutoff1.load_cross_correlations()
assert all(cross_correlations.parameter.to_numpy() == ["a", "b"])
assert cross_correlations["COEFFS"].values.shape == (2, 5)
assert (
(cross_correlations["COEFFS"].values >= -1)
& (cross_correlations["COEFFS"].values <= 1)
).all()
assert cross_correlations["param_group"].unique().to_list() == ["COEFFS"]
assert sorted(cross_correlations["param_name"].unique().to_list()) == [
"a",
"b",
"c",
]
# Make sure correlations are between -1 and 1.
is_valid = (cross_correlations["value"] >= -1) & (cross_correlations["value"] <= 1)
assert is_valid.all()

prior_sample_cutoff1 = prior_ensemble_cutoff1.load_parameters("COEFFS")["values"]
prior_cov = np.cov(prior_sample_cutoff1, rowvar=False)
Expand Down
20 changes: 20 additions & 0 deletions tests/ert/ui_tests/cli/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def sample_prior(nx, ny):
@pytest.mark.integration_test
@pytest.mark.usefixtures("copy_snake_oil_field")
def test_update_multiple_param():
with open("snake_oil.ert", "a", encoding="utf-8") as f:
f.write("\nANALYSIS_SET_VAR STD_ENKF LOCALIZATION True\n")
f.write("ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.1\n")

run_cli(
ENSEMBLE_SMOOTHER_MODE,
"--disable-monitor",
Expand All @@ -183,6 +187,22 @@ def test_update_multiple_param():
# https://en.wikipedia.org/wiki/Variance#For_vector-valued_random_variables
assert np.trace(np.cov(posterior_array)) < np.trace(np.cov(prior_array))

corr_XY = prior_ensemble.load_cross_correlations()
expected_obs_groups = [obs[0] for obs in ert_config.observation_config]
obs_groups = corr_XY["obs_group"].unique().to_list()
assert sorted(obs_groups) == sorted(expected_obs_groups)
# Check that obs names are created using obs groups
obs_name_starts_with_group = (
corr_XY.with_columns(
polars.col("obs_name")
.str.starts_with(polars.col("obs_group"))
.alias("starts_with_check")
)
.get_column("starts_with_check")
.all()
)
assert obs_name_starts_with_group


@pytest.mark.usefixtures("copy_poly_case")
def test_that_update_works_with_failed_realizations():
Expand Down
77 changes: 51 additions & 26 deletions tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,63 @@
from .run_cli import run_cli


def test_field_param_update_using_heat_equation(heat_equation_storage):
config = ErtConfig.from_file("config.ert")
with open_storage(config.ens_path, mode="w") as storage:
def test_shared_heat_equation_storage(heat_equation_storage):
"""The fixture heat_equation_storage runs the heat equation test case.
This test verifies that results are as expected.
"""
config = heat_equation_storage
with open_storage(config.ens_path) as storage:
experiment = storage.get_experiment_by_name("es-mda")
prior = experiment.get_ensemble_by_name("default_0")
posterior = experiment.get_ensemble_by_name("default_1")

prior_result = prior.load_parameters("COND")["values"]
ensembles = [experiment.get_ensemble_by_name(f"default_{i}") for i in range(4)]

param_config = config.ensemble_config.parameter_configs["COND"]
assert len(prior_result.x) == param_config.nx
assert len(prior_result.y) == param_config.ny
assert len(prior_result.z) == param_config.nz

posterior_result = posterior.load_parameters("COND")["values"]
prior_covariance = np.cov(
prior_result.values.reshape(
prior.ensemble_size, param_config.nx * param_config.ny * param_config.nz
),
rowvar=False,
)
posterior_covariance = np.cov(
posterior_result.values.reshape(
posterior.ensemble_size,
# Check that generalized variance decreases across consecutive ensembles
covariances = []
for ensemble in ensembles:
results = ensemble.load_parameters("COND")["values"]
reshaped_values = results.values.reshape(
ensemble.ensemble_size,
param_config.nx * param_config.ny * param_config.nz,
),
rowvar=False,
)
# Check that generalized variance is reduced by update step.
assert np.trace(prior_covariance) > np.trace(posterior_covariance)
)
covariances.append(np.cov(reshaped_values, rowvar=False))
for i in range(len(covariances) - 1):
assert np.trace(covariances[i]) > np.trace(
covariances[i + 1]
), f"Generalized variance did not decrease from iteration {i} to {i + 1}"

# Check that the saved cross-correlations are as expected.
for i in range(3):
ensemble = ensembles[i]
corr_XY = ensemble.load_cross_correlations()

assert sorted(corr_XY["param_group"].unique().to_list()) == [
"CORR_LENGTH",
"INIT_TEMP_SCALE",
]
assert corr_XY["param_name"].unique().to_list() == ["x"]

# Make sure correlations are between -1 and 1.
is_valid = (corr_XY["value"] >= -1) & (corr_XY["value"] <= 1)
assert is_valid.all()

# Check obs names and obs groups
expected_obs_groups = [obs[0] for obs in config.observation_config]
obs_groups = corr_XY["obs_group"].unique().to_list()
assert sorted(obs_groups) == sorted(expected_obs_groups)
# Check that obs names are created using obs groups
obs_name_starts_with_group = (
corr_XY.with_columns(
pl.col("obs_name")
.str.starts_with(pl.col("obs_group"))
.alias("starts_with_check")
)
.get_column("starts_with_check")
.all()
)
assert obs_name_starts_with_group

# Check that fields in the runpath are different between iterations
# Check that fields in the runpath are different between ensembles
cond_iter0 = resfo.read("simulations/realization-0/iter-0/cond.bgrdecl")[0][1]
cond_iter1 = resfo.read("simulations/realization-0/iter-1/cond.bgrdecl")[0][1]
assert (cond_iter0 != cond_iter1).all()
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 359eb5a

Please sign in to comment.