Skip to content

Commit

Permalink
Update notebooks. (#447)
Browse files Browse the repository at this point in the history
* Update yacht.py

* Update likelihoods_guide.py

* Revert "Update likelihoods_guide.py"

This reverts commit 5f51cfe.

* Update oceanmodelling.py
  • Loading branch information
daniel-dodd authored and thomaspinder committed Jul 9, 2024
1 parent 2d3951a commit 1783be0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
16 changes: 7 additions & 9 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,14 @@ def dataset_3d(pos, vel):


# %%


@dataclass
class VelocityKernel(gpx.kernels.AbstractKernel):
kernel0: gpx.kernels.AbstractKernel = field(
default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
)
kernel1: gpx.kernels.AbstractKernel = field(
default_factory=lambda: gpx.kernels.RBF(active_dims=[0, 1])
)
def __init__(
self,
kernel0: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]),
kernel1: gpx.kernels.AbstractKernel = gpx.kernels.RBF(active_dims=[0, 1]),
):
self.kernel0 = kernel0
self.kernel1 = kernel1

def __call__(
self, X: Float[Array, "1 D"], Xp: Float[Array, "1 D"]
Expand Down
12 changes: 7 additions & 5 deletions docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from jax import jit
import jax.random as jr
import jax.numpy as jnp
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -169,8 +170,8 @@
n_train, n_covariates = scaled_Xtr.shape
kernel = gpx.kernels.RBF(
active_dims=list(range(n_covariates)),
variance=np.var(scaled_ytr),
lengthscale=0.1 * np.ones((n_covariates,)),
variance=jnp.var(scaled_ytr),
lengthscale=0.1 * jnp.ones((n_covariates,)),
)
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
Expand All @@ -188,14 +189,15 @@
# %%
training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr)

negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=negative_mll,
# we use the negative mll as we are minimising
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=training_data,
)

print(-gpx.objectives.conjugate_mll(opt_posterior, training_data))

# %% [markdown]
# ## Prediction
#
Expand Down

0 comments on commit 1783be0

Please sign in to comment.