Skip to content

Commit

Permalink
still broken
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss committed Sep 26, 2023
1 parent 0a09d61 commit f46288e
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from dataclasses import (
asdict,

Check failure on line 16 in gpjax/fit.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

gpjax/fit.py:16:5: F401 `dataclasses.asdict` imported but unused
dataclass,

Check failure on line 17 in gpjax/fit.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

gpjax/fit.py:17:5: F401 `dataclasses.dataclass` imported but unused
)

from beartype.typing import (
Any,
Expand Down Expand Up @@ -131,12 +134,13 @@ def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:

# Initialise solver state.
solver.fun = _wrap_objective(solver.fun)
if hasattr(solver, "options"): # allow __post_init__ without weird jaxopt error
solver.options.pop("maxiter", None)
solver.__post_init__() # needed to propagate changes to `fun` attribute

if isinstance(solver, OptaxSolver): # hack for Optax compatibility
model = jax.tree_map(lambda x: x.astype(jnp.float64), model)
# # elif isinstance(solver, ScipyMinimize): # hack for jaxopt compatibility
# del solver.options["maxiter"]

solver.__post_init__() # needed to propagate changes to `fun` attribute

if isinstance(solver, OptaxSolver): # For optax, run optimization by step
solver_state = solver.init_state(
Expand Down

0 comments on commit f46288e

Please sign in to comment.