diff --git a/docs/nbs/classification.pct.py b/docs/nbs/classification.pct.py index a184527c5..45788476a 100644 --- a/docs/nbs/classification.pct.py +++ b/docs/nbs/classification.pct.py @@ -44,7 +44,7 @@ D = gpx.Dataset(X=x, y=y) -xtest = jnp.linspace(-1., 1., 500).reshape(-1, 1) +xtest = jnp.linspace(-1.0, 1.0, 500).reshape(-1, 1) plt.plot(x, y, "o", markersize=8) # %% [markdown] # ## MAP inference @@ -60,7 +60,7 @@ posterior = prior * likelihood print(type(posterior)) # %% [markdown] -# Whilst the latent function is Gaussian, the posterior distribution is non-Gaussian since our generative model first samples the latent GP and propagates these samples through the likelihood function's inverse link function. This step prevents us from being able to analytically integrate the latent function's values out of our posterior, and we must instead adopt alternative inference techniques. We begin with maximum a posteriori (MAP) estimation, a fast inference procedure to obtain point estimates for the latent function and the kernel's hyperparameters by maximising the marginal log-likelihood. +# Whilst the latent function is Gaussian, the posterior distribution is non-Gaussian since our generative model first samples the latent GP and propagates these samples through the likelihood function's inverse link function. This step prevents us from being able to analytically integrate the latent function's values out of our posterior, and we must instead adopt alternative inference techniques. We begin with maximum a posteriori (MAP) estimation, a fast inference procedure to obtain point estimates for the latent function and the kernel's hyperparameters by maximising the marginal log-likelihood. # %% [markdown] # To begin we obtain a set of initial parameter values through the `initialise` callable, and transform these to the unconstrained space via `transform` (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We also define the negative marginal log-likelihood, and JIT compile this to accelerate training. # %% @@ -96,17 +96,16 @@ # %% # Adapted from BlackJax's introduction notebook. num_adapt = 1000 -num_samples = 1000 +num_samples = 500 mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False)) -adapt = blackjax.window_adaptation( - blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65 -) +adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65) # Initialise the chain last_state, kernel, _ = adapt.run(key, params) + def inference_loop(rng_key, kernel, initial_state, num_samples): def one_step(state, rng_key): state, info = kernel(rng_key, state) @@ -133,9 +132,9 @@ def one_step(state, rng_key): # Our acceptance rate is slightly too large, prompting an examination of the chain's trace plots. A well-mixing chain will have very few (if any) flat spots in its trace plot whilst also not having too many steps in the same direction. In addition to the model's hyperparameters, there will be 500 samples for each of the 100 latent function values in the `states.position` dictionary. We depict the chains that correspond to the model hyperparameters and the first value of the latent function for brevity. # %% fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(15, 5), tight_layout=True) -ax0.plot(states.position['kernel']['lengthscale']) -ax1.plot(states.position['kernel']['variance']) -ax2.plot(states.position['latent'][:, 0, :]) +ax0.plot(states.position["kernel"]["lengthscale"]) +ax1.plot(states.position["kernel"]["variance"]) +ax2.plot(states.position["latent"][:, 0, :]) ax0.set_title("Kernel Lengthscale") ax1.set_title("Kernel Variance") ax2.set_title("Latent Function (index = 1)") @@ -143,7 +142,7 @@ def one_step(state, rng_key): # %% [markdown] # ## Prediction # -# Having obtained samples from the posterior, we draw ten instances from our model's predictive distribution per MCMC sample. Using these draws, we will be able to compute credible values and expected values under our posterior distribution. +# Having obtained samples from the posterior, we draw ten instances from our model's predictive distribution per MCMC sample. Using these draws, we will be able to compute credible values and expected values under our posterior distribution. # # An ideal Markov chain would have samples completely uncorrelated with their neighbours after a single lag. However, in practice, correlations often exist within our chain's sample set. A commonly used technique to try and reduce this correlation is _thinning_ whereby we select every $n$th sample where $n$ is the minimum lag length at which we believe the samples are uncorrelated. Although further analysis of the chain's autocorrelation is required to find appropriate thinning factors, we employ a thin factor of 10 for demonstration purposes. # %% @@ -152,9 +151,9 @@ def one_step(state, rng_key): for i in range(0, num_samples, thin_factor): ps = gpx.parameters.copy_dict_structure(params) - ps['kernel']['lengthscale'] = states.position['kernel']['lengthscale'][i] - ps['kernel']['variance'] = states.position['kernel']['variance'][i] - ps['latent'] = states.position['latent'][i, :, :] + ps["kernel"]["lengthscale"] = states.position["kernel"]["lengthscale"][i] + ps["kernel"]["variance"] = states.position["kernel"]["variance"][i] + ps["latent"] = states.position["latent"][i, :, :] ps = gpx.transform(ps, constrainer) predictive_dist = likelihood(posterior(D, ps)(xtest), ps) @@ -171,9 +170,16 @@ def one_step(state, rng_key): # %% fig, ax = plt.subplots(figsize=(16, 5), tight_layout=True) -ax.plot(x, y, "o", markersize=5, color='tab:red', label='Observations', zorder=2, alpha=0.7) -ax.plot(xtest, expected_val, linewidth=2, color='tab:blue', label='Predicted mean', zorder=1) -ax.fill_between(xtest.flatten(), lower_ci.flatten(), upper_ci.flatten(), alpha=0.2, color='tab:blue', label='95% CI') +ax.plot(x, y, "o", markersize=5, color="tab:red", label="Observations", zorder=2, alpha=0.7) +ax.plot(xtest, expected_val, linewidth=2, color="tab:blue", label="Predicted mean", zorder=1) +ax.fill_between( + xtest.flatten(), + lower_ci.flatten(), + upper_ci.flatten(), + alpha=0.2, + color="tab:blue", + label="95% CI", +) # %% [markdown] # ## System configuration diff --git a/docs/nbs/regression.pct.py b/docs/nbs/regression.pct.py index a381e7815..2731dd13f 100644 --- a/docs/nbs/regression.pct.py +++ b/docs/nbs/regression.pct.py @@ -64,7 +64,7 @@ # Our aim in this tutorial will be to reconstruct the latent function from our noisy observations $\mathcal{D}$ via Gaussian process regression. We begin by defining a Gaussian process prior in the next section. - # %% [markdown] +# %% [markdown] # ## Defining the prior # # A zero-mean Gaussian process (GP) places a prior distribution over real-valued functions $f(\cdot)$ where $f(\boldsymbol{x}) \sim \mathcal{N}(0, \mathbf{K}_{\boldsymbol{x}\boldsymbol{x}})$ for any finite collection of inputs $\boldsymbol{x}$. @@ -107,7 +107,7 @@ # The posterior is proportional to the prior multiplied by the likelihood, written as # # $$ p(f(\cdot) | \mathcal{D}) \propto p(f(\cdot)) * p(\mathcal{D} | f(\cdot)). $$ -# +# # Mimicking this construct, the posterior is established in GPJax through the `*` operator. # %% diff --git a/docs/nbs/spectral_kernels.pct.py b/docs/nbs/spectral_kernels.pct.py new file mode 100644 index 000000000..7356341e9 --- /dev/null +++ b/docs/nbs/spectral_kernels.pct.py @@ -0,0 +1,58 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: Python 3.9.7 ('gpjax') +# language: python +# name: python3 +# --- + +# %% +import gpjax as gpx +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt +import optax as ox +import distrax as dx + +key = jr.PRNGKey(123) + +# %% +n = 100 +noise = 0.3 + +x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,)).sort().reshape(-1, 1) +f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x) +signal = f(x) +y = signal + jr.normal(key, shape=signal.shape) * noise + +D = gpx.Dataset(X=x, y=y) + +xtest = jnp.linspace(-3.25, 3.25, 500).reshape(-1, 1) +ytest = f(xtest) + +# %% +fig, ax = plt.subplots(figsize=(10, 5)) +ax.plot(xtest, ytest, label="Latent function") +ax.plot(x, y, "o", label="Observations") +ax.legend(loc="best") + +# %% +base_kernel = gpx.Matern32() +kernel = gpx.RandomFourierFeature(key = key, base_kernel = base_kernel, num_basis_fns=50) +prior = gpx.Prior(kernel=kernel) + +# %% +kernel.params + +# %% + +# %% diff --git a/docs/requirements.txt b/docs/requirements.txt index 184bc833b..a0a56ea07 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -14,5 +14,6 @@ watermark sphinxext-opengraph blackjax dm-haiku +ipywidgets # Install GPJax istself . \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index edcd00b3a..da3a4a9aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ tensorflow == 2.8.1 tensorflow-probability==0.16.0 tqdm>=4.0.0 ml-collections==0.1.0 +protobuf==3.19.0 \ No newline at end of file