diff --git a/docs/api.rst b/docs/api.rst index 6a16787d0..58633f043 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -163,14 +163,14 @@ Sparse Frameworks Abstract Sparse Objects ********************************* -.. autoclass:: VariationalPosterior +.. autoclass:: AbstractVariationalInference :members: Sparse Methods ********************************* -.. autoclass:: SVGP +.. autoclass:: StochasticVI :members: diff --git a/docs/nbs/sparse_regression.pct.py b/docs/nbs/sparse_regression.pct.py index 1eec88617..f933eed6d 100644 --- a/docs/nbs/sparse_regression.pct.py +++ b/docs/nbs/sparse_regression.pct.py @@ -124,10 +124,10 @@ # Here, the variational process $q(\cdot)$ depends on the prior through $p(f(\cdot)|f(\boldsymbol{z}))$ in $(\times)$. # %% [markdown] # -# We combine our true and approximate posterior Gaussian processes into an `SVGP` object to define the variational strategy that we will adopt in the forthcoming inference. +# We combine our true and approximate posterior Gaussian processes into an `StochasticVI` object to define the variational strategy that we will adopt in the forthcoming inference. # %% -svgp = gpx.SVGP(posterior=p, variational_family=q) +svgp = gpx.StochasticVI(posterior=p, variational_family=q) # %% [markdown] # ## Inference diff --git a/gpjax/__init__.py b/gpjax/__init__.py index a407ad8e6..53136e8b4 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -20,7 +20,7 @@ from .likelihoods import Bernoulli, Gaussian from .mean_functions import Constant, Zero from .parameters import copy_dict_structure, initialise, transform -from .sparse_gps import SVGP +from .sparse_gps import StochasticVI from .types import Dataset from .variational import VariationalGaussian, WhitenedVariationalGaussian diff --git a/gpjax/sparse_gps.py b/gpjax/sparse_gps.py index 4c37e87e1..2a52d9b9d 100644 --- a/gpjax/sparse_gps.py +++ b/gpjax/sparse_gps.py @@ -14,8 +14,8 @@ @dataclass -class VariationalPosterior: - """A variational posterior object. With reference to some true posterior distribution :math:`p`, this can be used to minimise the KL-divergence between :math:`p` and a variational posterior :math:`q`.""" +class AbstractVariationalInference: + """A base class for inference and training of variational families against an extact posterior""" posterior: AbstractPosterior variational_family: VariationalFamily @@ -50,8 +50,8 @@ def elbo( @dataclass -class SVGP(VariationalPosterior): - """Sparse Variational Gaussian Process (SVGP) training module. The key reference is Hensman et. al., (2013) - Gaussian processes for big data.""" +class StochasticVI(AbstractVariationalInference): + """Stochastic Variational inference training module. The key reference is Hensman et. al., (2013) - Gaussian processes for big data.""" def __post_init__(self): self.prior = self.posterior.prior diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index 155b2377d..5c21f2598 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -51,7 +51,7 @@ def test_batch_fitting(nb, ndata): q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z) - svgp = gpx.SVGP(posterior=p, variational_family=q) + svgp = gpx.StochasticVI(posterior=p, variational_family=q) params, trainable_status, constrainer, unconstrainer = initialise(svgp) params = gpx.transform(params, unconstrainer) objective = svgp.elbo(D, constrainer) diff --git a/tests/test_sparse_gps.py b/tests/test_sparse_gps.py index 4d2277f15..b7ebf56d2 100644 --- a/tests/test_sparse_gps.py +++ b/tests/test_sparse_gps.py @@ -39,7 +39,7 @@ def test_svgp(n_datapoints, n_inducing_points, n_test, whiten, diag, jit_fns): inducing_inputs=inducing_inputs, diag=diag ) - svgp = gpx.SVGP(posterior=post, variational_family=q) + svgp = gpx.StochasticVI(posterior=post, variational_family=q) assert svgp.posterior.prior == post.prior assert svgp.posterior.likelihood == post.likelihood