-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jaxopt 2 #402
Conversation
Some Qs:
|
Thanks @henrymoss this PR looks great to me. In the barycentres, it will actually break on: fit_gp(x, ys[0])
fit_gp(x, ys[1]) Doing It would probably be safer to remove these following lines of code on ...
def __hash__(self):
return hash(tuple(jtu.tree_leaves(self))) # Probably put this on the Module!
def __call__(self, *args, **kwargs) -> ScalarFloat:
return self.step(*args, **kwargs) So that Objectives could still be passed as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super-duper. 🍾
Just so we make some progress and we support BFGS in some capacity (we really should!). I have done a quick PR where we have a separate fit function
fit_scipy
that uses scipy (BFGS, or L-BFGS depending on problem size), for those that are keen to do so. This is a very small non-breaking change unlike the other PR which was breaking and involved 10x the code.Note, I also played around with optimistix, but it didnt really make things easier.
@daniel-dodd , by having it as a separate function, I only need to instantiate the jaxopt thingy once and so no immediate horrors pop up.
A couple of interesting observations when I update the notebooks with the scipy optimzier.
One weird thing for @daniel-dodd. In the barycentre notebook, I have issues with a strange error popping up due to calling fit a few times in a for loop. Its entirely incomprehensible.