Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Jun 1, 2022
2 parents a5e50ad + 44e7648 commit 013adf9
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 18 deletions.
38 changes: 22 additions & 16 deletions docs/nbs/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-1., 1., 500).reshape(-1, 1)
xtest = jnp.linspace(-1.0, 1.0, 500).reshape(-1, 1)
plt.plot(x, y, "o", markersize=8)
# %% [markdown]
# ## MAP inference
Expand All @@ -60,7 +60,7 @@
posterior = prior * likelihood
print(type(posterior))
# %% [markdown]
# Whilst the latent function is Gaussian, the posterior distribution is non-Gaussian since our generative model first samples the latent GP and propagates these samples through the likelihood function's inverse link function. This step prevents us from being able to analytically integrate the latent function's values out of our posterior, and we must instead adopt alternative inference techniques. We begin with maximum a posteriori (MAP) estimation, a fast inference procedure to obtain point estimates for the latent function and the kernel's hyperparameters by maximising the marginal log-likelihood.
# Whilst the latent function is Gaussian, the posterior distribution is non-Gaussian since our generative model first samples the latent GP and propagates these samples through the likelihood function's inverse link function. This step prevents us from being able to analytically integrate the latent function's values out of our posterior, and we must instead adopt alternative inference techniques. We begin with maximum a posteriori (MAP) estimation, a fast inference procedure to obtain point estimates for the latent function and the kernel's hyperparameters by maximising the marginal log-likelihood.
# %% [markdown]
# To begin we obtain a set of initial parameter values through the `initialise` callable, and transform these to the unconstrained space via `transform` (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We also define the negative marginal log-likelihood, and JIT compile this to accelerate training.
# %%
Expand Down Expand Up @@ -96,17 +96,16 @@
# %%
# Adapted from BlackJax's introduction notebook.
num_adapt = 1000
num_samples = 1000
num_samples = 500

mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False))

adapt = blackjax.window_adaptation(
blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65
)
adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)

# Initialise the chain
last_state, kernel, _ = adapt.run(key, params)


def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
Expand All @@ -133,17 +132,17 @@ def one_step(state, rng_key):
# Our acceptance rate is slightly too large, prompting an examination of the chain's trace plots. A well-mixing chain will have very few (if any) flat spots in its trace plot whilst also not having too many steps in the same direction. In addition to the model's hyperparameters, there will be 500 samples for each of the 100 latent function values in the `states.position` dictionary. We depict the chains that correspond to the model hyperparameters and the first value of the latent function for brevity.
# %%
fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(15, 5), tight_layout=True)
ax0.plot(states.position['kernel']['lengthscale'])
ax1.plot(states.position['kernel']['variance'])
ax2.plot(states.position['latent'][:, 0, :])
ax0.plot(states.position["kernel"]["lengthscale"])
ax1.plot(states.position["kernel"]["variance"])
ax2.plot(states.position["latent"][:, 0, :])
ax0.set_title("Kernel Lengthscale")
ax1.set_title("Kernel Variance")
ax2.set_title("Latent Function (index = 1)")

# %% [markdown]
# ## Prediction
#
# Having obtained samples from the posterior, we draw ten instances from our model's predictive distribution per MCMC sample. Using these draws, we will be able to compute credible values and expected values under our posterior distribution.
# Having obtained samples from the posterior, we draw ten instances from our model's predictive distribution per MCMC sample. Using these draws, we will be able to compute credible values and expected values under our posterior distribution.
#
# An ideal Markov chain would have samples completely uncorrelated with their neighbours after a single lag. However, in practice, correlations often exist within our chain's sample set. A commonly used technique to try and reduce this correlation is _thinning_ whereby we select every $n$th sample where $n$ is the minimum lag length at which we believe the samples are uncorrelated. Although further analysis of the chain's autocorrelation is required to find appropriate thinning factors, we employ a thin factor of 10 for demonstration purposes.
# %%
Expand All @@ -152,9 +151,9 @@ def one_step(state, rng_key):

for i in range(0, num_samples, thin_factor):
ps = gpx.parameters.copy_dict_structure(params)
ps['kernel']['lengthscale'] = states.position['kernel']['lengthscale'][i]
ps['kernel']['variance'] = states.position['kernel']['variance'][i]
ps['latent'] = states.position['latent'][i, :, :]
ps["kernel"]["lengthscale"] = states.position["kernel"]["lengthscale"][i]
ps["kernel"]["variance"] = states.position["kernel"]["variance"][i]
ps["latent"] = states.position["latent"][i, :, :]
ps = gpx.transform(ps, constrainer)

predictive_dist = likelihood(posterior(D, ps)(xtest), ps)
Expand All @@ -171,9 +170,16 @@ def one_step(state, rng_key):

# %%
fig, ax = plt.subplots(figsize=(16, 5), tight_layout=True)
ax.plot(x, y, "o", markersize=5, color='tab:red', label='Observations', zorder=2, alpha=0.7)
ax.plot(xtest, expected_val, linewidth=2, color='tab:blue', label='Predicted mean', zorder=1)
ax.fill_between(xtest.flatten(), lower_ci.flatten(), upper_ci.flatten(), alpha=0.2, color='tab:blue', label='95% CI')
ax.plot(x, y, "o", markersize=5, color="tab:red", label="Observations", zorder=2, alpha=0.7)
ax.plot(xtest, expected_val, linewidth=2, color="tab:blue", label="Predicted mean", zorder=1)
ax.fill_between(
xtest.flatten(),
lower_ci.flatten(),
upper_ci.flatten(),
alpha=0.2,
color="tab:blue",
label="95% CI",
)

# %% [markdown]
# ## System configuration
Expand Down
4 changes: 2 additions & 2 deletions docs/nbs/regression.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
# Our aim in this tutorial will be to reconstruct the latent function from our noisy observations $\mathcal{D}$ via Gaussian process regression. We begin by defining a Gaussian process prior in the next section.


# %% [markdown]
# %% [markdown]
# ## Defining the prior
#
# A zero-mean Gaussian process (GP) places a prior distribution over real-valued functions $f(\cdot)$ where $f(\boldsymbol{x}) \sim \mathcal{N}(0, \mathbf{K}_{\boldsymbol{x}\boldsymbol{x}})$ for any finite collection of inputs $\boldsymbol{x}$.
Expand Down Expand Up @@ -107,7 +107,7 @@
# The posterior is proportional to the prior multiplied by the likelihood, written as
#
# $$ p(f(\cdot) | \mathcal{D}) \propto p(f(\cdot)) * p(\mathcal{D} | f(\cdot)). $$
#
#
# Mimicking this construct, the posterior is established in GPJax through the `*` operator.

# %%
Expand Down
58 changes: 58 additions & 0 deletions docs/nbs/spectral_kernels.pct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# language: python
# name: python3
# ---

# %%
import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax as ox
import distrax as dx

key = jr.PRNGKey(123)

# %%
n = 100
noise = 0.3

x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).sort().reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-3.25, 3.25, 500).reshape(-1, 1)
ytest = f(xtest)

# %%
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(xtest, ytest, label="Latent function")
ax.plot(x, y, "o", label="Observations")
ax.legend(loc="best")

# %%
base_kernel = gpx.Matern32()
kernel = gpx.RandomFourierFeature(key = key, base_kernel = base_kernel, num_basis_fns=50)
prior = gpx.Prior(kernel=kernel)

# %%
kernel.params

# %%

# %%
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ watermark
sphinxext-opengraph
blackjax
dm-haiku
ipywidgets
# Install GPJax istself
.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ tensorflow == 2.8.1
tensorflow-probability==0.16.0
tqdm>=4.0.0
ml-collections==0.1.0
protobuf==3.19.0

0 comments on commit 013adf9

Please sign in to comment.