From 30ed414f1a241be17ddc8648df6146094e4b8f82 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 22 Mar 2024 12:33:05 +0000 Subject: [PATCH] Update notebooks. (#447) * Update yacht.py * Update likelihoods_guide.py * Revert "Update likelihoods_guide.py" This reverts commit 5f51cfe76fa0bec6dfd7fcbfc69495ff126b6818. * Update oceanmodelling.py --- docs/examples/oceanmodelling.py | 16 +++++++--------- docs/examples/yacht.py | 12 +++++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index 77f125b12..5d3490495 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -195,16 +195,14 @@ def dataset_3d(pos, vel): # %% - - -@dataclass class VelocityKernel(gpx.kernels.AbstractKernel): - kernel0: gpx.kernels.AbstractKernel = field( - default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) - ) - kernel1: gpx.kernels.AbstractKernel = field( - default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) - ) + def __init__( + self, + kernel0: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]), + kernel1: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]), + ): + self.kernel0 = kernel0 + self.kernel1 = kernel1 def __call__( self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"] diff --git a/docs/examples/yacht.py b/docs/examples/yacht.py index b5d7fd0dc..c1d0958aa 100644 --- a/docs/examples/yacht.py +++ b/docs/examples/yacht.py @@ -31,6 +31,7 @@ from jax import jit import jax.random as jr +import jax.numpy as jnp from jaxtyping import install_import_hook import matplotlib as mpl import matplotlib.pyplot as plt @@ -169,8 +170,8 @@ n_train, n_covariates = scaled_Xtr.shape kernel = gpx.kernels.RBF( active_dims=list(range(n_covariates)), - variance=np.var(scaled_ytr), - lengthscale=0.1 * np.ones((n_covariates,)), + variance=jnp.var(scaled_ytr), + lengthscale=0.1 * jnp.ones((n_covariates,)), ) meanf = gpx.mean_functions.Zero() prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) @@ -188,14 +189,15 @@ # %% training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr) -negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True)) - opt_posterior, history = gpx.fit_scipy( model=posterior, - objective=negative_mll, + # we use the negative mll as we are minimising + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), train_data=training_data, ) +print(-gpx.objectives.conjugate_mll(opt_posterior, training_data)) + # %% [markdown] # ## Prediction #