Skip to content

Commit

Permalink
Merge pull request #426 from JaxGaussianProcesses/henry/new_optim
Browse files Browse the repository at this point in the history
Perhaps finally a decent LBFGS?
  • Loading branch information
thomaspinder authored Dec 3, 2023
2 parents 6955957 + 5308794 commit 9e4006c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 37 deletions.
7 changes: 2 additions & 5 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,11 @@
# With a posterior defined, we can now optimise the model's hyperparameters.

# %%
opt_posterior, training_history = gpx.fit(
opt_posterior, training_history = gpx.fit_scipy(
model=posterior,
objective=gpx.objectives.ConjugateMLL(negative=True),
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
key=key
)
)

# %% [markdown]
#
Expand Down
54 changes: 28 additions & 26 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
Union,
)
import jax
from jax import (
jit,
value_and_grad,
)
from jax._src.random import _check_prng_key
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
import jaxopt
import optax as ox
import scipy

from gpjax.base import Module
from gpjax.dataset import Dataset
Expand All @@ -42,10 +47,6 @@
ModuleModel = TypeVar("ModuleModel", bound=Module)


class FailedScipyFitError(Exception):
"""Raised a model fit using Scipy fails"""


def fit( # noqa: PLR0913
*,
model: ModuleModel,
Expand Down Expand Up @@ -214,30 +215,31 @@ def fit_scipy( # noqa: PLR0913
model = model.unconstrain()

# Unconstrained space loss function with stop-gradient rule for non-trainable params.
def loss(model: Module, data: Dataset) -> ScalarFloat:
def loss(model: Module) -> ScalarFloat:
model = model.stop_gradient()
return objective(model.constrain(), data)

solver = jaxopt.ScipyMinimize(
fun=loss,
maxiter=max_iters,
return objective(model.constrain(), train_data)

# convert to numpy for interface with scipy
x0, scipy_to_jnp = ravel_pytree(model)

@jit
def scipy_wrapper(x0):
value, grads = value_and_grad(loss)(scipy_to_jnp(jnp.array(x0)))
scipy_grads = ravel_pytree(grads)[0]
return value, scipy_grads

history = [scipy_wrapper(x0)[0]]
result = scipy.optimize.minimize(
fun=scipy_wrapper,
x0=x0,
jac=True,
callback=lambda X: history.append(scipy_wrapper(X)[0]),
options={"maxiter": max_iters, "disp": verbose},
)
history = jnp.array(history)

initial_loss = solver.fun(model, train_data)
model, result = solver.run(model, data=train_data)
history = jnp.array([initial_loss, result.fun_val])

if verbose:
print(f"Initial loss is {initial_loss}")
if result.success:
print("Optimization was successful")
else:
raise FailedScipyFitError(
"Optimization failed, try increasing max_iters or using a different optimiser."
)
print(f"Final loss is {result.fun_val} after {result.num_fun_eval} iterations")

# Constrained space.
# convert back to pytree and reconstrain
model = scipy_to_jnp(result.x)
model = model.constrain()
return model, history

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jax = ">=0.4.10"
jaxlib = ">=0.4.10"
orbax-checkpoint = ">=0.2.3"
cola-ml = "^0.0.5"
jaxopt = "^0.8.2"

[tool.poetry.group.test.dependencies]
pytest = "^7.2.2"
Expand Down
10 changes: 5 additions & 5 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
import optax as ox
import pytest
import scipy
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import (
Expand All @@ -33,7 +34,6 @@
)
from gpjax.dataset import Dataset
from gpjax.fit import (
FailedScipyFitError,
fit,
fit_scipy,
get_batch,
Expand Down Expand Up @@ -116,7 +116,7 @@ def step(self, model: LinearModel, train_data: Dataset) -> float:
)

# Ensure we return a history of the correct length
assert len(hist) == 2
assert len(hist) > 2

# Ensure we return a model of the same class
assert isinstance(trained_model, LinearModel)
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_gaussian_process_regression(n_data: int, verbose: bool) -> None:
assert isinstance(trained_model_bfgs, ConjugatePosterior)

# Ensure we return a history_bfgs of the correct length
assert len(history_bfgs) == 2
assert len(history_bfgs) > 2

# Ensure we reduce the loss
assert mll(trained_model_bfgs, D) < mll(posterior, D)
Expand All @@ -206,7 +206,7 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
# Define loss function:
mll = ConjugateMLL(negative=True)

with pytest.raises(FailedScipyFitError):
with pytest.raises(scipy.optimize.OptimizeWarning):
fit_scipy(
model=posterior,
objective=mll,
Expand All @@ -220,7 +220,7 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
posterior = prior * likelihood
mll = ConjugateMLL(negative=True)

with pytest.raises(FailedScipyFitError):
with pytest.raises(scipy.optimize.OptimizeWarning):
fit_scipy(
model=posterior,
objective=mll,
Expand Down

0 comments on commit 9e4006c

Please sign in to comment.