-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multistart Parallelization with Jit Compatible Code #493
Comments
Hey @nikithiel, Thanks for opening the issue! Could you show me a minimal reproducible example of the behavior mentioned? And could you also post the versions of estimagic, joblib, and jax/jaxlib you use? That would be very helpful in solving your issue! |
Hey @timmens, I'm using the following versions: estimagic=0.4.6 This warning is occuring when I'm running my code on a linux HPC, so I think it's related to that. I found two very interesting posts: https://discuss.python.org/t/switching-default-multiprocessing-context-to-spawn-on-posix-as-well/21868/22 It seems like in I'm not sure how to force joblib to use I could also try to create an MWE. However, this is not so straighforward, as the problem occurs in a large code project. Hope this helps, |
I accidently closed this issue. Sorry! |
It's probably just a check in Jax whether fork has been called. Happens to me in a project with pytask-parallel and Jax recently, too. |
I've found a minimal reproducible example using JAX and joblib and a way to fix it (in the MRE). As you correctly anticipated @nikithiel , choosing a different parallelization backend fixes the MRE problem. If you want to validate that this fixes your problem, you could use a local estimagic installation to add the Additionally, you can always run the multistart in serial using Note The following is tested on my Linux ThinkPad and might not work on your HPC machine. @janosg, I propose we add an option to the batch_evaluator for custom kwargs and allow these to be passed through the multistart_options. What are your thoughts? Minimal Reproducible Exampleimport jax.numpy as jnp
from joblib import Parallel, delayed
x_list = [jnp.ones(2) for _ in range(2)]
# Backend: loky (results in a warning)
Parallel(n_jobs=2, backend="loky")(delayed(jnp.mean)(x) for x in x_list)
# Backend: threading (does *not* result in a warning)
Parallel(n_jobs=2, backend="threading")(delayed(jnp.mean)(x) for x in x_list) Backend: lokyResults in the warning RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is
multithreaded, so this will likely lead to a deadlock. Backend: threadingResults in no warning. Versions
|
Hey @timmens, thanks for the MRE and the suggested solution. I have changed the argument as suggested and I don't get an error on the HPC machine either. I also compared the performance of my bigger code project for a serial run, a run with |
This is a very important usecase for us and we should offer a batch evaluator that support jax functions. It's not just for multistart but also for bootstrap or parallelizing optimizers. Instead of making the batch evaluator configurable with more arguments I would probably just add a new batch evaluator. In the meantime I see two workarounds:
Disabling JAX's default parallelism is probably a good idea anyways when you do multistart. Running multiple optimizations in parallel is a very simple and efficient form of parallelization. So as long as you have enough optimizations to keep your computer busy you probably don't want parallelize the objective function. |
Hey there,
I am trying to run a gradient-based algorithm with multistart of my jit compatible code in parallel. Can I use estimagic's parallelisation using 'nprocs' via
joblib
orpathos
or do I need to create asample
for the exploration phase manually and distribute it using jax parallelisation?When running
multistart=True
withn_procs=2
, I'm encountering the following warning:RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
If helpful, I can post a some code snippets from my implementation.
The text was updated successfully, but these errors were encountered: