From 70838a8282a81dd2da71f42288fc4d861a7e2a16 Mon Sep 17 00:00:00 2001
From: hmoss <32096840+henrymoss@users.noreply.github.com>
Date: Sat, 21 Oct 2023 13:21:03 +0100
Subject: [PATCH] Revert "doesnt seem to be working"

This reverts commit e3ab411965c186dc8d0cdbd736e3f4d8ae25a9b1.
---
 docs/examples/barycentres.py              |   6 +-
 docs/examples/constructing_new_kernels.py |   7 +-
 docs/examples/graph_kernels.py            |   5 +-
 docs/examples/oceanmodelling.py           |   6 +-
 docs/examples/regression.py               |   8 +-
 gpjax/__init__.py                         |   3 +-
 gpjax/fit.py                              | 104 ----------------------
 tests/test_fit.py                         |  47 +---------
 8 files changed, 26 insertions(+), 160 deletions(-)

diff --git a/docs/examples/barycentres.py b/docs/examples/barycentres.py
index 25d872d5c..9ab7a2d83 100644
--- a/docs/examples/barycentres.py
+++ b/docs/examples/barycentres.py
@@ -137,11 +137,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
     likelihood = gpx.Gaussian(num_datapoints=n)
     posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood
 
-    opt_posterior, _ = gpx.fit_bfgs(
+    opt_posterior, _ = gpx.fit(
         model=posterior,
         objective=jax.jit(gpx.ConjugateMLL(negative=True)),
         train_data=D,
-        max_iters=500,
+        optim=ox.adamw(learning_rate=0.01),
+        num_iters=500,
+        key=key,
     )
     latent_dist = opt_posterior.predict(xtest, train_data=D)
     return opt_posterior.likelihood(latent_dist)
diff --git a/docs/examples/constructing_new_kernels.py b/docs/examples/constructing_new_kernels.py
index dabf66cf6..9355b614f 100644
--- a/docs/examples/constructing_new_kernels.py
+++ b/docs/examples/constructing_new_kernels.py
@@ -267,11 +267,14 @@ def __call__(
 likelihood = gpx.Gaussian(num_datapoints=n)
 circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood
 
-# Optimise GP's marginal log-likelihood using BFGS
-opt_posterior, history = gpx.fit_bfgs(
+# Optimise GP's marginal log-likelihood using Adam
+opt_posterior, history = gpx.fit(
     model=circular_posterior,
     objective=jit(gpx.ConjugateMLL(negative=True)),
     train_data=D,
+    optim=ox.adamw(learning_rate=0.05),
+    num_iters=500,
+    key=key,
 )
 
 # %% [markdown]
diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py
index 6c1884817..82154b3a4 100644
--- a/docs/examples/graph_kernels.py
+++ b/docs/examples/graph_kernels.py
@@ -153,10 +153,13 @@
 # With a posterior defined, we can now optimise the model's hyperparameters.
 
 # %%
-opt_posterior, training_history = gpx.fit_bfgs(
+opt_posterior, training_history = gpx.fit(
     model=posterior,
     objective=jit(gpx.ConjugateMLL(negative=True)),
     train_data=D,
+    optim=ox.adamw(learning_rate=0.01),
+    num_iters=1000,
+    key=key,
 )
 
 # %% [markdown]
diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py
index 4791d8bcc..3488d1640 100644
--- a/docs/examples/oceanmodelling.py
+++ b/docs/examples/oceanmodelling.py
@@ -248,10 +248,14 @@ def optimise_mll(posterior, dataset, NIters=1000, key=key, plot_history=True):
     objective = gpx.objectives.ConjugateMLL(negative=True)
     # Optimise to minimise the MLL
     optimiser = ox.adam(learning_rate=0.1)
-    opt_posterior, history = gpx.fit_bfgs(
+    opt_posterior, history = gpx.fit(
         model=posterior,
         objective=objective,
         train_data=dataset,
+        optim=optimiser,
+        num_iters=NIters,
+        safe=True,
+        key=key,
     )
     # plot MLL value at each iteration
     if plot_history:
diff --git a/docs/examples/regression.py b/docs/examples/regression.py
index 2d83b8536..bccbb8068 100644
--- a/docs/examples/regression.py
+++ b/docs/examples/regression.py
@@ -210,14 +210,18 @@
 # accelerate training.
 
 # %% [markdown]
-# We can now define an optimiser with `jaxopt`. For this example we'll use the `bfgs`
+# We can now define an optimiser with `optax`. For this example we'll use the `adam`
 # optimiser.
 
 # %%
-opt_posterior, history = gpx.fit_bfgs(
+opt_posterior, history = gpx.fit(
     model=posterior,
     objective=negative_mll,
     train_data=D,
+    optim=ox.adam(learning_rate=0.01),
+    num_iters=500,
+    safe=True,
+    key=key,
 )
 
 # %% [markdown]
diff --git a/gpjax/__init__.py b/gpjax/__init__.py
index d68e6685e..eca71e399 100644
--- a/gpjax/__init__.py
+++ b/gpjax/__init__.py
@@ -22,7 +22,7 @@
 )
 from gpjax.citation import cite
 from gpjax.dataset import Dataset
-from gpjax.fit import fit, fit_bfgs
+from gpjax.fit import fit
 from gpjax.gps import (
     Prior,
     construct_posterior,
@@ -87,7 +87,6 @@
     "decision_making",
     "kernels",
     "fit",
-    "fit_bfgs",
     "Prior",
     "construct_posterior",
     "integrators",
diff --git a/gpjax/fit.py b/gpjax/fit.py
index 7d7190252..69b6a699a 100644
--- a/gpjax/fit.py
+++ b/gpjax/fit.py
@@ -23,11 +23,9 @@
     Union,
 )
 import jax
-import jax.numpy as jnp
 from jax._src.random import _check_prng_key
 import jax.random as jr
 import optax as ox
-import jaxopt
 
 from gpjax.base import Module
 from gpjax.dataset import Dataset
@@ -170,100 +168,6 @@ def step(carry, key):
     model = model.constrain()
 
     return model, history
-
-
-
-
-
-def fit_bfgs(  # noqa: PLR0913
-    *,
-    model: ModuleModel,
-    objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
-    train_data: Dataset,
-    max_iters: Optional[int] = 500,
-    tol: float = 0.01,
-    verbose: Optional[bool] = True,
-    safe: Optional[bool] = True,
-) -> Tuple[ModuleModel, Array]:
-    r"""Train a Module model with respect to a supplied Objective function.
-    Optimisers used here should originate from Optax. todo
-
-    Args:
-        model (Module): The model Module to be optimised.
-        objective (Objective): The objective function that we are optimising with
-            respect to.
-        train_data (Dataset): The training data to be used for the optimisation.
-        max_iters (Optional[int]): The maximum number of optimisation steps to run. Defaults
-            to 500.
-        tol (Optional[float]): The tolerance for termination. Defaults to scipy default.
-        verbose (Optional[bool]): Whether to print the information about the optimisation. Defaults
-            to True.
-
-    Returns
-    -------
-        Tuple[Module, Array]: A Tuple comprising the optimised model and training
-            history respectively.
-    """
-    if safe:
-        # Check inputs.
-        _check_model(model)
-        _check_train_data(train_data)
-        _check_num_iters(max_iters)
-        _check_verbose(verbose)
-        _check_tol(tol)
-
-
-    # Unconstrained space model.
-    model = model.unconstrain()
-
-    # Unconstrained space loss function with stop-gradient rule for non-trainable params.
-    def loss(model: Module, data: Dataset) -> ScalarFloat:
-        model = model.stop_gradient()
-        return objective(model.constrain(), data)
-
-    solver = jaxopt.BFGS(
-        fun=loss, 
-        maxiter=max_iters, 
-        tol=tol, 
-        # method="L-BFGS-B",
-        #implicit_diff=False,
-        )
-
-    initial_loss = solver.fun(model, train_data)
-    model, result = solver.run(model, data = train_data)
-    history = jnp.array([initial_loss, result.value])
-    
-    if verbose:
-        print(f"Initial loss: {initial_loss}")
-        # if result.success:
-        #     print(f"Optimization was successful")
-        # else:
-        #     print(f"Optimization was not successful")
-        print(f"Final loss {result.value} after {result.num_fun_eval} iterations")
-        
-    # Constrained space.
-    model = model.constrain()
-    return model, history
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
 
 
 def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
@@ -321,14 +225,6 @@ def _check_log_rate(log_rate: Any) -> None:
     if not log_rate > 0:
         raise ValueError("log_rate must be positive")
 
-def _check_tol(tol: Any) -> None:
-    """Check that the tolerance is of type float and positive."""
-    if not isinstance(tol, float):
-        raise TypeError("tol must be of type float or None")
-    
-    if not tol > 0:
-        raise ValueError("tol must be positive")
-
 
 def _check_verbose(verbose: Any) -> None:
     """Check that the verbose is of type bool."""
diff --git a/tests/test_fit.py b/tests/test_fit.py
index 73966ae3f..6eeaed992 100644
--- a/tests/test_fit.py
+++ b/tests/test_fit.py
@@ -30,7 +30,6 @@
 from gpjax.dataset import Dataset
 from gpjax.fit import (
     fit,
-    fit_bfgs,
     get_batch,
 )
 from gpjax.gps import (
@@ -98,28 +97,6 @@ def step(self, model: LinearModel, train_data: Dataset) -> float:
     # Test stop_gradient on bias:
     assert trained_model.bias == 1.0
 
-    # Train with bfgs!
-    trained_model, hist = fit_bfgs(
-        model=model,
-        objective=loss,
-        train_data=D,
-        max_iters=10,
-    )
-
-    # Ensure we return a history of the correct length
-    assert len(hist) == 2
-
-    # Ensure we return a model of the same class
-    assert isinstance(trained_model, LinearModel)
-
-    # Test reduction in loss:
-    assert loss(trained_model, D) < loss(model, D)
-
-    # Test stop_gradient on bias:
-    assert trained_model.bias == 1.0
-
-
-
 
 @pytest.mark.parametrize("num_iters", [1, 5])
 @pytest.mark.parametrize("n_data", [1, 20])
@@ -141,7 +118,7 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N
     # Define loss function:
     mll = ConjugateMLL(negative=True)
 
-    # Train with optax!
+    # Train!
     trained_model, history = fit(
         model=posterior,
         objective=mll,
@@ -161,28 +138,6 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N
     # Ensure we reduce the loss
     assert mll(trained_model, D) < mll(posterior, D)
 
-    # Train with BFGS!
-    trained_model_bfgs, history_bfgs = fit_bfgs(
-        model=posterior,
-        objective=mll,
-        train_data=D,
-        max_iters=num_iters,
-        verbose=verbose,
-    )
-
-    # Ensure the trained model is a Gaussian process posterior
-    assert isinstance(trained_model_bfgs, ConjugatePosterior)
-
-    # Ensure we return a history_bfgs of the correct length
-    assert len(history_bfgs) == 2
-
-    # Ensure we reduce the loss
-    assert mll(trained_model_bfgs, D) < mll(posterior, D)
-
-
-
-
-
 
 @pytest.mark.parametrize("num_iters", [1, 5])
 @pytest.mark.parametrize("batch_size", [1, 20, 50])