You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Nutpie currently has two compile modes, numba and JAX, with a 3rd pytorch backend on the way. It would be nice if we could easily access these via pm.sample.
Proposal 1: Allow nutpie.compile_pymc kwargs in nuts_sampler_kwargs
Pros: It's easy, since there are only two such arguments: backend and gradient-backend. We just check for and pop them before forwarding all other arguments to nutpie.sample.
Cons: It might be see as "unexpected" behavior, since some keywords are going to one function, and some to another. Also, the nuts_sampler_kwargs argument isn't very beautiful in the first place
Proposal 2: pip-style optional arguments, like nuts_sampler="nutpie[jax]" and nuts_sampler="nutpie[numba]"
Pros: It's quite pretty!
Cons: technically you can pick both the forward and backward compile mode, so if a user wanted that, she'd still have to import nutpie and do it manually. Maybe that's enough of a corner case that it's ok? Also it's a different API to other samplers (although blackjax could benefit from something similar to ask for the many different options over there -- but that's beyond the scope here).
Proposal 3: Add a new compile_kwargs argument to pm.sample
Pros: It's very clear. It could be used to forward kwargs to pytensor as well, which is a nice side bonus.
Cons: It's another argument to an already bloated pm.sample function
The text was updated successfully, but these errors were encountered:
I think that compile_kwargs is the most flexible way to go by far. It's also similar to the API that we provide for every other sampling function. The pip style does look very slick and user friendly though.
Unhappily we might want both? It's a bit of an anti-pattern to forward compile_kwargs to nutpie in pm.sample only, since in every other function they only go to pytensor.function. But as discussed in the linked issue, we might want to allow compile kwargs to access alternative backends in the default pymc sampler.
I suspect a lot of users could see speedups from just compiling their model's logp/dogp to JAX/numba mode, without adding an additional dependency to their projects.
Description
Nutpie currently has two compile modes, numba and JAX, with a 3rd pytorch backend on the way. It would be nice if we could easily access these via
pm.sample
.Proposal 1: Allow
nutpie.compile_pymc
kwargs innuts_sampler_kwargs
backend
andgradient-backend
. We just check for and pop them before forwarding all other arguments tonutpie.sample
.nuts_sampler_kwargs
argument isn't very beautiful in the first placeProposal 2: pip-style optional arguments, like
nuts_sampler="nutpie[jax]"
andnuts_sampler="nutpie[numba]
"nutpie
and do it manually. Maybe that's enough of a corner case that it's ok? Also it's a different API to other samplers (although blackjax could benefit from something similar to ask for the many different options over there -- but that's beyond the scope here).Proposal 3: Add a new
compile_kwargs
argument topm.sample
pm.sample
functionThe text was updated successfully, but these errors were encountered: