From 4b15a8f97e2225ca41f147a4ae8b7d5d12909de5 Mon Sep 17 00:00:00 2001 From: Thomas-Christie Date: Thu, 28 Dec 2023 19:28:01 +0000 Subject: [PATCH] Fix BO notebook --- docs/examples/bayesian_optimisation.py | 65 +++++++++++++++----------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/docs/examples/bayesian_optimisation.py b/docs/examples/bayesian_optimisation.py index a2361814d..39feba64f 100644 --- a/docs/examples/bayesian_optimisation.py +++ b/docs/examples/bayesian_optimisation.py @@ -164,13 +164,21 @@ # $$f(x) = (6x - 2)^2 \sin(12x - 4)$$ # # treating $f$ as a black-box function. Moreover, we shall restrict the domain of the -# function to $\mathbf{x} \in [0, 1]$. The global minimum of this function is located at -# $x = 0.757$, where $f(x) = -6.021$. +# function to $\mathbf{x} \in [0, 1]$. We shall also *standardise* the output of the function, such that +# it has a mean of 0 and standard deviation of 1. This is quite common practice when using +# GPs; we're using a zero mean prior, so ensuring that our data has a mean of zero aligns +# with this, and often we have scale parameters in the covariance function, which +# are frequently initialised, or have priors set on them, under the assumption that the +# function being modelled has unit variance. For similar reasons, it can also be useful +# to normalise the inputs to a GP. The global minimum of this (standardised) +# function is located at $x = 0.757$, where $f(x) = -1.463$. # %% -def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: - return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4) +def standardised_forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + mean = 0.45321 + std = 4.4258 + return ((6 * x - 2) ** 2 * jnp.sin(12 * x - 4) - mean) / std # %% [markdown] @@ -189,7 +197,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: initial_x = tfp.mcmc.sample_halton_sequence( dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64 ).reshape(-1, 1) -initial_y = forrester(initial_x) +initial_y = standardised_forrester(initial_x) D = gpx.Dataset(X=initial_x, y=initial_y) @@ -204,7 +212,7 @@ def return_optimised_posterior( data: gpx.Dataset, prior: gpx.base.Module, key: Array ) -> gpx.base.Module: likelihood = gpx.likelihoods.Gaussian( - num_datapoints=data.n, obs_stddev=jnp.array(1e-3) + num_datapoints=data.n, obs_stddev=jnp.array(1e-6) ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value likelihood = likelihood.replace_trainable(obs_stddev=False) @@ -214,7 +222,7 @@ def return_optimised_posterior( negative_mll(posterior, train_data=data) negative_mll = jit(negative_mll) - opt_posterior, history = gpx.fit( + opt_posterior, _ = gpx.fit( model=posterior, objective=negative_mll, train_data=data, @@ -301,7 +309,7 @@ def optimise_sample( x_star = optimise_sample(approx_sample, key, lower_bound, upper_bound, 100) -y_star = forrester(x_star) +y_star = standardised_forrester(x_star) # %% [markdown] @@ -321,7 +329,7 @@ def plot_bayes_opt( queried_x: ScalarFloat, ) -> None: plt_x = jnp.linspace(0, 1, 1000).reshape(-1, 1) - forrester_y = forrester(plt_x) + forrester_y = standardised_forrester(plt_x) sample_y = sample(plt_x) latent_dist = posterior.predict(plt_x, train_data=dataset) @@ -392,7 +400,7 @@ def plot_bayes_opt( initial_x = tfp.mcmc.sample_halton_sequence( dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64 ).reshape(-1, 1) -initial_y = forrester(initial_x) +initial_y = standardised_forrester(initial_x) D = gpx.Dataset(X=initial_x, y=initial_y) for i in range(bo_iters): @@ -415,7 +423,7 @@ def plot_bayes_opt( plot_bayes_opt(opt_posterior, approx_sample, D, x_star) # Evaluate the black-box function at the best point observed so far, and add it to the dataset - y_star = forrester(x_star) + y_star = standardised_forrester(x_star) print(f"Queried Point: {x_star}, Black-Box Function Value: {y_star}") D = D + gpx.Dataset(X=x_star, y=y_star) @@ -435,7 +443,7 @@ def plot_bayes_opt( cumulative_best_y = jax.lax.associative_scan(jax.numpy.minimum, D.y) ax.plot(fn_evaluations, cumulative_best_y) ax.axvline(x=initial_sample_num, linestyle=":") -ax.axhline(y=-6.0207, linestyle="--", label="True Minimum") +ax.axhline(y=-1.463, linestyle="--", label="True Minimum") ax.set_xlabel("Number of Black-Box Function Evaluations") ax.set_ylabel("Best Observed Value") ax.legend() @@ -452,18 +460,23 @@ def plot_bayes_opt( # # $$f(x_1, x_2) = (4 - 2.1x_1^2 + \frac{x_1^4}{3})x_1^2 + x_1x_2 + (-4 + 4x_2^2)x_2^2$$ # -# We'll be evaluating it over the domain $x_1 \in [-2, 2]$ and $x_2 \in [-1, 1]$. The -# global minima of this function are located at $\mathbf{x} = (0.0898, -0.7126)$ and $\mathbf{x} = (-0.0898, 0.7126)$, where the function takes the value $f(\mathbf{x}) = -1.0316$. +# We'll be evaluating it over the domain $x_1 \in [-2, 2]$ and $x_2 \in [-1, 1]$, and +# shall standardise it. The +# global minima of this function are located at $\mathbf{x} = (0.0898, -0.7126)$ and +# $\mathbf{x} = (-0.0898, 0.7126)$, where the standardised function takes the value $f(\mathbf{x}) = +# -1.8377$. # %% -def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: +def standardised_six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: + mean = 1.12767 + std = 1.17500 x1 = x[..., :1] x2 = x[..., 1:] term1 = (4 - 2.1 * x1**2 + x1**4 / 3) * x1**2 term2 = x1 * x2 term3 = (-4 + 4 * x2**2) * x2**2 - return term1 + term2 + term3 + return (term1 + term2 + term3 - mean) / std # %% [markdown] @@ -474,7 +487,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: x2 = jnp.linspace(-1, 1, 100) x1, x2 = jnp.meshgrid(x1, x2) x = jnp.stack([x1.flatten(), x2.flatten()], axis=1) -y = six_hump_camel(x) +y = standardised_six_hump_camel(x) fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) surf = ax.plot_surface( @@ -521,7 +534,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: lower_bound = jnp.array([-2.0, -1.0]) upper_bound = jnp.array([2.0, 1.0]) initial_sample_num = 5 -bo_iters = 11 +bo_iters = 12 num_experiments = 5 bo_experiment_results = [] @@ -532,7 +545,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: dim=2, num_results=initial_sample_num, seed=key, dtype=jnp.float64 ) initial_x = jnp.array(lower_bound + (upper_bound - lower_bound) * initial_x) - initial_y = six_hump_camel(initial_x) + initial_y = standardised_six_hump_camel(initial_x) D = gpx.Dataset(X=initial_x, y=initial_y) for i in range(bo_iters): @@ -559,7 +572,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: ) # Evaluate the black-box function at the best point observed so far, and add it to the dataset - y_star = six_hump_camel(x_star) + y_star = standardised_six_hump_camel(x_star) print( f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:" f" {y_star}" @@ -587,7 +600,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: minval=lower_bound, maxval=upper_bound, ) - final_y = six_hump_camel(final_x) + final_y = standardised_six_hump_camel(final_x) random_x = jnp.concatenate([initial_x, final_x], axis=0) random_y = jnp.concatenate([initial_y, final_y], axis=0) random_experiment_results.append(gpx.Dataset(X=random_x, y=random_y)) @@ -628,12 +641,12 @@ def obtain_log_regret_statistics( bo_log_regret_mean, bo_log_regret_std = obtain_log_regret_statistics( - bo_experiment_results, -1.031625 + bo_experiment_results, -1.8377 ) ( random_log_regret_mean, random_log_regret_std, -) = obtain_log_regret_statistics(random_experiment_results, -1.031625) +) = obtain_log_regret_statistics(random_experiment_results, -1.8377) # %% [markdown] # Now, when we plot the mean and standard deviation of the log regret at each iteration, @@ -665,7 +678,7 @@ def obtain_log_regret_statistics( # %% [markdown] # It can also be useful to plot the queried points over the course of a single BO run, in # order to gain some insight into how the algorithm queries the search space. Below -# we do this for the first BO experiment, and can see that the algorithm initially +# we do this for one of the BO experiments, and can see that the algorithm initially # performs some exploration of the search space whilst it is uncertain about the black-box # function, but it then hones in one one of the global minima of the function, as we would hope! @@ -684,8 +697,8 @@ def obtain_log_regret_statistics( ) ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2], zorder=2) ax.scatter( - bo_experiment_results[0].X[:, 0], - bo_experiment_results[0].X[:, 1], + bo_experiment_results[1].X[:, 0], + bo_experiment_results[1].X[:, 1], marker="x", color=cols[1], label="Bayesian Optimisation Queries",