Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX-based find_MAP function #385

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

This PR adds code to run find_MAP using JAX. I'm using JAX for gradients, because I found the compile times were faster. Open to suggestions/rebuke.

It also adds a fit_laplace function, which is bad because we already have a fit_laplace function. This one has slightly different objective though -- it isn't meant to be used as a step sampler on a subset of model variables. Instead, it is meant to be used on the MAP result to give an approximation to the full posterior. My function also lets you do the Laplace approximation in the transformed space, then do sample-wise reverse transformation. I think this is legit, and lets you obtain approximate posteriors that respect the domain of the prior. Tagging @theorashid so we can resolve the differences.

Last point is that I added a dependency on better_optimize. This is a package I wrote that basically rips out the wrapper code used in PyMC find_MAP and applies it to arbitrary optimization problems. It is more feature complete than the PyMC wrapper -- it supports all optimizer modes for scipy.optimize.minimize and scipy.optimize.root, and also helps get keywords to the right place in those functions (who can ever remember if an argument goes in method_kwargs or in the funciton itself?). I plan to add support for basinhopping as well, which will be nice for really hairy minimizations.

I could see an objection to adding another dependency, but 1) it's a lightweight wrapper around functionality that doesn't really belong in PyMC anyway, and 2) it's a big value-add compared to working directly with the scipy.optimize functions, which have gnarly, inconsistent signatures.

@theorashid
Copy link
Contributor

Hey, nice one, yeah I agree, we should only have one fit_laplace function.

it isn't meant to be used as a step sampler on a subset of model variables

The current fit_laplace isn't either. It isn't a step sampler. (The INLA stuff #340 still has a few blockers so that's separate and not yet in the library.) The implementation was made by a user following Statistical Rethinking, where McElreath fits some models using the Laplace approximation of all parameters.

Current behaviour of fit_laplace is:

The behaviour when you only pass a subset of variables isn't really desirable in my opinion (see #345 (comment)), so we put a warning. So as you say:

Instead, it is meant to be used on the MAP result to give an approximation to the full posterior.

Agree, that's the best plan for fit_laplace.

Judging by your docs and a quick glance at your code, I think you're basically doing the same thing. The current implementation is few lines of code and a few docs, so I reckon

  1. make sure you can pass the test case with your method, which is an example from BDA3 https://github.com/pymc-devs/pymc-experimental/blob/main/tests/test_laplace.py
  2. throw any of the useful code and docs into your method

Then it should be safe to delete the existing code and we can go back to one fit_laplace.

I could see an objection to adding another dependency

I would love a generic optimiser in p u r e pytensor, but I can see looking at your code that there a lot of fancy extras that would take a large effort to write in pytensor. Still, if we want to go back to one of our efforts with a fixed point operator (pymc-devs/pytensor#978 and pymc-devs/pytensor#944), we could probably write find_MAP with that in some form, with fewer bells and whistles though.

Happy to look at your code and review properly later in the week if you'd like me to. Let me know. Otherwise, I'll leave to the core devs.

Comment on lines +72 to +74
out_shapes[rv] = tuple(
[len(model.coords[dim]) for dim in model.named_vars_to_dims.get(rv.name, [])]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if an rv doesn't have dims specified?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there already a helper in pymc to figure out shapes of an rv?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you froze the dims you should have everything in rv.type.shape. Otherwise we usually use model.initial_point

If None, no jitter is added. Default is 1e-8.
progressbar : bool
Whether or not to display progress bar. Default is True.
mode : str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No option is being given besides JAX

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously because it's all you need : )

The posterior distribution will be approximated as a Gaussian
distribution centered at the posterior mode.
The covariance is the inverse of the negative Hessian matrix of
the log-posterior evaluated at the mode.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? Less ambiguous since there is a mode argument

Suggested change
the log-posterior evaluated at the mode.
the log-posterior evaluated at the optimized_point.

@ricardoV94
Copy link
Member

Happy to look at your code and review properly later in the week if you'd like me to. Let me know. Otherwise, I'll leave to the core devs.

That would be appreciated

@ricardoV94
Copy link
Member

Agree with what @theorashid said. This fit_laplace is going for the same goal as the previous one. Happy to replace it, if it's not married to JAX backend. Still fine to allow using JAX for the autodiff. What you're offering is very similar to nutpie gradient_backend kwarg, so we could use the same terminology

@ricardoV94
Copy link
Member

No objections about your custom library wrapper

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants