-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX-based find_MAP
function
#385
base: main
Are you sure you want to change the base?
Conversation
Hey, nice one, yeah I agree, we should only have one
The current Current behaviour of
The behaviour when you only pass a subset of variables isn't really desirable in my opinion (see #345 (comment)), so we put a warning. So as you say:
Agree, that's the best plan for Judging by your docs and a quick glance at your code, I think you're basically doing the same thing. The current implementation is few lines of code and a few docs, so I reckon
Then it should be safe to delete the existing code and we can go back to one
I would love a generic optimiser in p u r e pytensor, but I can see looking at your code that there a lot of fancy extras that would take a large effort to write in pytensor. Still, if we want to go back to one of our efforts with a fixed point operator (pymc-devs/pytensor#978 and pymc-devs/pytensor#944), we could probably write Happy to look at your code and review properly later in the week if you'd like me to. Let me know. Otherwise, I'll leave to the core devs. |
out_shapes[rv] = tuple( | ||
[len(model.coords[dim]) for dim in model.named_vars_to_dims.get(rv.name, [])] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if an rv doesn't have dims specified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there already a helper in pymc to figure out shapes of an rv?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you froze the dims you should have everything in rv.type.shape
. Otherwise we usually use model.initial_point
If None, no jitter is added. Default is 1e-8. | ||
progressbar : bool | ||
Whether or not to display progress bar. Default is True. | ||
mode : str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No option is being given besides JAX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Obviously because it's all you need : )
The posterior distribution will be approximated as a Gaussian | ||
distribution centered at the posterior mode. | ||
The covariance is the inverse of the negative Hessian matrix of | ||
the log-posterior evaluated at the mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? Less ambiguous since there is a mode
argument
the log-posterior evaluated at the mode. | |
the log-posterior evaluated at the optimized_point. |
That would be appreciated |
Agree with what @theorashid said. This |
No objections about your custom library wrapper |
This PR adds code to run
find_MAP
using JAX. I'm using JAX for gradients, because I found the compile times were faster. Open to suggestions/rebuke.It also adds a
fit_laplace
function, which is bad because we already have afit_laplace
function. This one has slightly different objective though -- it isn't meant to be used as a step sampler on a subset of model variables. Instead, it is meant to be used on the MAP result to give an approximation to the full posterior. My function also lets you do the Laplace approximation in the transformed space, then do sample-wise reverse transformation. I think this is legit, and lets you obtain approximate posteriors that respect the domain of the prior. Tagging @theorashid so we can resolve the differences.Last point is that I added a dependency on
better_optimize
. This is a package I wrote that basically rips out the wrapper code used in PyMCfind_MAP
and applies it to arbitrary optimization problems. It is more feature complete than the PyMC wrapper -- it supports all optimizer modes forscipy.optimize.minimize
andscipy.optimize.root
, and also helps get keywords to the right place in those functions (who can ever remember if an argument goes inmethod_kwargs
or in the funciton itself?). I plan to add support forbasinhopping
as well, which will be nice for really hairy minimizations.I could see an objection to adding another dependency, but 1) it's a lightweight wrapper around functionality that doesn't really belong in PyMC anyway, and 2) it's a big value-add compared to working directly with the
scipy.optimize
functions, which have gnarly, inconsistent signatures.