-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model fitting using jaxopt solvers #364
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @frazane! A high-level comment on the notebook updates. I see you've updated the code, but can ensure that the supporting text is aligned. In the regression
notebook, for example, I think there should be changes
gpjax/fit.py
Outdated
if isinstance(solver, jaxopt.OptaxSolver): | ||
model = jax.tree_map(lambda x: x.astype(jnp.float64), model) | ||
|
||
# Initialise solver state. | ||
solver.fun = _wrap_objective(solver.fun) | ||
solver.__post_init__() # needed to propagate changes to `fun` attribute | ||
|
||
solver_state = solver.init_state( | ||
model, | ||
get_batch(train_data, batch_size, key) if batch_size != -1 else train_data, | ||
) | ||
jitted_update = jax.jit(solver.update) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have any additional unit tests to check this block is running correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a good idea, yes. However I am not convinced about this piece of code (particularly lines 131-132) to start with, and I am open to suggestions.
I wonder if instead of wrapping the objective function inside fit
like so, we could specify whether to do constraints and stopping gradients when instantiating the objective. Something like:
nmll = gpx.ConjugateMLL(negative=True, stop_gradients=True, constrain_model=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, my concern with this approach is that it demands a lot of the user. If I've passed a bijector or a trainability status to my parameters, then I would not expect to have to explicitly apply this functionality again later on. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your concerns. For now I'll keep it as is and add a test to check that specific block 👍
I've taken this on so we can merge in @frazane's absence. I've fiddled it so it all works with the new BO code and it seems to pass all the tests locally. I can;t work out why it wont pass on the CI though. Please can someone take a quick look at the CI logs and explain them to me @thomaspinder or @daniel-dodd ? |
Nice work @henrymoss. I think we need to remove the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have reviewed the code. Other than the aforementioned issue, this PR looks excellent to me - will give an approval ahead of fixing this, feel free to merge as and when the tests pass. Thanks, @henrymoss, @frazane. :)
eeadba6
to
1d51c54
Compare
@daniel-dodd I actually think something is really wrong with this. I was updating the notebooks to use the LBFGS instead of Adam and I couldn't get good performances on the ocean or decision maker notebook. In the decision maker notebook, I was getting weird things like the plots of the posterior became invisible ..... |
Also @daniel-dodd , If you use the new version of Adam (e.g. in the decision maker notebook), model fits seem much worse than with the old version of Adam |
Ah strange @henrymoss! |
@henrymoss fixed it! In the old code each time we call In the new code, we pass |
Right then @thomaspinder and @daniel-dodd . I have rejigged this somewhat. When I went to update some of the notebooks to use the jaxopt lbfgs, everything got worse, i.e. it was even worse than ADAM!!! Turns out that the jaxopt lbfgs is its own implementation that lots of people say is a bit naff (similar to the torch and tensorflow versions). I have instead built a wrapper to go to Scipy (still through jaxopt), but now using Scipy's LBFGS directly. We now have a model fitter than works really really well and has resulted in a lot of the notebooks running a lot faster. For example, the barycentre model fits take 4 secs to get to lower loss than after 25 secs of Adam. I had to rejig the code a little bit, because the Scipy optimisers just optimise in one go (rather than the step by step of the optax ones). This is actually great, because it means we dont have to specify a I have also strengthened the typing a bit. We now only support the optax and scipy wrappers from jaxopt. Before, the code suggested that we supported all IterativeSolvers from jaxopt, but this was nonsense. We never actually checked this and all the different solvers will likely require weird hacks to get to work. |
Thanks for picking this up @henrymoss. It looks like some of the documentation is failing to build - would you mind looking into this and then, once resolved, we should be good to merge. |
I cant recreat this error on my machine :( @daniel-dodd reckons its something to do with macs.... |
@henrymoss This is not what I thought earlier. I thought we were passing through the same solver to the posterior, and returning this from Actually what is happening is side-effect behaviour. When we call |
JAXopt developer here. JAXopt's LBFGS is getting better and better. We try to fix things as people report issues. The issue is that optimization with linesearch in float32 precision is very hard. Using scipy's LBFGS means that float64 precision is used, while JAX uses float32 by default. If people compare JAXopt's LBFGS and scipy's LBFGS, they should compare both with float64. Please do report issues when you encounter them. |
Also, if you keep encountering issues in jaxopt, we would be very grateful if you can report them (ideally with a reproducible test 🙏🏼 ) |
I don't know if you have seen this, Optimistix is a JAX library for nonlinear solvers: root finding, minimisation, fixed points, and least squares. Reasons to use Optimistix rather than JAXopt:
|
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
Model training using
fit
can now be called with a much wider choice of optimization algorithms provided asjaxopt
solvers. These also includeoptax
solvers (which are currently used).