-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non-JAX likelihoods #59
Comments
For now, this is on hold as it would require a fair amount of work. For those wishing to use JAXNS when they don't have a JAX-based likelihood, you should consider using a Gaussian process, or similar to learn the likelihood. I think this would be a good example to add to the list, since I'm having many people ask about that. |
This is now possible using JAX's new pure callback functionality. A wrapper is now implemented in jaxns allowing non-JAX likelihoods. ...
from jaxns import jaxify_likelihood
def prior_model():
x = yield Prior(tfpd.Uniform(), name='x')
return x
@jaxify_likelihood
def log_likelihood(x):
return x
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10) This works with all types of priors including ones with parametrised priors variables, and neural network based priors. JIT and AOT compilation work on the resulting calcuations. However, some limitations:
Vectorisable likelihoodsYou can improve performance if the likelihood handles a batch dimension. It must follow this semantic: If the arguments to the likelihood have a leading batch dimension, then the output has the same leading dimension. The outputs are independent in the sense that elements from different batches have no impact on the outputs of different batches. from functools import partial
@partial(jaxify_likelihood, vectorised=True)
def log_likelihood(x):
return x |
Is your feature request related to a problem? Please describe.
A problem that comes up frequently in physics, biology, and other sciences are that the likelihoods require special simulation code that are simply not-expressible in JAX. This makes it virtually impossible to use Jaxns.
Describe the solution you'd like
An abstraction of the backend, so that non-JAX likelihoods can be used. Good candidates for this abstraction would be:
numpy
,tensorflow
, andpytorch
. This would enable a larger suite of projects to use Jaxns without conforming their likelihood to JAX.Describe alternatives you've considered
One option is to use
disable_jit
, however this will undoubtedly bring the evaluation to a grinding halt due to the massive overhead of dispatching.Edit:
You can use
jaxify_likelihood
as shown below: #59 (comment)The text was updated successfully, but these errors were encountered: