Skip to content

Commit

Permalink
Update nonconjugate
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Pinder committed Jun 1, 2022
1 parent 013adf9 commit 0a26a9f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,12 @@ def predict(self, train_data: Dataset, params: dict) -> tp.Callable[[Array], dx.

def predict_fn(test_inputs: Array) -> dx.Distribution:
t = test_inputs
nt = t.shape[0]
Ktx = cross_covariance(self.prior.kernel, t, x, params["kernel"])
Ktt = gram(self.prior.kernel, t, params["kernel"])
Ktt = gram(self.prior.kernel, t, params["kernel"]) + I(nt) * self.jitter
μt = self.prior.mean_function(t, params["mean_function"])
A = solve_triangular(Lx, Ktx.T, lower=True)
latent_var = Ktt - jnp.sum(jnp.square(A), -2)
latent_var = jnp.diag(jnp.diag(Ktt - jnp.sum(jnp.square(A), -2)))
latent_mean = μt + jnp.matmul(A.T, params["latent"])
return dx.MultivariateNormalFullCovariance(
jnp.atleast_1d(latent_mean.squeeze()), latent_var
Expand Down Expand Up @@ -301,4 +302,4 @@ def construct_posterior(prior: Prior, likelihood: AbstractLikelihood) -> Abstrac
PosteriorGP = NonConjugatePosterior
else:
raise NotImplementedError(f"No posterior implemented for {likelihood.name} likelihood")
return PosteriorGP(prior=prior, likelihood=likelihood)
return PosteriorGP(prior=prior, likelihood=likelihood)

0 comments on commit 0a26a9f

Please sign in to comment.