diff --git a/docs/examples/barycentres.py b/docs/examples/barycentres.py index 25d872d5c..9ab7a2d83 100644 --- a/docs/examples/barycentres.py +++ b/docs/examples/barycentres.py @@ -137,11 +137,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: likelihood = gpx.Gaussian(num_datapoints=n) posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood - opt_posterior, _ = gpx.fit_bfgs( + opt_posterior, _ = gpx.fit( model=posterior, objective=jax.jit(gpx.ConjugateMLL(negative=True)), train_data=D, - max_iters=500, + optim=ox.adamw(learning_rate=0.01), + num_iters=500, + key=key, ) latent_dist = opt_posterior.predict(xtest, train_data=D) return opt_posterior.likelihood(latent_dist) diff --git a/docs/examples/constructing_new_kernels.py b/docs/examples/constructing_new_kernels.py index dabf66cf6..9355b614f 100644 --- a/docs/examples/constructing_new_kernels.py +++ b/docs/examples/constructing_new_kernels.py @@ -267,11 +267,14 @@ def __call__( likelihood = gpx.Gaussian(num_datapoints=n) circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood -# Optimise GP's marginal log-likelihood using BFGS -opt_posterior, history = gpx.fit_bfgs( +# Optimise GP's marginal log-likelihood using Adam +opt_posterior, history = gpx.fit( model=circular_posterior, objective=jit(gpx.ConjugateMLL(negative=True)), train_data=D, + optim=ox.adamw(learning_rate=0.05), + num_iters=500, + key=key, ) # %% [markdown] diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py index 6c1884817..82154b3a4 100644 --- a/docs/examples/graph_kernels.py +++ b/docs/examples/graph_kernels.py @@ -153,10 +153,13 @@ # With a posterior defined, we can now optimise the model's hyperparameters. # %% -opt_posterior, training_history = gpx.fit_bfgs( +opt_posterior, training_history = gpx.fit( model=posterior, objective=jit(gpx.ConjugateMLL(negative=True)), train_data=D, + optim=ox.adamw(learning_rate=0.01), + num_iters=1000, + key=key, ) # %% [markdown] diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index 4791d8bcc..3488d1640 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -248,10 +248,14 @@ def optimise_mll(posterior, dataset, NIters=1000, key=key, plot_history=True): objective = gpx.objectives.ConjugateMLL(negative=True) # Optimise to minimise the MLL optimiser = ox.adam(learning_rate=0.1) - opt_posterior, history = gpx.fit_bfgs( + opt_posterior, history = gpx.fit( model=posterior, objective=objective, train_data=dataset, + optim=optimiser, + num_iters=NIters, + safe=True, + key=key, ) # plot MLL value at each iteration if plot_history: diff --git a/docs/examples/regression.py b/docs/examples/regression.py index 2d83b8536..bccbb8068 100644 --- a/docs/examples/regression.py +++ b/docs/examples/regression.py @@ -210,14 +210,18 @@ # accelerate training. # %% [markdown] -# We can now define an optimiser with `jaxopt`. For this example we'll use the `bfgs` +# We can now define an optimiser with `optax`. For this example we'll use the `adam` # optimiser. # %% -opt_posterior, history = gpx.fit_bfgs( +opt_posterior, history = gpx.fit( model=posterior, objective=negative_mll, train_data=D, + optim=ox.adam(learning_rate=0.01), + num_iters=500, + safe=True, + key=key, ) # %% [markdown] diff --git a/gpjax/__init__.py b/gpjax/__init__.py index d68e6685e..eca71e399 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -22,7 +22,7 @@ ) from gpjax.citation import cite from gpjax.dataset import Dataset -from gpjax.fit import fit, fit_bfgs +from gpjax.fit import fit from gpjax.gps import ( Prior, construct_posterior, @@ -87,7 +87,6 @@ "decision_making", "kernels", "fit", - "fit_bfgs", "Prior", "construct_posterior", "integrators", diff --git a/gpjax/fit.py b/gpjax/fit.py index 7d7190252..69b6a699a 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -23,11 +23,9 @@ Union, ) import jax -import jax.numpy as jnp from jax._src.random import _check_prng_key import jax.random as jr import optax as ox -import jaxopt from gpjax.base import Module from gpjax.dataset import Dataset @@ -170,100 +168,6 @@ def step(carry, key): model = model.constrain() return model, history - - - - - -def fit_bfgs( # noqa: PLR0913 - *, - model: ModuleModel, - objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]], - train_data: Dataset, - max_iters: Optional[int] = 500, - tol: float = 0.01, - verbose: Optional[bool] = True, - safe: Optional[bool] = True, -) -> Tuple[ModuleModel, Array]: - r"""Train a Module model with respect to a supplied Objective function. - Optimisers used here should originate from Optax. todo - - Args: - model (Module): The model Module to be optimised. - objective (Objective): The objective function that we are optimising with - respect to. - train_data (Dataset): The training data to be used for the optimisation. - max_iters (Optional[int]): The maximum number of optimisation steps to run. Defaults - to 500. - tol (Optional[float]): The tolerance for termination. Defaults to scipy default. - verbose (Optional[bool]): Whether to print the information about the optimisation. Defaults - to True. - - Returns - ------- - Tuple[Module, Array]: A Tuple comprising the optimised model and training - history respectively. - """ - if safe: - # Check inputs. - _check_model(model) - _check_train_data(train_data) - _check_num_iters(max_iters) - _check_verbose(verbose) - _check_tol(tol) - - - # Unconstrained space model. - model = model.unconstrain() - - # Unconstrained space loss function with stop-gradient rule for non-trainable params. - def loss(model: Module, data: Dataset) -> ScalarFloat: - model = model.stop_gradient() - return objective(model.constrain(), data) - - solver = jaxopt.BFGS( - fun=loss, - maxiter=max_iters, - tol=tol, - # method="L-BFGS-B", - #implicit_diff=False, - ) - - initial_loss = solver.fun(model, train_data) - model, result = solver.run(model, data = train_data) - history = jnp.array([initial_loss, result.value]) - - if verbose: - print(f"Initial loss: {initial_loss}") - # if result.success: - # print(f"Optimization was successful") - # else: - # print(f"Optimization was not successful") - print(f"Final loss {result.value} after {result.num_fun_eval} iterations") - - # Constrained space. - model = model.constrain() - return model, history - - - - - - - - - - - - - - - - - - - - def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: @@ -321,14 +225,6 @@ def _check_log_rate(log_rate: Any) -> None: if not log_rate > 0: raise ValueError("log_rate must be positive") -def _check_tol(tol: Any) -> None: - """Check that the tolerance is of type float and positive.""" - if not isinstance(tol, float): - raise TypeError("tol must be of type float or None") - - if not tol > 0: - raise ValueError("tol must be positive") - def _check_verbose(verbose: Any) -> None: """Check that the verbose is of type bool.""" diff --git a/tests/test_fit.py b/tests/test_fit.py index 73966ae3f..6eeaed992 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -30,7 +30,6 @@ from gpjax.dataset import Dataset from gpjax.fit import ( fit, - fit_bfgs, get_batch, ) from gpjax.gps import ( @@ -98,28 +97,6 @@ def step(self, model: LinearModel, train_data: Dataset) -> float: # Test stop_gradient on bias: assert trained_model.bias == 1.0 - # Train with bfgs! - trained_model, hist = fit_bfgs( - model=model, - objective=loss, - train_data=D, - max_iters=10, - ) - - # Ensure we return a history of the correct length - assert len(hist) == 2 - - # Ensure we return a model of the same class - assert isinstance(trained_model, LinearModel) - - # Test reduction in loss: - assert loss(trained_model, D) < loss(model, D) - - # Test stop_gradient on bias: - assert trained_model.bias == 1.0 - - - @pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("n_data", [1, 20]) @@ -141,7 +118,7 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N # Define loss function: mll = ConjugateMLL(negative=True) - # Train with optax! + # Train! trained_model, history = fit( model=posterior, objective=mll, @@ -161,28 +138,6 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N # Ensure we reduce the loss assert mll(trained_model, D) < mll(posterior, D) - # Train with BFGS! - trained_model_bfgs, history_bfgs = fit_bfgs( - model=posterior, - objective=mll, - train_data=D, - max_iters=num_iters, - verbose=verbose, - ) - - # Ensure the trained model is a Gaussian process posterior - assert isinstance(trained_model_bfgs, ConjugatePosterior) - - # Ensure we return a history_bfgs of the correct length - assert len(history_bfgs) == 2 - - # Ensure we reduce the loss - assert mll(trained_model_bfgs, D) < mll(posterior, D) - - - - - @pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("batch_size", [1, 20, 50])