Skip to content

Commit

Permalink
u2d
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss committed Oct 22, 2023
1 parent 6eca0ec commit 8bbce75
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 18 deletions.
7 changes: 0 additions & 7 deletions gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@
)
from gpjax.citation import cite
from gpjax.dataset import Dataset
<<<<<<< HEAD
from gpjax.fit import (
fit,
fit_scipy,
)
=======
from gpjax.fit import fit
>>>>>>> main
from gpjax.gps import (
Prior,
construct_posterior,
Expand Down Expand Up @@ -94,10 +90,7 @@
"decision_making",
"kernels",
"fit",
<<<<<<< HEAD
"fit_scipy",
=======
>>>>>>> main
"Prior",
"construct_posterior",
"integrators",
Expand Down
3 changes: 0 additions & 3 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
from jax._src.random import _check_prng_key
import jax.numpy as jnp
import jax.random as jr
<<<<<<< HEAD
import jaxopt
=======
>>>>>>> main
import optax as ox

from gpjax.base import Module
Expand Down
11 changes: 3 additions & 8 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from gpjax.dataset import Dataset
from gpjax.fit import (
fit,
fit_scipy,
get_batch,
)
from gpjax.gps import (
Expand Down Expand Up @@ -97,9 +98,8 @@ def step(self, model: LinearModel, train_data: Dataset) -> float:
# Test stop_gradient on bias:
assert trained_model.bias == 1.0

<<<<<<< HEAD
# Train with bfgs!
trained_model, hist = fit_bfgs(
trained_model, hist = fit_scipy(
model=model,
objective=loss,
train_data=D,
Expand All @@ -118,8 +118,6 @@ def step(self, model: LinearModel, train_data: Dataset) -> float:
# Test stop_gradient on bias:
assert trained_model.bias == 1.0

=======
>>>>>>> main

@pytest.mark.parametrize("num_iters", [1, 5])
@pytest.mark.parametrize("n_data", [1, 20])
Expand Down Expand Up @@ -161,9 +159,8 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N
# Ensure we reduce the loss
assert mll(trained_model, D) < mll(posterior, D)

<<<<<<< HEAD
# Train with BFGS!
trained_model_bfgs, history_bfgs = fit_bfgs(
trained_model_bfgs, history_bfgs = fit_scipy(
model=posterior,
objective=mll,
train_data=D,
Expand All @@ -180,8 +177,6 @@ def test_gaussian_process_regression(num_iters, n_data: int, verbose: bool) -> N
# Ensure we reduce the loss
assert mll(trained_model_bfgs, D) < mll(posterior, D)

=======
>>>>>>> main

@pytest.mark.parametrize("num_iters", [1, 5])
@pytest.mark.parametrize("batch_size", [1, 20, 50])
Expand Down

0 comments on commit 8bbce75

Please sign in to comment.