Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model fitting using jaxopt solvers #364

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ package.json
package-lock.json
node_modules/

docs/api
docs/api
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ from jax import config
config.update("jax_enable_x64", True)

import gpjax as gpx
import jax
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import optax as ox
import jaxopt

jax.config.update("jax_enable_x64", True)

key = jr.PRNGKey(123)

Expand All @@ -120,19 +124,17 @@ likelihood = gpx.Gaussian(num_datapoints = n)
# Construct the posterior
posterior = prior * likelihood

# Define an optimiser
optimiser = ox.adam(learning_rate=1e-2)

# Define the marginal log-likelihood
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

# Define a solver
solver = jaxopt.OptaxSolver(negative_mll, ox.adam(learning_rate=1e-2), maxiter=500)

# Obtain Type 2 MLEs of the hyperparameters
opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=optimiser,
num_iters=500,
solver=solver,
safe=True,
key=key,
)
Expand Down
7 changes: 4 additions & 3 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from jaxtyping import install_import_hook
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import tensorflow_probability.substrates.jax.distributions as tfd

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down Expand Up @@ -139,10 +140,10 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:

opt_posterior, _ = gpx.fit(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=500,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adam(0.01), maxiter=500
),
key=key,
)
latent_dist = opt_posterior.predict(xtest, train_data=D)
Expand Down
9 changes: 5 additions & 4 deletions docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import matplotlib.pyplot as plt
from matplotlib import cm
import optax as ox
import jaxopt
import tensorflow_probability.substrates.jax as tfp
from typing import List, Tuple

Expand Down Expand Up @@ -216,10 +217,10 @@ def return_optimised_posterior(

opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=data,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
train_data=D,
daniel-dodd marked this conversation as resolved.
Show resolved Hide resolved
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adam(0.01), maxiter=1000
),
safe=True,
key=key,
verbose=False,
Expand Down
779 changes: 0 additions & 779 deletions docs/examples/classification.ipynb

This file was deleted.

9 changes: 3 additions & 6 deletions docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import tensorflow_probability.substrates.jax as tfp
from tqdm import trange

Expand Down Expand Up @@ -113,19 +114,15 @@

# %% [markdown]
# We can obtain a MAP estimate by optimising the log-posterior density with
# Optax's optimisers.
# `jaxopt` solvers.

# %%
negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True))

optimiser = ox.adam(learning_rate=0.01)

opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_lpd,
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=1000,
solver=jaxopt.OptaxSolver(negative_lpd, opt=ox.adam(0.01), maxiter=1000),
key=key,
)

Expand Down
5 changes: 2 additions & 3 deletions docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down Expand Up @@ -155,10 +156,8 @@
# %%
opt_posterior, history = gpx.fit(
model=q,
objective=elbo,
train_data=D,
optim=ox.adamw(learning_rate=1e-2),
num_iters=500,
solver=jaxopt.OptaxSolver(elbo, opt=ox.adamw(1e-2), maxiter=500),
key=key,
)

Expand Down
7 changes: 4 additions & 3 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import matplotlib.pyplot as plt
import numpy as np
import optax as ox
import jaxopt
from simple_pytree import static_field
import tensorflow_probability.substrates.jax as tfp

Expand Down Expand Up @@ -270,10 +271,10 @@ def __call__(
# Optimise GP's marginal log-likelihood using Adam
opt_posterior, history = gpx.fit(
model=circular_posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.05),
num_iters=500,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.05), maxiter=500
),
key=key,
)

Expand Down
11 changes: 6 additions & 5 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
from scipy.signal import sawtooth
from gpjax.base import static_field

Expand Down Expand Up @@ -182,8 +183,8 @@ def __call__(self, x):
# hyperparameter set.
#
# With the inclusion of a neural network, we take this opportunity to highlight the
# additional benefits gleaned from using
# [Optax](https://optax.readthedocs.io/en/latest/) for optimisation. In particular, we
# additional benefits gleaned from using `jaxopt`'s
# [Optax](https://optax.readthedocs.io/en/latest/) solver for optimisation. In particular, we
# showcase the ability to use a learning rate scheduler that decays the optimiser's
# learning rate throughout the inference. We decrease the learning rate according to a
# half-cosine curve over 700 iterations, providing us with large step sizes early in
Expand All @@ -207,10 +208,10 @@ def __call__(self, x):

opt_posterior, history = gpx.fit(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=optimiser,
num_iters=800,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=optimiser, maxiter=800
),
key=key,
)

Expand Down
9 changes: 5 additions & 4 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import matplotlib.pyplot as plt
import networkx as nx
import optax as ox
import jaxopt

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
Expand Down Expand Up @@ -132,7 +133,7 @@
# For this reason, we simply perform gradient descent on the GP's marginal
# log-likelihood term as in the
# [regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
# We do this using the Adam optimiser provided in `optax`.
# We do this using the OptaxSolver provided by `jaxopt`, instantiated with the Adam optimiser.

# %%
likelihood = gpx.Gaussian(num_datapoints=D.n)
Expand All @@ -155,10 +156,10 @@
# %%
opt_posterior, training_history = gpx.fit(
model=posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=1000,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.01), maxiter=1000
),
key=key,
)

Expand Down
19 changes: 7 additions & 12 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import pandas as pd
from docs.examples.utils import clean_legend

Expand Down Expand Up @@ -233,16 +234,13 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# We can then optimise the hyperparameters by minimising the negative log marginal likelihood of the data:

# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(no_opt_posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
model=no_opt_posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=2000,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.01), maxiter=2000
),
safe=True,
key=key,
)
Expand Down Expand Up @@ -538,16 +536,13 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# marginal likelihood of the data:

# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
solver=jaxopt.OptaxSolver(
gpx.ConjugateMLL(negative=True), opt=ox.adamw(0.01), maxiter=1000
),
safe=True,
key=key,
)
Expand Down
6 changes: 2 additions & 4 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from matplotlib import rcParams
import matplotlib.pyplot as plt
import jaxopt
import optax as ox
import pandas as pd
import tensorflow_probability as tfp
Expand Down Expand Up @@ -247,13 +248,10 @@ def optimise_mll(posterior, dataset, NIters=1000, key=key, plot_history=True):
# define the MLL using dataset_train
objective = gpx.objectives.ConjugateMLL(negative=True)
# Optimise to minimise the MLL
optimiser = ox.adam(learning_rate=0.1)
opt_posterior, history = gpx.fit(
model=posterior,
objective=objective,
train_data=dataset,
optim=optimiser,
num_iters=NIters,
solver=jaxopt.OptaxSolver(objective, opt=ox.adam(0.1), maxiter=NIters),
safe=True,
key=key,
)
Expand Down
9 changes: 4 additions & 5 deletions docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down Expand Up @@ -210,16 +211,14 @@
# accelerate training.

# %% [markdown]
# We can now define an optimiser with `optax`. For this example we'll use the `adam`
# optimiser.
# We can now train our model using a `jaxopt` solver. In this case we opt for the `OptaxSolver`,
# which wraps an `optax` optimizer.

# %%
opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=500,
solver=jaxopt.OptaxSolver(negative_mll, opt=ox.adamw(0.01), maxiter=500),
safe=True,
key=key,
)
Expand Down
5 changes: 2 additions & 3 deletions docs/examples/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import pandas as pd
import planetary_computer
import pystac_client
Expand Down Expand Up @@ -189,10 +190,8 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
optim = ox.chain(ox.adam(learning_rate=0.1), ox.clip(1.0))
posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=optim,
num_iters=3000,
solver=jaxopt.OptaxSolver(negative_mll, opt=optim, maxiter=3000),
safe=True,
key=key,
)
Expand Down
11 changes: 4 additions & 7 deletions docs/examples/uncollapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
import jaxopt
import tensorflow_probability.substrates.jax as tfp

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down Expand Up @@ -228,7 +229,7 @@
# see Sections 3.1 and 4.1 of the excellent review paper
# <strong data-cite="leibfried2020tutorial"></strong>.
#
# Since Optax's optimisers work to minimise functions, to maximise the ELBO we return
# Since `jaxopt's solvers work to minimise functions, to maximise the ELBO we return
# its negative.

# %%
Expand Down Expand Up @@ -266,10 +267,8 @@

opt_posterior, history = gpx.fit(
model=q,
objective=negative_elbo,
train_data=D,
optim=ox.adam(learning_rate=schedule),
num_iters=3000,
solver=jaxopt.OptaxSolver(negative_elbo, opt=ox.adam(schedule), maxiter=3000),
key=jr.PRNGKey(42),
batch_size=128,
)
Expand Down Expand Up @@ -330,10 +329,8 @@
# %%
opt_rep, history = gpx.fit(
model=reparameterised_q,
objective=negative_elbo,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=3000,
solver=jaxopt.OptaxSolver(negative_elbo, opt=ox.adam(0.01), maxiter=3000),
key=jr.PRNGKey(42),
batch_size=128,
)
Expand Down
Loading