Skip to content

Commit

Permalink
Remove try-except around loading responses
Browse files Browse the repository at this point in the history
Function should raise if errors are encountered
  • Loading branch information
dafeda committed Dec 21, 2023
1 parent f874da8 commit 86a702b
Showing 1 changed file with 47 additions and 46 deletions.
93 changes: 47 additions & 46 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections import UserDict
from dataclasses import dataclass, field
from datetime import datetime
from math import sqrt
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -348,7 +347,14 @@ def _load_observations_and_responses(
iens_ative_index: npt.NDArray[np.int_],
selected_observations: List[Tuple[str, Optional[List[int]]]],
misfit_process: bool,
) -> Any:
) -> Tuple[
npt.NDArray[np.float_],
Tuple[
npt.NDArray[np.float_],
npt.NDArray[np.float_],
List[ObservationAndResponseSnapshot],
],
]:
S, observations, errors, obs_keys = _get_obs_and_measure_data(
source_fs,
selected_observations,
Expand All @@ -359,19 +365,19 @@ def _load_observations_and_responses(
# in for example evensen2018 - Analysis of iterative ensemble smoothers for
# solving inverse problems.
# `global_std_scaling` is 1.0 for ES.
scaling = np.ones(len(errors))
scaling *= sqrt(global_std_scaling)
scaling = np.sqrt(global_std_scaling) * np.ones_like(errors)
scaled_errors = errors * scaling

# Identifies non-outlier observations based on responses.
ens_mean = S.mean(axis=1)
ens_std = S.std(ddof=0, axis=1)

ens_std_mask = ens_std > std_cutoff
ens_mean_mask = abs(observations - ens_mean) <= alpha * (ens_std + errors * scaling)
ens_mean_mask = abs(observations - ens_mean) <= alpha * (ens_std + scaled_errors)
obs_mask = np.logical_and(ens_mean_mask, ens_std_mask)

if misfit_process:
scaling[obs_mask] *= misfit_preprocessor.main(
S[obs_mask], (errors * scaling)[obs_mask]
S[obs_mask], scaled_errors[obs_mask]
)

update_snapshot = []
Expand Down Expand Up @@ -409,10 +415,10 @@ def _load_observations_and_responses(

for missing_obs in obs_keys[~obs_mask]:
_logger.warning(f"Deactivating observation: {missing_obs}")
errors *= scaling

return S[obs_mask], (
observations[obs_mask],
errors[obs_mask],
scaled_errors[obs_mask],
update_snapshot,
)

Expand Down Expand Up @@ -512,25 +518,23 @@ def analysis_ES(
progress_callback(
AnalysisStatusEvent(msg="Loading observations and responses..")
)
try:
(
S,
(
S,
(
observation_values,
observation_errors,
update_snapshot,
),
) = _load_observations_and_responses(
source_fs,
alpha,
std_cutoff,
global_scaling,
iens_active_index,
update_step.observation_config(),
misfit_process,
)
except IndexError as e:
raise ErtAnalysisError(e) from e
observation_values,
observation_errors,
update_snapshot,
),
) = _load_observations_and_responses(
source_fs,
alpha,
std_cutoff,
global_scaling,
iens_active_index,
update_step.observation_config(),
misfit_process,
)

smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot

num_obs = len(observation_values)
Expand Down Expand Up @@ -747,26 +751,23 @@ def analysis_IES(
AnalysisStatusEvent(msg="Loading observations and responses..")
)

# Load responses, and observations
try:
(
S,
(
S,
(
observation_values,
observation_errors,
update_snapshot,
),
) = _load_observations_and_responses(
source_fs,
alpha,
std_cutoff,
global_scaling,
iens_active_index,
update_step.observation_config(),
misfit_preprocessor,
)
except IndexError as e:
raise ErtAnalysisError(str(e)) from e
observation_values,
observation_errors,
update_snapshot,
),
) = _load_observations_and_responses(
source_fs,
alpha,
std_cutoff,
global_scaling,
iens_active_index,
update_step.observation_config(),
misfit_preprocessor,
)

smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot
if len(observation_values) == 0:
raise ErtAnalysisError(
Expand Down

0 comments on commit 86a702b

Please sign in to comment.