Skip to content

Commit

Permalink
Add workaround for problems with microseconds
Browse files Browse the repository at this point in the history
Workaround for storage not handling datetimes
with microseconds due to index overflow in netcdf3.
#6952
  • Loading branch information
eivindjahren committed Jan 23, 2024
1 parent cdd505b commit 7cdce3d
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 108 deletions.
20 changes: 13 additions & 7 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,22 +311,28 @@ def _get_obs_and_measure_data(
name: list(set(index.get_level_values(name))) for name in index.names
}
observation = observation.sel(sub_selection)
ds = source_fs.load_responses(group, tuple(iens_active_index))
response = source_fs.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
response = response.reindex(
time=observation.time, method="nearest", tolerance="1s" # type: ignore
)
try:
filtered_ds = observation.merge(ds, join="left")
filtered_response = observation.merge(response, join="left")
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched index for: "
f"Observation: {obs_key} attached to response: {group}"
) from e

observation_keys.append([obs_key] * len(filtered_ds.observations.data.ravel()))
observation_values.append(filtered_ds["observations"].data.ravel())
observation_errors.append(filtered_ds["std"].data.ravel())
observation_keys.append(
[obs_key] * len(filtered_response.observations.data.ravel())
)
observation_values.append(filtered_response["observations"].data.ravel())
observation_errors.append(filtered_response["std"].data.ravel())
measured_data.append(
filtered_ds["values"]
filtered_response["values"]
.transpose(..., "realization")
.values.reshape((-1, len(filtered_ds.realization)))
.values.reshape((-1, len(filtered_response.realization)))
)
source_fs.load_responses.cache_clear()
return (
Expand Down
25 changes: 23 additions & 2 deletions src/ert/config/_read_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def _read_spec(
hour=hour,
minute=minute,
second=microsecond // 10**6,
microsecond=microsecond % 10**6,
# Due to https://github.com/equinor/ert/issues/6952
# microsecond=self.micro_seconds % 10**6,
)
except Exception as err:
raise ValueError(
Expand Down Expand Up @@ -404,6 +405,19 @@ def optional_get(arr: Optional[npt.NDArray[Any]], idx: int) -> Any:
)


def _round_to_seconds(dt: datetime) -> datetime:
"""
>>> _round_to_seconds(datetime(2000, 1, 1, 1, 0, 1, 1))
datetime.datetime(2000, 1, 1, 1, 0, 1)
>>> _round_to_seconds(datetime(2000, 1, 1, 1, 0, 1, 500001))
datetime.datetime(2000, 1, 1, 1, 0, 2)
>>> _round_to_seconds(datetime(2000, 1, 1, 1, 0, 1, 499999))
datetime.datetime(2000, 1, 1, 1, 0, 1)
"""
extra_sec = round(dt.microsecond / 10**6)
return dt.replace(microsecond=0) + timedelta(seconds=extra_sec)


def _read_summary(
summary: str,
start_date: datetime,
Expand All @@ -427,7 +441,14 @@ def read_params() -> None:
if last_params is not None:
vals = _check_vals("PARAMS", summary, last_params.read_array())
values.append(vals[indices])
dates.append(start_date + unit.make_delta(float(vals[date_index])))

dates.append(
_round_to_seconds(
start_date + unit.make_delta(float(vals[date_index]))
),
)
# Due to https://github.com/equinor/ert/issues/6952
# dates.append(start_date + unit.make_delta(float(vals[date_index])))
last_params = None

with open(summary, mode) as fp:
Expand Down
13 changes: 0 additions & 13 deletions src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,6 @@ def __post_init__(self) -> None:
def read_from_file(self, run_path: str, iens: int) -> xr.Dataset:
filename = self.input_file.replace("<IENS>", str(iens))
_, keys, time_map, data = read_summary(f"{run_path}/{filename}", self.keys)

if self.refcase:
assert isinstance(self.refcase, set)
missing = self.refcase.difference(time_map)
if missing:
first, last = min(missing), max(missing)
logger.warning(
f"Realization: {iens}, load warning: {len(missing)} "
f"inconsistencies in time map, first: Time mismatch for response "
f"time: {first}, last: Time mismatch for response time: "
f"{last} from: {run_path}/{filename}.UNSMRY"
)

ds = xr.Dataset(
{"values": (["name", "time"], data)},
coords={"time": time_map, "name": keys},
Expand Down
159 changes: 159 additions & 0 deletions tests/integration_tests/test_observation_times.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
Tests behavior of matching response times to observation times
"""
from contextlib import redirect_stderr
from datetime import datetime
from io import StringIO
from textwrap import dedent

import hypothesis.strategies as st
import numpy as np
import pytest
from hypothesis import assume, given, settings

from ert.cli import ES_MDA_MODE
from ert.cli.main import ErtCliError
from tests.unit_tests.config.observations_generator import summary_observations
from tests.unit_tests.config.summary_generator import summaries

from .run_cli import run_cli

start = datetime(1969, 1, 1)
observation_time = datetime(2024, 1, 1)
epsilon = 0.1


@settings(max_examples=3)
@given(
responses=st.lists(
summaries(
start_date=st.just(start),
time_deltas=st.just(
[((observation_time - start).total_seconds() + epsilon) / 3600]
),
summary_keys=st.just(["FOPR"]),
use_days=st.just(False),
),
min_size=2,
max_size=2,
),
observation=summary_observations(
summary_keys=st.just("FOPR"),
std_cutoff=10.0,
names=st.just("FOPR_OBSERVATION"),
dates=st.just(observation_time),
time_types=st.just("date"),
),
std_cutoff=st.floats(min_value=0.0, max_value=1.0),
enkf_alpha=st.floats(min_value=3.0, max_value=10.0),
)
def test_small_time_mismatches_are_ignored(
responses, observation, tmp_path_factory, std_cutoff, enkf_alpha
):
tmp_path = tmp_path_factory.mktemp("summary")
(tmp_path / "config.ert").write_text(
dedent(
f"""
NUM_REALIZATIONS 2
ECLBASE CASE
SUMMARY FOPR
MAX_SUBMIT 1
GEN_KW KW_NAME prior.txt
OBS_CONFIG observations.txt
STD_CUTOFF {std_cutoff}
ENKF_ALPHA {enkf_alpha}
"""
)
)
(tmp_path / "prior.txt").write_text("KW_NAME NORMAL 0 1")
response_values = np.array(
[r[1].steps[-1].ministeps[-1].params[-1] for r in responses]
)
std_dev = response_values.std(ddof=0)
assume(np.isfinite(std_dev))
assume(std_dev > std_cutoff)
observation.value = float(response_values.mean())
for i in range(2):
for j in range(4):
summary = responses[i]
smspec, unsmry = summary
(tmp_path / f"simulations/realization-{i}/iter-{j}").mkdir(parents=True)
smspec.to_file(
tmp_path / f"simulations/realization-{i}/iter-{j}/CASE.SMSPEC"
)
unsmry.to_file(
tmp_path / f"simulations/realization-{i}/iter-{j}/CASE.UNSMRY"
)
(tmp_path / "observations.txt").write_text(str(observation))

stderr = StringIO()
with redirect_stderr(stderr):
run_cli(ES_MDA_MODE, str(tmp_path / "config.ert"), "--weights=0,1")
assert "Experiment completed" in stderr.getvalue()


@settings(max_examples=3)
@given(
responses=st.lists(
summaries(
start_date=st.just(start),
time_deltas=st.just(
[((observation_time - start).total_seconds() + 360000) / 3600]
),
summary_keys=st.just(["FOPR"]),
use_days=st.just(False),
),
min_size=2,
max_size=2,
),
observation=summary_observations(
summary_keys=st.just("FOPR"),
std_cutoff=10.0,
names=st.just("FOPR_OBSERVATION"),
dates=st.just(observation_time),
time_types=st.just("date"),
),
std_cutoff=st.floats(min_value=0.0, max_value=1.0),
enkf_alpha=st.floats(min_value=3.0, max_value=10.0),
)
def test_big_time_mismatches_results_in_failure(
responses, observation, tmp_path_factory, std_cutoff, enkf_alpha
):
tmp_path = tmp_path_factory.mktemp("summary")
(tmp_path / "config.ert").write_text(
dedent(
f"""
NUM_REALIZATIONS 2
ECLBASE CASE
SUMMARY FOPR
MAX_SUBMIT 1
GEN_KW KW_NAME prior.txt
OBS_CONFIG observations.txt
STD_CUTOFF {std_cutoff}
ENKF_ALPHA {enkf_alpha}
"""
)
)
(tmp_path / "prior.txt").write_text("KW_NAME NORMAL 0 1")
response_values = np.array(
[r[1].steps[-1].ministeps[-1].params[-1] for r in responses]
)
std_dev = response_values.std(ddof=0)
assume(np.isfinite(std_dev))
assume(std_dev > std_cutoff)
observation.value = float(response_values.mean())
for i in range(2):
for j in range(4):
summary = responses[i]
smspec, unsmry = summary
(tmp_path / f"simulations/realization-{i}/iter-{j}").mkdir(parents=True)
smspec.to_file(
tmp_path / f"simulations/realization-{i}/iter-{j}/CASE.SMSPEC"
)
unsmry.to_file(
tmp_path / f"simulations/realization-{i}/iter-{j}/CASE.UNSMRY"
)
(tmp_path / "observations.txt").write_text(str(observation))

with pytest.raises(ErtCliError, match="No active observations"):
run_cli(ES_MDA_MODE, str(tmp_path / "config.ert"), "--weights=0,1")
38 changes: 27 additions & 11 deletions tests/unit_tests/config/observations_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,45 @@ def general_observations(draw, ensemble_keys, std_cutoff, names):
return GeneralObservation(**kws)


positive_floats = st.floats(min_value=0.1, allow_nan=False, allow_infinity=False)
positive_floats = st.floats(
min_value=0.1, max_value=1e25, allow_nan=False, allow_infinity=False
)
dates = st.datetimes(
max_value=datetime.datetime(year=2024, month=1, day=1),
min_value=datetime.datetime(year=1969, month=1, day=1),
)
time_types = st.sampled_from(["date", "days", "restart", "hours"])


@st.composite
def summary_observations(draw, summary_keys, std_cutoff, names):
def summary_observations(
draw, summary_keys, std_cutoff, names, dates=dates, time_types=time_types
):
kws = {
"name": draw(names),
"key": draw(summary_keys),
"error": draw(
st.floats(min_value=std_cutoff, allow_nan=False, allow_infinity=False)
st.floats(
min_value=std_cutoff,
max_value=std_cutoff * 1.1,
allow_nan=False,
allow_infinity=False,
)
),
"error_min": draw(
st.floats(
min_value=std_cutoff,
max_value=std_cutoff * 1.1,
allow_nan=False,
allow_infinity=False,
)
),
"error_min": draw(positive_floats),
"error_mode": draw(st.sampled_from(ErrorMode)),
"value": draw(positive_floats),
}
time_type = draw(st.sampled_from(["date", "days", "restart", "hours"]))
time_type = draw(time_types)
if time_type == "date":
date = draw(
st.datetimes(
max_value=datetime.datetime(year=2037, month=1, day=1),
min_value=datetime.datetime(year=1999, month=1, day=2),
)
)
date = draw(dates)
kws["date"] = date.strftime("%Y-%m-%d")
if time_type in ["days", "hours"]:
kws[time_type] = draw(st.floats(min_value=1, max_value=3000))
Expand Down
Loading

0 comments on commit 7cdce3d

Please sign in to comment.