diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index 77f125b12..5d3490495 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -195,16 +195,14 @@ def dataset_3d(pos, vel): # %% - - -@dataclass class VelocityKernel(gpx.kernels.AbstractKernel): - kernel0: gpx.kernels.AbstractKernel = field( - default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) - ) - kernel1: gpx.kernels.AbstractKernel = field( - default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1]) - ) + def __init__( + self, + kernel0: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]), + kernel1: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]), + ): + self.kernel0 = kernel0 + self.kernel1 = kernel1 def __call__( self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"] diff --git a/docs/examples/yacht.py b/docs/examples/yacht.py index b5d7fd0dc..c1d0958aa 100644 --- a/docs/examples/yacht.py +++ b/docs/examples/yacht.py @@ -31,6 +31,7 @@ from jax import jit import jax.random as jr +import jax.numpy as jnp from jaxtyping import install_import_hook import matplotlib as mpl import matplotlib.pyplot as plt @@ -169,8 +170,8 @@ n_train, n_covariates = scaled_Xtr.shape kernel = gpx.kernels.RBF( active_dims=list(range(n_covariates)), - variance=np.var(scaled_ytr), - lengthscale=0.1 * np.ones((n_covariates,)), + variance=jnp.var(scaled_ytr), + lengthscale=0.1 * jnp.ones((n_covariates,)), ) meanf = gpx.mean_functions.Zero() prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) @@ -188,14 +189,15 @@ # %% training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr) -negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True)) - opt_posterior, history = gpx.fit_scipy( model=posterior, - objective=negative_mll, + # we use the negative mll as we are minimising + objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d), train_data=training_data, ) +print(-gpx.objectives.conjugate_mll(opt_posterior, training_data)) + # %% [markdown] # ## Prediction #