Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose all nutpie compile backends through pm.sample #7497

Open
jessegrabowski opened this issue Sep 10, 2024 · 3 comments · May be fixed by #7498
Open

Expose all nutpie compile backends through pm.sample #7497

jessegrabowski opened this issue Sep 10, 2024 · 3 comments · May be fixed by #7498

Comments

@jessegrabowski
Copy link
Member

Description

Nutpie currently has two compile modes, numba and JAX, with a 3rd pytorch backend on the way. It would be nice if we could easily access these via pm.sample.

Proposal 1: Allow nutpie.compile_pymc kwargs in nuts_sampler_kwargs

  • Pros: It's easy, since there are only two such arguments: backend and gradient-backend. We just check for and pop them before forwarding all other arguments to nutpie.sample.
  • Cons: It might be see as "unexpected" behavior, since some keywords are going to one function, and some to another. Also, the nuts_sampler_kwargs argument isn't very beautiful in the first place

Proposal 2: pip-style optional arguments, like nuts_sampler="nutpie[jax]" and nuts_sampler="nutpie[numba]"

  • Pros: It's quite pretty!
  • Cons: technically you can pick both the forward and backward compile mode, so if a user wanted that, she'd still have to import nutpie and do it manually. Maybe that's enough of a corner case that it's ok? Also it's a different API to other samplers (although blackjax could benefit from something similar to ask for the many different options over there -- but that's beyond the scope here).

Proposal 3: Add a new compile_kwargs argument to pm.sample

  • Pros: It's very clear. It could be used to forward kwargs to pytensor as well, which is a nice side bonus.
  • Cons: It's another argument to an already bloated pm.sample function
@ricardoV94
Copy link
Member

I like 2, with just "nutpie" defaulting to numba for back compat

@lucianopaz
Copy link
Contributor

I think that compile_kwargs is the most flexible way to go by far. It's also similar to the API that we provide for every other sampling function. The pip style does look very slick and user friendly though.

@jessegrabowski
Copy link
Member Author

Unhappily we might want both? It's a bit of an anti-pattern to forward compile_kwargs to nutpie in pm.sample only, since in every other function they only go to pytensor.function. But as discussed in the linked issue, we might want to allow compile kwargs to access alternative backends in the default pymc sampler.

I suspect a lot of users could see speedups from just compiling their model's logp/dogp to JAX/numba mode, without adding an additional dependency to their projects.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants