Skip to content

Commit

Permalink
updated notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss committed Sep 19, 2023
1 parent 3f73d19 commit eeadba6
Show file tree
Hide file tree
Showing 11 changed files with 14 additions and 40 deletions.
6 changes: 1 addition & 5 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import jax.scipy.linalg as jsl
from jaxtyping import install_import_hook
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import tensorflow_probability.substrates.jax.distributions as tfd

Expand Down Expand Up @@ -137,13 +136,10 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:

likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood

opt_posterior, _ = gpx.fit(
model=posterior,
train_data=D,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adam(0.01), maxiter=500
),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
key=key,
)
latent_dist = opt_posterior.predict(xtest, train_data=D)
Expand Down
5 changes: 1 addition & 4 deletions docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
import optax as ox
import jaxopt
import tensorflow_probability.substrates.jax as tfp
from typing import List, Tuple
Expand Down Expand Up @@ -218,9 +217,7 @@ def return_optimised_posterior(
opt_posterior, history = gpx.fit(
model=posterior,
train_data=D,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adam(0.01), maxiter=1000
),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
safe=True,
key=key,
verbose=False,
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
install_import_hook,
)
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import tensorflow_probability.substrates.jax as tfp
from tqdm import trange
Expand Down Expand Up @@ -122,7 +121,7 @@
opt_posterior, history = gpx.fit(
model=posterior,
train_data=D,
solver=jaxopt.OptaxSolver(negative_lpd, opt=ox.adam(0.01), maxiter=1000),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
key=key,
)

Expand Down
4 changes: 1 addition & 3 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ def __call__(
opt_posterior, history = gpx.fit(
model=circular_posterior,
train_data=D,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.05), maxiter=500
),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
key=key,
)

Expand Down
10 changes: 3 additions & 7 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import jax.random as jr
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox

import gpjax as gpx
from gpjax.decision_making.utility_functions import (
Expand Down Expand Up @@ -164,17 +163,14 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# defined above. We tend to also optimise the hyperparameters of the GP prior when
# "fitting" our GP, as demonstrated in the [Regression
# notebook](https://docs.jaxgaussianprocesses.com/examples/regression/). This will be
# using the GPJax `fit` method under the hood, which requires an `optimization_objective`,
# `optimizer` and `num_optimization_iters`. Therefore, we also pass these to the
# `PosteriorHandler` as demonstrated below:
# using the GPJax `fit` method under the hood, which requires an jaxopt `solver`.
# Therefore, we also pass this to the `PosteriorHandler` as demonstrated below:

# %%
posterior_handler = PosteriorHandler(
prior,
likelihood_builder=likelihood_builder,
optimization_objective=gpx.ConjugateMLL(negative=True),
optimizer=ox.adam(learning_rate=0.01),
num_optimization_iters=1000,
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
)
posterior_handlers = {OBJECTIVE: posterior_handler}

Expand Down
5 changes: 1 addition & 4 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import optax as ox
import jaxopt

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down Expand Up @@ -157,9 +156,7 @@
opt_posterior, training_history = gpx.fit(
model=posterior,
train_data=D,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.01), maxiter=1000
),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
key=key,
)

Expand Down
9 changes: 2 additions & 7 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from jaxtyping import install_import_hook, Float
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import pandas as pd
from docs.examples.utils import clean_legend
Expand Down Expand Up @@ -238,9 +237,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
opt_posterior, history = gpx.fit(
model=no_opt_posterior,
train_data=D,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.01), maxiter=2000
),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
safe=True,
key=key,
)
Expand Down Expand Up @@ -540,9 +537,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
opt_posterior, history = gpx.fit(
model=posterior,
train_data=D,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.01), maxiter=1000
),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
safe=True,
key=key,
)
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from matplotlib import rcParams
import matplotlib.pyplot as plt
import jaxopt
import optax as ox
import pandas as pd
import tensorflow_probability as tfp

Expand Down Expand Up @@ -251,7 +250,7 @@ def optimise_mll(posterior, dataset, NIters=1000, key=key, plot_history=True):
opt_posterior, history = gpx.fit(
model=posterior,
train_data=dataset,
solver=jaxopt.OptaxSolver(objective, opt=ox.adam(0.1), maxiter=NIters),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=NIters),
safe=True,
key=key,
)
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
from docs.examples.utils import clean_legend

Expand Down Expand Up @@ -218,7 +217,7 @@
opt_posterior, history = gpx.fit(
model=posterior,
train_data=D,
solver=jaxopt.OptaxSolver(negative_mll, opt=ox.adamw(0.01), maxiter=500),
solver=jaxopt.LBFGS(negative_mll, maxiter=500),
safe=True,
key=key,
)
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
)
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import pandas as pd
import planetary_computer
Expand Down Expand Up @@ -191,7 +190,7 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
posterior, history = gpx.fit(
model=posterior,
train_data=D,
solver=jaxopt.OptaxSolver(negative_mll, opt=optim, maxiter=3000),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=1000),
safe=True,
key=key,
)
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import optax as ox
import jaxopt
import pandas as pd
from sklearn.metrics import (
Expand Down Expand Up @@ -195,7 +194,7 @@
opt_posterior, history = gpx.fit(
model=posterior,
train_data=training_data,
solver=jaxopt.OptaxSolver(negative_mll, opt=ox.adamw(0.05), maxiter=500),
solver=jaxopt.LBFGS(gpx.ConjugateMLL(negative=True), maxiter=500),
key=key,
)

Expand Down

0 comments on commit eeadba6

Please sign in to comment.