Skip to content

Commit

Permalink
Detemine source for update in separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Feb 22, 2024
1 parent 9ad8072 commit 1a39d2a
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,23 @@ def _split_by_batchsize(
return np.array_split(arr, sections)


def _determine_parameter_source(
param_group_name: str, target_fs: EnsembleAccessor, source_fs: EnsembleReader
) -> Union[EnsembleReader, EnsembleAccessor]:
"""
Determines the source for a parameter group based on whether it is available in `target_fs`.
It is possible to update the same parameter multiple times.
For example, two update steps may be defined where one udpates the parameter using observation
`A` while the other udpates it using observation `B`.
After the processing of the first update step has completed, the updated parameter is stored in `traget_fs`.
Hence, when processing the second update step, we need to load the parameter from `target_fs` and not `source_fs`.
"""
if target_fs.has_parameter_group(param_group_name):
return target_fs
else:
return source_fs


def _update_with_row_scaling(
update_step: UpdateStep,
source_fs: EnsembleReader,
Expand All @@ -485,10 +502,7 @@ def _update_with_row_scaling(
) -> None:
for param_group in update_step.row_scaling_parameters:
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
source = _determine_parameter_source(param_group.name, target_fs, source_fs)
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
Expand Down Expand Up @@ -637,10 +651,7 @@ def adaptive_localization_progress_callback(

for param_group in update_step.parameters:
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
source = _determine_parameter_source(param_group.name, target_fs, source_fs)
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
Expand Down Expand Up @@ -837,13 +848,7 @@ def analysis_IES(
sies_smoother.W[:, masking_of_initial_parameters] = proposed_W

for param_group in update_step.parameters:
source: Union[EnsembleReader, EnsembleAccessor] = target_fs
try:
target_fs.load_parameters(group=param_group.name, realizations=0)[
"values"
]
except Exception:
source = source_fs
source = _determine_parameter_source(param_group.name, target_fs, source_fs)
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
Expand Down

0 comments on commit 1a39d2a

Please sign in to comment.