Skip to content

Commit

Permalink
Renaming Variational Inference
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Jun 1, 2022
1 parent 790ea25 commit a5e50ad
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ Sparse Frameworks
Abstract Sparse Objects
*********************************

.. autoclass:: VariationalPosterior
.. autoclass:: AbstractVariationalInference
:members:


Sparse Methods
*********************************

.. autoclass:: SVGP
.. autoclass:: StochasticVI
:members:


Expand Down
4 changes: 2 additions & 2 deletions docs/nbs/sparse_regression.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@
# Here, the variational process $q(\cdot)$ depends on the prior through $p(f(\cdot)|f(\boldsymbol{z}))$ in $(\times)$.
# %% [markdown]
#
# We combine our true and approximate posterior Gaussian processes into an `SVGP` object to define the variational strategy that we will adopt in the forthcoming inference.
# We combine our true and approximate posterior Gaussian processes into an `StochasticVI` object to define the variational strategy that we will adopt in the forthcoming inference.

# %%
svgp = gpx.SVGP(posterior=p, variational_family=q)
svgp = gpx.StochasticVI(posterior=p, variational_family=q)

# %% [markdown]
# ## Inference
Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .likelihoods import Bernoulli, Gaussian
from .mean_functions import Constant, Zero
from .parameters import copy_dict_structure, initialise, transform
from .sparse_gps import SVGP
from .sparse_gps import StochasticVI
from .types import Dataset
from .variational import VariationalGaussian, WhitenedVariationalGaussian

Expand Down
8 changes: 4 additions & 4 deletions gpjax/sparse_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


@dataclass
class VariationalPosterior:
"""A variational posterior object. With reference to some true posterior distribution :math:`p`, this can be used to minimise the KL-divergence between :math:`p` and a variational posterior :math:`q`."""
class AbstractVariationalInference:
"""A base class for inference and training of variational families against an extact posterior"""

posterior: AbstractPosterior
variational_family: VariationalFamily
Expand Down Expand Up @@ -50,8 +50,8 @@ def elbo(


@dataclass
class SVGP(VariationalPosterior):
"""Sparse Variational Gaussian Process (SVGP) training module. The key reference is Hensman et. al., (2013) - Gaussian processes for big data."""
class StochasticVI(AbstractVariationalInference):
"""Stochastic Variational inference training module. The key reference is Hensman et. al., (2013) - Gaussian processes for big data."""

def __post_init__(self):
self.prior = self.posterior.prior
Expand Down
2 changes: 1 addition & 1 deletion tests/test_abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_batch_fitting(nb, ndata):

q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z)

svgp = gpx.SVGP(posterior=p, variational_family=q)
svgp = gpx.StochasticVI(posterior=p, variational_family=q)
params, trainable_status, constrainer, unconstrainer = initialise(svgp)
params = gpx.transform(params, unconstrainer)
objective = svgp.elbo(D, constrainer)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sparse_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_svgp(n_datapoints, n_inducing_points, n_test, whiten, diag, jit_fns):
inducing_inputs=inducing_inputs, diag=diag
)

svgp = gpx.SVGP(posterior=post, variational_family=q)
svgp = gpx.StochasticVI(posterior=post, variational_family=q)

assert svgp.posterior.prior == post.prior
assert svgp.posterior.likelihood == post.likelihood
Expand Down

0 comments on commit a5e50ad

Please sign in to comment.