Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model fitting using jaxopt solvers #364

Closed
wants to merge 25 commits into from
Closed

Model fitting using jaxopt solvers #364

wants to merge 25 commits into from

Conversation

frazane
Copy link
Contributor

@frazane frazane commented Aug 22, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Model training using fit can now be called with a much wider choice of optimization algorithms provided as jaxopt solvers. These also include optax solvers (which are currently used).

@frazane frazane added the enhancement New feature or request label Aug 22, 2023
@frazane frazane requested a review from henrymoss August 22, 2023 10:42
@frazane frazane self-assigned this Aug 22, 2023
@frazane frazane marked this pull request as ready for review August 23, 2023 10:16
@frazane frazane requested a review from thomaspinder August 23, 2023 11:31
Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @frazane! A high-level comment on the notebook updates. I see you've updated the code, but can ensure that the supporting text is aligned. In the regression notebook, for example, I think there should be changes

gpjax/fit.py Outdated
Comment on lines 127 to 138
if isinstance(solver, jaxopt.OptaxSolver):
model = jax.tree_map(lambda x: x.astype(jnp.float64), model)

# Initialise solver state.
solver.fun = _wrap_objective(solver.fun)
solver.__post_init__() # needed to propagate changes to `fun` attribute

solver_state = solver.init_state(
model,
get_batch(train_data, batch_size, key) if batch_size != -1 else train_data,
)
jitted_update = jax.jit(solver.update)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have any additional unit tests to check this block is running correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good idea, yes. However I am not convinced about this piece of code (particularly lines 131-132) to start with, and I am open to suggestions.

I wonder if instead of wrapping the objective function inside fit like so, we could specify whether to do constraints and stopping gradients when instantiating the objective. Something like:

nmll = gpx.ConjugateMLL(negative=True, stop_gradients=True, constrain_model=True)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, my concern with this approach is that it demands a lot of the user. If I've passed a bijector or a trainability status to my parameters, then I would not expect to have to explicitly apply this functionality again later on. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your concerns. For now I'll keep it as is and add a test to check that specific block 👍

@frazane frazane removed the request for review from henrymoss August 24, 2023 08:49
@henrymoss
Copy link
Contributor

I've taken this on so we can merge in @frazane's absence.

I've fiddled it so it all works with the new BO code and it seems to pass all the tests locally.

I can;t work out why it wont pass on the CI though. Please can someone take a quick look at the CI logs and explain them to me @thomaspinder or @daniel-dodd ?

@daniel-dodd
Copy link
Member

daniel-dodd commented Sep 18, 2023

Nice work @henrymoss.

I think we need to remove the plum dependancy! (It conflicts with cola's plum). Doing that should (🤞) fix the issue!

Copy link
Member

@daniel-dodd daniel-dodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have reviewed the code. Other than the aforementioned issue, this PR looks excellent to me - will give an approval ahead of fixing this, feel free to merge as and when the tests pass. Thanks, @henrymoss, @frazane. :)

@henrymoss
Copy link
Contributor

@daniel-dodd I actually think something is really wrong with this.

I was updating the notebooks to use the LBFGS instead of Adam and I couldn't get good performances on the ocean or decision maker notebook.

In the decision maker notebook, I was getting weird things like the plots of the posterior became invisible .....

@henrymoss
Copy link
Contributor

Also @daniel-dodd , If you use the new version of Adam (e.g. in the decision maker notebook), model fits seem much worse than with the old version of Adam

@daniel-dodd
Copy link
Member

Ah strange @henrymoss!

gpjax/fit.py Outdated Show resolved Hide resolved
gpjax/fit.py Outdated Show resolved Hide resolved
@daniel-dodd
Copy link
Member

daniel-dodd commented Sep 19, 2023

@henrymoss fixed it!

In the old code each time we call fit we create an (optax) optimiser state - that remains inside within the scope of fit therefore and gets discarded after optimisation.

In the new code, we pass solver. This has the (optax) optimiser state stored on it, and it was carried across over the data acquisition -> whereas we needed to discard / refresh it, since the data changed between data acquisitions.

@henrymoss
Copy link
Contributor

Right then @thomaspinder and @daniel-dodd . I have rejigged this somewhat. When I went to update some of the notebooks to use the jaxopt lbfgs, everything got worse, i.e. it was even worse than ADAM!!!

Turns out that the jaxopt lbfgs is its own implementation that lots of people say is a bit naff (similar to the torch and tensorflow versions). I have instead built a wrapper to go to Scipy (still through jaxopt), but now using Scipy's LBFGS directly.

We now have a model fitter than works really really well and has resulted in a lot of the notebooks running a lot faster. For example, the barycentre model fits take 4 secs to get to lower loss than after 25 secs of Adam.

I had to rejig the code a little bit, because the Scipy optimisers just optimise in one go (rather than the step by step of the optax ones). This is actually great, because it means we dont have to specify a maxiter, which, to me, is one of the advantages of LBFGS.

I have also strengthened the typing a bit. We now only support the optax and scipy wrappers from jaxopt. Before, the code suggested that we supported all IterativeSolvers from jaxopt, but this was nonsense. We never actually checked this and all the different solvers will likely require weird hacks to get to work.

@thomaspinder
Copy link
Collaborator

Thanks for picking this up @henrymoss. It looks like some of the documentation is failing to build - would you mind looking into this and then, once resolved, we should be good to merge.

@henrymoss
Copy link
Contributor

Thanks for picking this up @henrymoss. It looks like some of the documentation is failing to build - would you mind looking into this and then, once resolved, we should be good to merge.

I cant recreat this error on my machine :( @daniel-dodd reckons its something to do with macs....

@daniel-dodd
Copy link
Member

@henrymoss This is not what I thought earlier. I thought we were passing through the same solver to the posterior, and returning this from fit.

Actually what is happening is side-effect behaviour. When we call fit the solver gets modified outsider the function. This is a bug.

@mblondel
Copy link

mblondel commented Oct 4, 2023

JAXopt developer here. JAXopt's LBFGS is getting better and better. We try to fix things as people report issues. The issue is that optimization with linesearch in float32 precision is very hard. Using scipy's LBFGS means that float64 precision is used, while JAX uses float32 by default. If people compare JAXopt's LBFGS and scipy's LBFGS, they should compare both with float64. Please do report issues when you encounter them.

@fabianp
Copy link

fabianp commented Oct 4, 2023

Also, if you keep encountering issues in jaxopt, we would be very grateful if you can report them (ideally with a reproducible test 🙏🏼 )

@adam-hartshorne
Copy link

adam-hartshorne commented Oct 17, 2023

I don't know if you have seen this,

Optimistix is a JAX library for nonlinear solvers: root finding, minimisation, fixed points, and least squares.
https://docs.kidger.site/optimistix/

Reasons to use Optimistix rather than JAXopt:

  • Optimistix is much faster to compile, and faster to run.
  • Optimistix supports some solvers not found in JAXopt (e.g. optimistix.Newton for root-finding problems).
  • Optimistix's APIs will integrate more cleanly with the scientific ecosystem being built up around Equinox.
  • Optimistix is much more flexible for advanced use-cases, see e.g. the way we can mix-and-match different optimisers.

https://docs.kidger.site/optimistix/faq/

@daniel-dodd daniel-dodd closed this Nov 7, 2023
@frazane frazane deleted the jaxopt_fit branch November 15, 2023 08:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants