Skip to content

Commit

Permalink
Revert "doesnt seem to be working"
Browse files Browse the repository at this point in the history
This reverts commit e3ab411.
  • Loading branch information
henrymoss committed Oct 21, 2023
1 parent 09dc153 commit 70838a8
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 160 deletions.
6 changes: 4 additions & 2 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood

opt_posterior, _ = gpx.fit_bfgs(
opt_posterior, _ = gpx.fit(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
max_iters=500,
optim=ox.adamw(learning_rate=0.01),
num_iters=500,
key=key,
)
latent_dist = opt_posterior.predict(xtest, train_data=D)
return opt_posterior.likelihood(latent_dist)
Expand Down
7 changes: 5 additions & 2 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,14 @@ def __call__(
likelihood = gpx.Gaussian(num_datapoints=n)
circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood

# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_bfgs(
# Optimise GP's marginal log-likelihood using Adam
opt_posterior, history = gpx.fit(
model=circular_posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.05),
num_iters=500,
key=key,
)

# %% [markdown]
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@
# With a posterior defined, we can now optimise the model's hyperparameters.

# %%
opt_posterior, training_history = gpx.fit_bfgs(
opt_posterior, training_history = gpx.fit(
model=posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=1000,
key=key,
)

# %% [markdown]
Expand Down
6 changes: 5 additions & 1 deletion docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,14 @@ def optimise_mll(posterior, dataset, NIters=1000, key=key, plot_history=True):
objective = gpx.objectives.ConjugateMLL(negative=True)
# Optimise to minimise the MLL
optimiser = ox.adam(learning_rate=0.1)
opt_posterior, history = gpx.fit_bfgs(
opt_posterior, history = gpx.fit(
model=posterior,
objective=objective,
train_data=dataset,
optim=optimiser,
num_iters=NIters,
safe=True,
key=key,
)
# plot MLL value at each iteration
if plot_history:
Expand Down
8 changes: 6 additions & 2 deletions docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,18 @@
# accelerate training.

# %% [markdown]
# We can now define an optimiser with `jaxopt`. For this example we'll use the `bfgs`
# We can now define an optimiser with `optax`. For this example we'll use the `adam`
# optimiser.

# %%
opt_posterior, history = gpx.fit_bfgs(
opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=500,
safe=True,
key=key,
)

# %% [markdown]
Expand Down
3 changes: 1 addition & 2 deletions gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from gpjax.citation import cite
from gpjax.dataset import Dataset
from gpjax.fit import fit, fit_bfgs
from gpjax.fit import fit
from gpjax.gps import (
Prior,
construct_posterior,
Expand Down Expand Up @@ -87,7 +87,6 @@
"decision_making",
"kernels",
"fit",
"fit_bfgs",
"Prior",
"construct_posterior",
"integrators",
Expand Down
104 changes: 0 additions & 104 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@
Union,
)
import jax
import jax.numpy as jnp
from jax._src.random import _check_prng_key
import jax.random as jr
import optax as ox
import jaxopt

from gpjax.base import Module
from gpjax.dataset import Dataset
Expand Down Expand Up @@ -170,100 +168,6 @@ def step(carry, key):
model = model.constrain()

return model, history





def fit_bfgs( # noqa: PLR0913
*,
model: ModuleModel,
objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
train_data: Dataset,
max_iters: Optional[int] = 500,
tol: float = 0.01,
verbose: Optional[bool] = True,
safe: Optional[bool] = True,
) -> Tuple[ModuleModel, Array]:
r"""Train a Module model with respect to a supplied Objective function.
Optimisers used here should originate from Optax. todo
Args:
model (Module): The model Module to be optimised.
objective (Objective): The objective function that we are optimising with
respect to.
train_data (Dataset): The training data to be used for the optimisation.
max_iters (Optional[int]): The maximum number of optimisation steps to run. Defaults
to 500.
tol (Optional[float]): The tolerance for termination. Defaults to scipy default.
verbose (Optional[bool]): Whether to print the information about the optimisation. Defaults
to True.
Returns
-------
Tuple[Module, Array]: A Tuple comprising the optimised model and training
history respectively.
"""
if safe:
# Check inputs.
_check_model(model)
_check_train_data(train_data)
_check_num_iters(max_iters)
_check_verbose(verbose)
_check_tol(tol)


# Unconstrained space model.
model = model.unconstrain()

# Unconstrained space loss function with stop-gradient rule for non-trainable params.
def loss(model: Module, data: Dataset) -> ScalarFloat:
model = model.stop_gradient()
return objective(model.constrain(), data)

solver = jaxopt.BFGS(
fun=loss,
maxiter=max_iters,
tol=tol,
# method="L-BFGS-B",
#implicit_diff=False,
)

initial_loss = solver.fun(model, train_data)
model, result = solver.run(model, data = train_data)
history = jnp.array([initial_loss, result.value])

if verbose:
print(f"Initial loss: {initial_loss}")
# if result.success:
# print(f"Optimization was successful")
# else:
# print(f"Optimization was not successful")
print(f"Final loss {result.value} after {result.num_fun_eval} iterations")

# Constrained space.
model = model.constrain()
return model, history






















def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
Expand Down Expand Up @@ -321,14 +225,6 @@ def _check_log_rate(log_rate: Any) -> None:
if not log_rate > 0:
raise ValueError("log_rate must be positive")

def _check_tol(tol: Any) -> None:
"""Check that the tolerance is of type float and positive."""
if not isinstance(tol, float):
raise TypeError("tol must be of type float or None")

if not tol > 0:
raise ValueError("tol must be positive")


def _check_verbose(verbose: Any) -> None:
"""Check that the verbose is of type bool."""
Expand Down
47 changes: 1 addition & 46 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from gpjax.dataset import Dataset
from gpjax.fit import (
fit,
fit_bfgs,
get_batch,
)
from gpjax.gps import (
Expand Down Expand Up @@ -98,28 +97,6 @@ def step(self, model: LinearModel, train_data: Dataset) -> float:
# Test stop_gradient on bias:
assert trained_model.bias == 1.0

# Train with bfgs!
trained_model, hist = fit_bfgs(
model=model,
objective=loss,
train_data=D,
max_iters=10,
)

# Ensure we return a history of the correct length
assert len(hist) == 2

# Ensure we return a model of the same class
assert isinstance(trained_model, LinearModel)

# Test reduction in loss:
assert loss(trained_model, D) < loss(model, D)

# Test stop_gradient on bias:
assert trained_model.bias == 1.0




@pytest.mark.parametrize("num_iters", [1, 5])
@pytest.mark.parametrize("n_data", [1, 20])
Expand All @@ -141,7 +118,7 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N
# Define loss function:
mll = ConjugateMLL(negative=True)

# Train with optax!
# Train!
trained_model, history = fit(
model=posterior,
objective=mll,
Expand All @@ -161,28 +138,6 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N
# Ensure we reduce the loss
assert mll(trained_model, D) < mll(posterior, D)

# Train with BFGS!
trained_model_bfgs, history_bfgs = fit_bfgs(
model=posterior,
objective=mll,
train_data=D,
max_iters=num_iters,
verbose=verbose,
)

# Ensure the trained model is a Gaussian process posterior
assert isinstance(trained_model_bfgs, ConjugatePosterior)

# Ensure we return a history_bfgs of the correct length
assert len(history_bfgs) == 2

# Ensure we reduce the loss
assert mll(trained_model_bfgs, D) < mll(posterior, D)






@pytest.mark.parametrize("num_iters", [1, 5])
@pytest.mark.parametrize("batch_size", [1, 20, 50])
Expand Down

0 comments on commit 70838a8

Please sign in to comment.