Skip to content

Commit

Permalink
Sample a single value instead of ensemble_size
Browse files Browse the repository at this point in the history
It is not necessary to sample #ensemble_size values
and then discard all but one.
This instead advances the rng using the realization number
and then samples just one value.
  • Loading branch information
dafeda committed Jan 2, 2024
1 parent b758cec commit cad295e
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def sample_or_load(
keys,
str(random_seed),
real_nr,
ensemble_size,
)

return xr.Dataset(
Expand Down Expand Up @@ -322,8 +321,31 @@ def _sample_value(
keys: List[str],
global_seed: str,
realization: int,
nr_samples: int,
) -> npt.NDArray[np.double]:
"""
Generate a sample value for each key in a parameter group.
The sampling is reproducible and dependent on a global seed combined
with the parameter group name and individual key names. The 'realization' parameter
determines the specific sample point from the distribution for each parameter.
Parameters:
- parameter_group_name (str): The name of the parameter group, used to ensure unique RNG
seeds for different groups.
- keys (List[str]): A list of parameter keys for which the sample values are generated.
- global_seed (str): A global seed string used for RNG seed generation to ensure
reproducibility across runs.
- realization (int): An integer used to advance the RNG to a specific point in its
sequence, effectively selecting the 'realization'-th sample from the distribution.
Returns:
- npt.NDArray[np.double]: An array of sample values, one for each key in the provided list.
Note:
The method uses SHA-256 for hash generation and numpy's default random number generator
for sampling. The RNG state is advanced to the 'realization' point before generating
a single sample, enhancing efficiency by avoiding the generation of large, unused sample sets.
"""
parameter_values = []
for key in keys:
key_hash = sha256(
Expand All @@ -332,8 +354,13 @@ def _sample_value(
)
seed = np.frombuffer(key_hash.digest(), dtype="uint32")
rng = np.random.default_rng(seed)
values = rng.standard_normal(nr_samples)
parameter_values.append(values[realization])

# Advance the RNG state to the realization point
rng.standard_normal(realization)

# Generate a single sample
value = rng.standard_normal(1)
parameter_values.append(value[0])
return np.array(parameter_values)

@staticmethod
Expand Down

0 comments on commit cad295e

Please sign in to comment.