diff --git a/docs/examples/likelihoods_guide.py b/docs/examples/likelihoods_guide.py index a04b2c529..f1d36d0fb 100644 --- a/docs/examples/likelihoods_guide.py +++ b/docs/examples/likelihoods_guide.py @@ -1,4 +1,3 @@ -# %% [markdown] # # Likelihood guide # # In this notebook, we will walk users through the process of creating a new likelihood @@ -49,7 +48,7 @@ # these methods in the forthcoming sections, but first, we will show how to instantiate # a likelihood object. To do this, we'll need a dataset. -# %% +# + # Enable Float64 for more stable matrix inversions. from jax import config @@ -81,8 +80,8 @@ ax.plot(x, y, "o", label="Observations") ax.plot(x, f(x), label="Latent function") ax.legend() +# - -# %% [markdown] # In this example, our observations have support $[-3, 3]$ and are generated from a # sinusoidal function with Gaussian noise. As such, our response values $\mathbf{y}$ # range between $-1$ and $1$, subject to Gaussian noise. Due to this, a Gaussian @@ -93,10 +92,8 @@ # instantiating a likelihood object. We do this by specifying the `num_datapoints` # argument. -# %% gpx.likelihoods.Gaussian(num_datapoints=D.n) -# %% [markdown] # ### Likelihood parameters # # Some likelihoods, such as the Gaussian likelihood, contain parameters that we seek @@ -108,10 +105,8 @@ # initialise the likelihood standard deviation with a value of $0.5$, then we would do # this as follows: -# %% gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.5) -# %% [markdown] # To control other properties of the observation noise such as trainability and value # constraints, see our [PyTree guide](pytrees.md). # @@ -128,7 +123,7 @@ # samples of $\mathbf{f}^{\star}$, whilst in red we see samples of # $\mathbf{y}^{\star}$. -# %% +# + kernel = gpx.kernels.Matern32() meanf = gpx.mean_functions.Zero() prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) @@ -158,11 +153,11 @@ color=cols[1], label="Predictive samples", ) +# - -# %% [markdown] # Similarly, for a Bernoulli likelihood function, the samples of $y$ would be binary. -# %% +# + likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n) @@ -185,8 +180,8 @@ color=cols[1], label="Predictive samples", ) +# - -# %% [markdown] # ### Link functions # # In the above figure, we can see the latent samples being constrained to be either 0 or @@ -234,7 +229,7 @@ # this, let us consider a Gaussian likelihood where we'll first define a variational # approximation to the posterior. -# %% +# + z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1) q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z) @@ -245,32 +240,27 @@ def q_moments(x): mean, variance = jax.vmap(q_moments)(x[:, None]) +# - -# %% [markdown] # Now that we have the variational mean and variational (co)variance, we can compute # the expected log-likelihood using the `expected_log_likelihood` method of the # likelihood object. -# %% jnp.sum(likelihood.expected_log_likelihood(y=y, mean=mean, variance=variance)) -# %% [markdown] # However, had we wanted to do this using quadrature, then we would have done the # following: -# %% lquad = gpx.likelihoods.Gaussian( num_datapoints=D.n, obs_stddev=jnp.array([0.1]), + integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20), ) -# %% [markdown] # However, this is not recommended for the Gaussian likelihood given that the # expectation can be computed analytically. -# %% [markdown] # ## System configuration -# %% # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Pinder'