Skip to content

Commit

Permalink
Revert "Update likelihoods_guide.py"
Browse files Browse the repository at this point in the history
This reverts commit 5f51cfe.
  • Loading branch information
daniel-dodd committed Mar 22, 2024
1 parent 5f51cfe commit 5aa7cc2
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions docs/examples/likelihoods_guide.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# %% [markdown]
# # Likelihood guide
#
# In this notebook, we will walk users through the process of creating a new likelihood
Expand Down Expand Up @@ -49,7 +48,7 @@
# these methods in the forthcoming sections, but first, we will show how to instantiate
# a likelihood object. To do this, we'll need a dataset.

# %%
# +
# Enable Float64 for more stable matrix inversions.
from jax import config

Expand Down Expand Up @@ -81,8 +80,8 @@
ax.plot(x, y, "o", label="Observations")
ax.plot(x, f(x), label="Latent function")
ax.legend()
# -

# %% [markdown]
# In this example, our observations have support $[-3, 3]$ and are generated from a
# sinusoidal function with Gaussian noise. As such, our response values $\mathbf{y}$
# range between $-1$ and $1$, subject to Gaussian noise. Due to this, a Gaussian
Expand All @@ -93,10 +92,8 @@
# instantiating a likelihood object. We do this by specifying the `num_datapoints`
# argument.

# %%
gpx.likelihoods.Gaussian(num_datapoints=D.n)

# %% [markdown]
# ### Likelihood parameters
#
# Some likelihoods, such as the Gaussian likelihood, contain parameters that we seek
Expand All @@ -108,10 +105,8 @@
# initialise the likelihood standard deviation with a value of $0.5$, then we would do
# this as follows:

# %%
gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.5)

# %% [markdown]
# To control other properties of the observation noise such as trainability and value
# constraints, see our [PyTree guide](pytrees.md).
#
Expand All @@ -128,7 +123,7 @@
# samples of $\mathbf{f}^{\star}$, whilst in red we see samples of
# $\mathbf{y}^{\star}$.

# %%
# +
kernel = gpx.kernels.Matern32()
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
Expand Down Expand Up @@ -158,11 +153,11 @@
color=cols[1],
label="Predictive samples",
)
# -

# %% [markdown]
# Similarly, for a Bernoulli likelihood function, the samples of $y$ would be binary.

# %%
# +
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)


Expand All @@ -185,8 +180,8 @@
color=cols[1],
label="Predictive samples",
)
# -

# %% [markdown]
# ### Link functions
#
# In the above figure, we can see the latent samples being constrained to be either 0 or
Expand Down Expand Up @@ -234,7 +229,7 @@
# this, let us consider a Gaussian likelihood where we'll first define a variational
# approximation to the posterior.

# %%
# +
z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1)
q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z)

Expand All @@ -245,32 +240,27 @@ def q_moments(x):


mean, variance = jax.vmap(q_moments)(x[:, None])
# -

# %% [markdown]
# Now that we have the variational mean and variational (co)variance, we can compute
# the expected log-likelihood using the `expected_log_likelihood` method of the
# likelihood object.

# %%
jnp.sum(likelihood.expected_log_likelihood(y=y, mean=mean, variance=variance))

# %% [markdown]
# However, had we wanted to do this using quadrature, then we would have done the
# following:

# %%
lquad = gpx.likelihoods.Gaussian(
num_datapoints=D.n,
obs_stddev=jnp.array([0.1]),
integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20),
)

# %% [markdown]
# However, this is not recommended for the Gaussian likelihood given that the
# expectation can be computed analytically.

# %% [markdown]
# ## System configuration

# %%
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Thomas Pinder'

0 comments on commit 5aa7cc2

Please sign in to comment.