Skip to content

Commit

Permalink
Fix BO notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas-Christie committed Dec 28, 2023
1 parent 9e4006c commit 4b15a8f
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,21 @@
# $$f(x) = (6x - 2)^2 \sin(12x - 4)$$
#
# treating $f$ as a black-box function. Moreover, we shall restrict the domain of the
# function to $\mathbf{x} \in [0, 1]$. The global minimum of this function is located at
# $x = 0.757$, where $f(x) = -6.021$.
# function to $\mathbf{x} \in [0, 1]$. We shall also *standardise* the output of the function, such that
# it has a mean of 0 and standard deviation of 1. This is quite common practice when using
# GPs; we're using a zero mean prior, so ensuring that our data has a mean of zero aligns
# with this, and often we have scale parameters in the covariance function, which
# are frequently initialised, or have priors set on them, under the assumption that the
# function being modelled has unit variance. For similar reasons, it can also be useful
# to normalise the inputs to a GP. The global minimum of this (standardised)
# function is located at $x = 0.757$, where $f(x) = -1.463$.


# %%
def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4)
def standardised_forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
mean = 0.45321
std = 4.4258
return ((6 * x - 2) ** 2 * jnp.sin(12 * x - 4) - mean) / std


# %% [markdown]
Expand All @@ -189,7 +197,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
initial_x = tfp.mcmc.sample_halton_sequence(
dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64
).reshape(-1, 1)
initial_y = forrester(initial_x)
initial_y = standardised_forrester(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)


Expand All @@ -204,7 +212,7 @@ def return_optimised_posterior(
data: gpx.Dataset, prior: gpx.base.Module, key: Array
) -> gpx.base.Module:
likelihood = gpx.likelihoods.Gaussian(
num_datapoints=data.n, obs_stddev=jnp.array(1e-3)
num_datapoints=data.n, obs_stddev=jnp.array(1e-6)
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)

Expand All @@ -214,7 +222,7 @@ def return_optimised_posterior(
negative_mll(posterior, train_data=data)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
opt_posterior, _ = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=data,
Expand Down Expand Up @@ -301,7 +309,7 @@ def optimise_sample(


x_star = optimise_sample(approx_sample, key, lower_bound, upper_bound, 100)
y_star = forrester(x_star)
y_star = standardised_forrester(x_star)


# %% [markdown]
Expand All @@ -321,7 +329,7 @@ def plot_bayes_opt(
queried_x: ScalarFloat,
) -> None:
plt_x = jnp.linspace(0, 1, 1000).reshape(-1, 1)
forrester_y = forrester(plt_x)
forrester_y = standardised_forrester(plt_x)
sample_y = sample(plt_x)

latent_dist = posterior.predict(plt_x, train_data=dataset)
Expand Down Expand Up @@ -392,7 +400,7 @@ def plot_bayes_opt(
initial_x = tfp.mcmc.sample_halton_sequence(
dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64
).reshape(-1, 1)
initial_y = forrester(initial_x)
initial_y = standardised_forrester(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)

for i in range(bo_iters):
Expand All @@ -415,7 +423,7 @@ def plot_bayes_opt(
plot_bayes_opt(opt_posterior, approx_sample, D, x_star)

# Evaluate the black-box function at the best point observed so far, and add it to the dataset
y_star = forrester(x_star)
y_star = standardised_forrester(x_star)
print(f"Queried Point: {x_star}, Black-Box Function Value: {y_star}")
D = D + gpx.Dataset(X=x_star, y=y_star)

Expand All @@ -435,7 +443,7 @@ def plot_bayes_opt(
cumulative_best_y = jax.lax.associative_scan(jax.numpy.minimum, D.y)
ax.plot(fn_evaluations, cumulative_best_y)
ax.axvline(x=initial_sample_num, linestyle=":")
ax.axhline(y=-6.0207, linestyle="--", label="True Minimum")
ax.axhline(y=-1.463, linestyle="--", label="True Minimum")
ax.set_xlabel("Number of Black-Box Function Evaluations")
ax.set_ylabel("Best Observed Value")
ax.legend()
Expand All @@ -452,18 +460,23 @@ def plot_bayes_opt(
#
# $$f(x_1, x_2) = (4 - 2.1x_1^2 + \frac{x_1^4}{3})x_1^2 + x_1x_2 + (-4 + 4x_2^2)x_2^2$$
#
# We'll be evaluating it over the domain $x_1 \in [-2, 2]$ and $x_2 \in [-1, 1]$. The
# global minima of this function are located at $\mathbf{x} = (0.0898, -0.7126)$ and $\mathbf{x} = (-0.0898, 0.7126)$, where the function takes the value $f(\mathbf{x}) = -1.0316$.
# We'll be evaluating it over the domain $x_1 \in [-2, 2]$ and $x_2 \in [-1, 1]$, and
# shall standardise it. The
# global minima of this function are located at $\mathbf{x} = (0.0898, -0.7126)$ and
# $\mathbf{x} = (-0.0898, 0.7126)$, where the standardised function takes the value $f(\mathbf{x}) =
# -1.8377$.


# %%
def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
def standardised_six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
mean = 1.12767
std = 1.17500
x1 = x[..., :1]
x2 = x[..., 1:]
term1 = (4 - 2.1 * x1**2 + x1**4 / 3) * x1**2
term2 = x1 * x2
term3 = (-4 + 4 * x2**2) * x2**2
return term1 + term2 + term3
return (term1 + term2 + term3 - mean) / std


# %% [markdown]
Expand All @@ -474,7 +487,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
x2 = jnp.linspace(-1, 1, 100)
x1, x2 = jnp.meshgrid(x1, x2)
x = jnp.stack([x1.flatten(), x2.flatten()], axis=1)
y = six_hump_camel(x)
y = standardised_six_hump_camel(x)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(
Expand Down Expand Up @@ -521,7 +534,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
lower_bound = jnp.array([-2.0, -1.0])
upper_bound = jnp.array([2.0, 1.0])
initial_sample_num = 5
bo_iters = 11
bo_iters = 12
num_experiments = 5
bo_experiment_results = []

Expand All @@ -532,7 +545,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
dim=2, num_results=initial_sample_num, seed=key, dtype=jnp.float64
)
initial_x = jnp.array(lower_bound + (upper_bound - lower_bound) * initial_x)
initial_y = six_hump_camel(initial_x)
initial_y = standardised_six_hump_camel(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)

for i in range(bo_iters):
Expand All @@ -559,7 +572,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
)

# Evaluate the black-box function at the best point observed so far, and add it to the dataset
y_star = six_hump_camel(x_star)
y_star = standardised_six_hump_camel(x_star)
print(
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:"
f" {y_star}"
Expand Down Expand Up @@ -587,7 +600,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
minval=lower_bound,
maxval=upper_bound,
)
final_y = six_hump_camel(final_x)
final_y = standardised_six_hump_camel(final_x)
random_x = jnp.concatenate([initial_x, final_x], axis=0)
random_y = jnp.concatenate([initial_y, final_y], axis=0)
random_experiment_results.append(gpx.Dataset(X=random_x, y=random_y))
Expand Down Expand Up @@ -628,12 +641,12 @@ def obtain_log_regret_statistics(


bo_log_regret_mean, bo_log_regret_std = obtain_log_regret_statistics(
bo_experiment_results, -1.031625
bo_experiment_results, -1.8377
)
(
random_log_regret_mean,
random_log_regret_std,
) = obtain_log_regret_statistics(random_experiment_results, -1.031625)
) = obtain_log_regret_statistics(random_experiment_results, -1.8377)

# %% [markdown]
# Now, when we plot the mean and standard deviation of the log regret at each iteration,
Expand Down Expand Up @@ -665,7 +678,7 @@ def obtain_log_regret_statistics(
# %% [markdown]
# It can also be useful to plot the queried points over the course of a single BO run, in
# order to gain some insight into how the algorithm queries the search space. Below
# we do this for the first BO experiment, and can see that the algorithm initially
# we do this for one of the BO experiments, and can see that the algorithm initially
# performs some exploration of the search space whilst it is uncertain about the black-box
# function, but it then hones in one one of the global minima of the function, as we would hope!

Expand All @@ -684,8 +697,8 @@ def obtain_log_regret_statistics(
)
ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2], zorder=2)
ax.scatter(
bo_experiment_results[0].X[:, 0],
bo_experiment_results[0].X[:, 1],
bo_experiment_results[1].X[:, 0],
bo_experiment_results[1].X[:, 1],
marker="x",
color=cols[1],
label="Bayesian Optimisation Queries",
Expand Down

0 comments on commit 4b15a8f

Please sign in to comment.