Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-JAX likelihoods #59

Closed
Joshuaalbert opened this issue May 5, 2022 · 2 comments
Closed

Non-JAX likelihoods #59

Joshuaalbert opened this issue May 5, 2022 · 2 comments
Labels
enhancement New feature or request
Milestone

Comments

@Joshuaalbert
Copy link
Owner

Joshuaalbert commented May 5, 2022

Is your feature request related to a problem? Please describe.
A problem that comes up frequently in physics, biology, and other sciences are that the likelihoods require special simulation code that are simply not-expressible in JAX. This makes it virtually impossible to use Jaxns.

Describe the solution you'd like
An abstraction of the backend, so that non-JAX likelihoods can be used. Good candidates for this abstraction would be: numpy, tensorflow, and pytorch. This would enable a larger suite of projects to use Jaxns without conforming their likelihood to JAX.

Describe alternatives you've considered
One option is to use disable_jit, however this will undoubtedly bring the evaluation to a grinding halt due to the massive overhead of dispatching.

from jax import disable_jit
with disable_jit():
  ns = NestedSampler(log_likelihood, prior_chain)
  results = ns(random.PRNGKey(42))

Edit:

You can use jaxify_likelihood as shown below: #59 (comment)

@Joshuaalbert Joshuaalbert added the enhancement New feature or request label May 5, 2022
@Joshuaalbert Joshuaalbert added this to the Release 1.2 milestone May 5, 2022
@Joshuaalbert
Copy link
Owner Author

For now, this is on hold as it would require a fair amount of work. For those wishing to use JAXNS when they don't have a JAX-based likelihood, you should consider using a Gaussian process, or similar to learn the likelihood. I think this would be a good example to add to the list, since I'm having many people ask about that.

@Joshuaalbert
Copy link
Owner Author

This is now possible using JAX's new pure callback functionality. A wrapper is now implemented in jaxns allowing non-JAX likelihoods.

...
from jaxns import jaxify_likelihood

def prior_model():
    x = yield Prior(tfpd.Uniform(), name='x')
    return x

@jaxify_likelihood
def log_likelihood(x):
    return x

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10)

This works with all types of priors including ones with parametrised priors variables, and neural network based priors. JIT and AOT compilation work on the resulting calcuations. However, some limitations:

  1. The likelihood must be a pure function. I.e. it should be deterministic, and there should be no side effects from running it.
  2. The speed of the resulting nested sampling will depend on the performance of the likelihood function.
  3. You cannot use parameters inside your likelihood function, e.g. neural networks who's parameters are learned by JAXNS to maximise evidence.

Vectorisable likelihoods

You can improve performance if the likelihood handles a batch dimension. It must follow this semantic: If the arguments to the likelihood have a leading batch dimension, then the output has the same leading dimension. The outputs are independent in the sense that elements from different batches have no impact on the outputs of different batches.

from functools import partial

@partial(jaxify_likelihood, vectorised=True)
def log_likelihood(x):
    return x

@Joshuaalbert Joshuaalbert reopened this May 15, 2024
Joshuaalbert added a commit that referenced this issue May 15, 2024
* bump to 2.5.0
Joshuaalbert added a commit that referenced this issue May 15, 2024
@Joshuaalbert Joshuaalbert changed the title Abstraction of backend for non-JAX likelihoods Non-JAX likelihoods May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant