Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Adds probability of improvement as an acquisition function #458

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
num_initial_samples=100, num_restarts=1
)

# %% [markdown]

# It is worth noting that `ThompsonSampling` is not the only utility function we could use,
# since our module also provides e.g. `ProbabilityOfImprovement`,
# which was briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).


# %% [markdown]
# ## Putting it All Together with the Decision Maker
Expand Down
4 changes: 4 additions & 0 deletions gpjax/decision_making/utility_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
SinglePointUtilityFunction,
UtilityFunction,
)
from gpjax.decision_making.utility_functions.probability_of_improvement import (
ProbabilityOfImprovement,
)
from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling

__all__ = [
Expand All @@ -26,4 +29,5 @@
"AbstractSinglePointUtilityFunctionBuilder",
"SinglePointUtilityFunction",
"ThompsonSampling",
"ProbabilityOfImprovement",
]
127 changes: 127 additions & 0 deletions gpjax/decision_making/utility_functions/probability_of_improvement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass

from beartype.typing import Mapping
from jaxtyping import Num
import tensorflow_probability.substrates.jax as tfp

from gpjax.dataset import Dataset
from gpjax.decision_making.utility_functions.base import (
AbstractSinglePointUtilityFunctionBuilder,
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import OBJECTIVE
from gpjax.gps import ConjugatePosterior
from gpjax.typing import (
Array,
KeyArray,
)


@dataclass
class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder):
r"""
An acquisition function which returns the probability of improvement
of the objective function over the best observed value.

More precisely, given a predictive posterior distribution of the objective
function $`f`$, the probability of improvement at a test point $`x`$ is defined as:
$$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$
where $`x_{\text{best}}`$ is the minimiser of the posterior mean
at previously observed values (to handle noisy observations).

The probability of improvement can be easily computed using the
cumulative distribution function of the standard normal distribution $`\Phi`$:
$$`\text{PI}(x) = \Phi\left(\frac{f(x_{\text{best}}) - \mu}{\sigma}\right)`$$
where $`\mu`$ and $`\sigma`$ are the mean and standard deviation of the
predictive distribution of the objective function at $`x`$.

References
----------
[1] Kushner, H. J. (1964).
A new method of locating the maximum point of an arbitrary multipeak curve in the presence of noise.
Journal of Basic Engineering, 86(1), 97-106.

[2] Shahriari, B., Swersky, K., Wang, Z., Adams, R. P., & de Freitas, N. (2016).
Taking the human out of the loop: A review of Bayesian optimization.
Proceedings of the IEEE, 104(1), 148-175. doi: 10.1109/JPROC.2015.2494218
"""

def build_utility_function(
self,
posteriors: Mapping[str, ConjugatePosterior],
datasets: Mapping[str, Dataset],
key: KeyArray,
) -> SinglePointUtilityFunction:
"""
Constructs the probability of improvement utility function
using the predictive posterior of the objective function.

Args:
posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
used to form the utility function. One of the posteriors must correspond
to the `OBJECTIVE` key, as we sample from the objective posterior to form
the utility function.
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
to form the utility function. Keys in `datasets` should correspond to
keys in `posteriors`. One of the datasets must correspond
to the `OBJECTIVE` key.
key (KeyArray): JAX PRNG key used for random number generation. Since
the probability of improvement is computed deterministically
from the predictive posterior, the key is not used.

Returns:
SinglePointUtilityFunction: the probability of improvement utility function.
"""
self.check_objective_present(posteriors, datasets)

objective_posterior = posteriors[OBJECTIVE]
if not isinstance(objective_posterior, ConjugatePosterior):
raise ValueError(
"Objective posterior must be a ConjugatePosterior to compute the Probability of Improvement using a Gaussian CDF."
)

objective_dataset = datasets[OBJECTIVE]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to have something along the lines of

if objective_dataset.X is None or objective_dataset.n == 0:
            raise ValueError("Objective dataset must contain at least one item")

given that we use the objective dataset to find best_y.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed!

if (
objective_dataset.X is None
or objective_dataset.n == 0
or objective_dataset.y is None
):
raise ValueError(
"Objective dataset must be non-empty to compute the "
"Probability of Improvement (since we need a "
"`best_y` value)."
)

def probability_of_improvement(x_test: Num[Array, "N D"]):
# Computing the posterior mean for the training dataset
# for computing the best_y value (as the minimum
# posterior mean of the objective function)
predictive_dist_for_training = objective_posterior.predict(
objective_dataset.X, objective_dataset
)
best_y = predictive_dist_for_training.mean().min()

predictive_dist = objective_posterior.predict(x_test, objective_dataset)

normal_dist = tfp.distributions.Normal(
loc=predictive_dist.mean(),
scale=predictive_dist.stddev(),
)

return normal_dist.cdf(best_y).reshape(-1, 1)

return probability_of_improvement
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jax import config

config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import jax.random as jr

from gpjax.decision_making.test_functions.continuous_functions import Forrester
from gpjax.decision_making.utility_functions.probability_of_improvement import (
ProbabilityOfImprovement,
)
from gpjax.decision_making.utils import OBJECTIVE
from tests.test_decision_making.utils import generate_dummy_conjugate_posterior


def test_probability_of_improvement_gives_correct_value_for_a_seed():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {OBJECTIVE: dataset}

pi_utility_builder = ProbabilityOfImprovement()
pi_utility = pi_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)

test_X = forrester.generate_test_points(num_points=10, key=key)
utility_values = pi_utility(test_X)

# Computing the expected utility values
predictive_dist = posterior.predict(test_X, train_data=dataset)
predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()

# Computing best_y as the min. of the posterior predictive mean
# over the training set.
predictive_dist_for_training_data = posterior.predict(dataset.X, train_data=dataset)
best_y = predictive_dist_for_training_data.mean().min()

# Gaussian CDF computed "by hand"
x_ = (best_y - predictive_mean) / predictive_std
expected_utility_values = 0.5 * (
1 + jax.scipy.special.erf(x_ / jnp.sqrt(2))
).reshape(-1, 1)

assert utility_values.shape == (10, 1)
assert jnp.isclose(utility_values, expected_utility_values).all()
Original file line number Diff line number Diff line change
Expand Up @@ -17,99 +17,18 @@
config.update("jax_enable_x64", True)

from beartype.typing import Callable
import jax.numpy as jnp
import jax.random as jr
import pytest

from gpjax.dataset import Dataset
from gpjax.decision_making.test_functions.continuous_functions import (
AbstractContinuousTestFunction,
Forrester,
LogarithmicGoldsteinPrice,
)
from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling
from gpjax.decision_making.utils import OBJECTIVE
from gpjax.gps import (
ConjugatePosterior,
NonConjugatePosterior,
Prior,
)
from gpjax.kernels import RBF
from gpjax.likelihoods import (
Gaussian,
Poisson,
)
from gpjax.mean_functions import Zero
from gpjax.typing import KeyArray


def generate_dummy_conjugate_posterior(dataset: Dataset) -> ConjugatePosterior:
kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1]))
mean_function = Zero()
prior = Prior(kernel=kernel, mean_function=mean_function)
likelihood = Gaussian(num_datapoints=dataset.n)
posterior = prior * likelihood
return posterior


def generate_dummy_non_conjugate_posterior(dataset: Dataset) -> NonConjugatePosterior:
kernel = RBF(lengthscale=jnp.ones(dataset.X.shape[1]))
mean_function = Zero()
prior = Prior(kernel=kernel, mean_function=mean_function)
likelihood = Poisson(num_datapoints=dataset.n)
posterior = prior * likelihood
return posterior


@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_no_objective_posterior_raises_error():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {"CONSTRAINT": posterior}
datasets = {OBJECTIVE: dataset}
with pytest.raises(ValueError):
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)


@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_no_objective_dataset_raises_error():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {"CONSTRAINT": dataset}
with pytest.raises(ValueError):
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)


@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_non_conjugate_posterior_raises_error():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_non_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {OBJECTIVE: dataset}
with pytest.raises(ValueError):
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)
from tests.test_decision_making.utils import generate_dummy_conjugate_posterior


@pytest.mark.parametrize("num_rff_features", [0, -1, -10])
Expand All @@ -130,34 +49,6 @@ def test_thompson_sampling_invalid_rff_num_raises_error(num_rff_features: int):
)


@pytest.mark.parametrize(
"test_target_function",
[(Forrester()), (LogarithmicGoldsteinPrice())],
)
@pytest.mark.parametrize("num_test_points", [50, 100])
@pytest.mark.parametrize("key", [jr.key(42), jr.key(10)])
@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_thompson_sampling_utility_function_correct_shapes(
test_target_function: AbstractContinuousTestFunction,
num_test_points: int,
key: KeyArray,
):
dataset = test_target_function.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {OBJECTIVE: dataset}
ts_utility_builder = ThompsonSampling(num_features=100)
ts_utility_function = ts_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)
test_key, _ = jr.split(key)
test_X = test_target_function.generate_test_points(num_test_points, test_key)
ts_utility_function_values = ts_utility_function(test_X)
assert ts_utility_function_values.shape == (num_test_points, 1)


@pytest.mark.parametrize(
"test_target_function",
[(Forrester()), (LogarithmicGoldsteinPrice())],
Expand Down
Loading
Loading