Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BO notebook #430

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,21 @@
# $$f(x) = (6x - 2)^2 \sin(12x - 4)$$
#
# treating $f$ as a black-box function. Moreover, we shall restrict the domain of the
# function to $\mathbf{x} \in [0, 1]$. The global minimum of this function is located at
# $x = 0.757$, where $f(x) = -6.021$.
# function to $\mathbf{x} \in [0, 1]$. We shall also *standardise* the output of the function, such that
# it has a mean of 0 and standard deviation of 1. This is quite common practice when using
# GPs; we're using a zero mean prior, so ensuring that our data has a mean of zero aligns
# with this, and often we have scale parameters in the covariance function, which
# are frequently initialised, or have priors set on them, under the assumption that the
# function being modelled has unit variance. For similar reasons, it can also be useful
# to normalise the inputs to a GP. The global minimum of this (standardised)
# function is located at $x = 0.757$, where $f(x) = -1.463$.


# %%
def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4)
def standardised_forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
mean = 0.45321
std = 4.4258
return ((6 * x - 2) ** 2 * jnp.sin(12 * x - 4) - mean) / std


# %% [markdown]
Expand All @@ -189,7 +197,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
initial_x = tfp.mcmc.sample_halton_sequence(
dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64
).reshape(-1, 1)
initial_y = forrester(initial_x)
initial_y = standardised_forrester(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)


Expand All @@ -204,7 +212,7 @@ def return_optimised_posterior(
data: gpx.Dataset, prior: gpx.base.Module, key: Array
) -> gpx.base.Module:
likelihood = gpx.likelihoods.Gaussian(
num_datapoints=data.n, obs_stddev=jnp.array(1e-3)
num_datapoints=data.n, obs_stddev=jnp.array(1e-6)
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)

Expand All @@ -214,7 +222,7 @@ def return_optimised_posterior(
negative_mll(posterior, train_data=data)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
opt_posterior, _ = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=data,
Expand Down Expand Up @@ -301,7 +309,7 @@ def optimise_sample(


x_star = optimise_sample(approx_sample, key, lower_bound, upper_bound, 100)
y_star = forrester(x_star)
y_star = standardised_forrester(x_star)


# %% [markdown]
Expand All @@ -321,7 +329,7 @@ def plot_bayes_opt(
queried_x: ScalarFloat,
) -> None:
plt_x = jnp.linspace(0, 1, 1000).reshape(-1, 1)
forrester_y = forrester(plt_x)
forrester_y = standardised_forrester(plt_x)
sample_y = sample(plt_x)

latent_dist = posterior.predict(plt_x, train_data=dataset)
Expand Down Expand Up @@ -392,7 +400,7 @@ def plot_bayes_opt(
initial_x = tfp.mcmc.sample_halton_sequence(
dim=1, num_results=initial_sample_num, seed=key, dtype=jnp.float64
).reshape(-1, 1)
initial_y = forrester(initial_x)
initial_y = standardised_forrester(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)

for i in range(bo_iters):
Expand All @@ -415,7 +423,7 @@ def plot_bayes_opt(
plot_bayes_opt(opt_posterior, approx_sample, D, x_star)

# Evaluate the black-box function at the best point observed so far, and add it to the dataset
y_star = forrester(x_star)
y_star = standardised_forrester(x_star)
print(f"Queried Point: {x_star}, Black-Box Function Value: {y_star}")
D = D + gpx.Dataset(X=x_star, y=y_star)

Expand All @@ -435,7 +443,7 @@ def plot_bayes_opt(
cumulative_best_y = jax.lax.associative_scan(jax.numpy.minimum, D.y)
ax.plot(fn_evaluations, cumulative_best_y)
ax.axvline(x=initial_sample_num, linestyle=":")
ax.axhline(y=-6.0207, linestyle="--", label="True Minimum")
ax.axhline(y=-1.463, linestyle="--", label="True Minimum")
ax.set_xlabel("Number of Black-Box Function Evaluations")
ax.set_ylabel("Best Observed Value")
ax.legend()
Expand All @@ -452,18 +460,23 @@ def plot_bayes_opt(
#
# $$f(x_1, x_2) = (4 - 2.1x_1^2 + \frac{x_1^4}{3})x_1^2 + x_1x_2 + (-4 + 4x_2^2)x_2^2$$
#
# We'll be evaluating it over the domain $x_1 \in [-2, 2]$ and $x_2 \in [-1, 1]$. The
# global minima of this function are located at $\mathbf{x} = (0.0898, -0.7126)$ and $\mathbf{x} = (-0.0898, 0.7126)$, where the function takes the value $f(\mathbf{x}) = -1.0316$.
# We'll be evaluating it over the domain $x_1 \in [-2, 2]$ and $x_2 \in [-1, 1]$, and
# shall standardise it. The
# global minima of this function are located at $\mathbf{x} = (0.0898, -0.7126)$ and
# $\mathbf{x} = (-0.0898, 0.7126)$, where the standardised function takes the value $f(\mathbf{x}) =
# -1.8377$.


# %%
def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
def standardised_six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
mean = 1.12767
std = 1.17500
x1 = x[..., :1]
x2 = x[..., 1:]
term1 = (4 - 2.1 * x1**2 + x1**4 / 3) * x1**2
term2 = x1 * x2
term3 = (-4 + 4 * x2**2) * x2**2
return term1 + term2 + term3
return (term1 + term2 + term3 - mean) / std


# %% [markdown]
Expand All @@ -474,7 +487,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
x2 = jnp.linspace(-1, 1, 100)
x1, x2 = jnp.meshgrid(x1, x2)
x = jnp.stack([x1.flatten(), x2.flatten()], axis=1)
y = six_hump_camel(x)
y = standardised_six_hump_camel(x)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(
Expand Down Expand Up @@ -521,7 +534,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
lower_bound = jnp.array([-2.0, -1.0])
upper_bound = jnp.array([2.0, 1.0])
initial_sample_num = 5
bo_iters = 11
bo_iters = 12
num_experiments = 5
bo_experiment_results = []

Expand All @@ -532,7 +545,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
dim=2, num_results=initial_sample_num, seed=key, dtype=jnp.float64
)
initial_x = jnp.array(lower_bound + (upper_bound - lower_bound) * initial_x)
initial_y = six_hump_camel(initial_x)
initial_y = standardised_six_hump_camel(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)

for i in range(bo_iters):
Expand All @@ -559,7 +572,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
)

# Evaluate the black-box function at the best point observed so far, and add it to the dataset
y_star = six_hump_camel(x_star)
y_star = standardised_six_hump_camel(x_star)
print(
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:"
f" {y_star}"
Expand Down Expand Up @@ -587,7 +600,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
minval=lower_bound,
maxval=upper_bound,
)
final_y = six_hump_camel(final_x)
final_y = standardised_six_hump_camel(final_x)
random_x = jnp.concatenate([initial_x, final_x], axis=0)
random_y = jnp.concatenate([initial_y, final_y], axis=0)
random_experiment_results.append(gpx.Dataset(X=random_x, y=random_y))
Expand Down Expand Up @@ -628,12 +641,12 @@ def obtain_log_regret_statistics(


bo_log_regret_mean, bo_log_regret_std = obtain_log_regret_statistics(
bo_experiment_results, -1.031625
bo_experiment_results, -1.8377
)
(
random_log_regret_mean,
random_log_regret_std,
) = obtain_log_regret_statistics(random_experiment_results, -1.031625)
) = obtain_log_regret_statistics(random_experiment_results, -1.8377)

# %% [markdown]
# Now, when we plot the mean and standard deviation of the log regret at each iteration,
Expand Down Expand Up @@ -665,7 +678,7 @@ def obtain_log_regret_statistics(
# %% [markdown]
# It can also be useful to plot the queried points over the course of a single BO run, in
# order to gain some insight into how the algorithm queries the search space. Below
# we do this for the first BO experiment, and can see that the algorithm initially
# we do this for one of the BO experiments, and can see that the algorithm initially
# performs some exploration of the search space whilst it is uncertain about the black-box
# function, but it then hones in one one of the global minima of the function, as we would hope!

Expand All @@ -684,8 +697,8 @@ def obtain_log_regret_statistics(
)
ax.scatter(x_star_two[0][0], x_star_two[0][1], marker="*", color=cols[2], zorder=2)
ax.scatter(
bo_experiment_results[0].X[:, 0],
bo_experiment_results[0].X[:, 1],
bo_experiment_results[1].X[:, 0],
bo_experiment_results[1].X[:, 1],
marker="x",
color=cols[1],
label="Bayesian Optimisation Queries",
Expand Down