From 0a26a9f819667fe0a45f1cbd16176fb4f80093d9 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 1 Jun 2022 16:12:03 +0000 Subject: [PATCH] Update nonconjugate --- gpjax/gps.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 2636a8d84..68d71d550 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -240,11 +240,12 @@ def predict(self, train_data: Dataset, params: dict) -> tp.Callable[[Array], dx. def predict_fn(test_inputs: Array) -> dx.Distribution: t = test_inputs + nt = t.shape[0] Ktx = cross_covariance(self.prior.kernel, t, x, params["kernel"]) - Ktt = gram(self.prior.kernel, t, params["kernel"]) + Ktt = gram(self.prior.kernel, t, params["kernel"]) + I(nt) * self.jitter μt = self.prior.mean_function(t, params["mean_function"]) A = solve_triangular(Lx, Ktx.T, lower=True) - latent_var = Ktt - jnp.sum(jnp.square(A), -2) + latent_var = jnp.diag(jnp.diag(Ktt - jnp.sum(jnp.square(A), -2))) latent_mean = μt + jnp.matmul(A.T, params["latent"]) return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(latent_mean.squeeze()), latent_var @@ -301,4 +302,4 @@ def construct_posterior(prior: Prior, likelihood: AbstractLikelihood) -> Abstrac PosteriorGP = NonConjugatePosterior else: raise NotImplementedError(f"No posterior implemented for {likelihood.name} likelihood") - return PosteriorGP(prior=prior, likelihood=likelihood) \ No newline at end of file + return PosteriorGP(prior=prior, likelihood=likelihood)