-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Limit jax multithreading #743
Comments
Great question! I'm not sure, but I think we'd want to set the |
My Hail Mary attempt of setting the |
Those are the right environment variables to fiddle with for PyTorch, but they won't have an effect in JAX (or any TensorFlow or XLA-based codebase on CPU) because the BLAS library used is Eigen Tensor (not OpenBLAS or MKL) and the threading mechanism used is Eigen threadpools (not OpenMP). |
This is now a blocking issue for me. I'd be happy to put together a PR, but I'm not really sure how to start... Also not sure how you guys would like this to be exposed to the user. For my usecase I'd like to run many jax experiments in parallel. Ideally these could all be managed as import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
...
pool = multiprocessing.Pool()
pool.imap_unordered(my_task, range(num_random_seeds)) But if thread-based throttling isn't possible that's not a dealbreaker. I can always kick off jobs as separate python processes. |
Did you try setting these environment variables? (My comment didn't explain this very well.)
That seems to work for me in a test, at least for a big matmul. |
Awesome, that seems to do the trick! Thank you so much! |
For my own (and other's) future googling, my current approach looks like from multiprocessing import get_context
import os
# Limit ourselves to single-threaded jax/xla operations to avoid thrashing. See
# https://github.com/google/jax/issues/743.
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
"intra_op_parallelism_threads=1")
def job(random_seed: int):
# jax jax jax
if __name__ == "__main__":
# See https://codewithoutrules.com/2018/09/04/python-multiprocessing/.
with get_context("spawn").Pool() as pool:
pool.imap_unordered(job, range(100)) There may be a way better way, but it seems to work 🤷♀️ |
The solution above does not work for me in an M1 chip. |
Update. The following worked, keeping CPU usage low import os
import jax
os.environ[
"XLA_FLAGS"
] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREAD"] = "1"
def job():
A = jax.random.normal(jax.random.PRNGKey(0), (3000, 3000))
return A @ A @ A @ A
if __name__ == "__main__":
for i in range(100):
job()
print("Done", i + 1) |
Hi, I'm trying to limit the computation of my xla compiled C++ code to run on one CPU core. I tried all the approaches here including the last one with 4 environment variables. The running time of the program is indeed slower after the change, but still quite a bit faster than adding |
As @juehang pointed out, there is one possible fix: #22739. And it seems to have better performance compared to multi-cores in some circumstances.
related TODO: https://github.com/openxla/xla/blob/8d7f72345ce58ce959dcc99f2b1a13e5c672b3e9/xla/pjrt/utils.cc#L817 |
By default jax appears to multithread most operations, eg.
will run across all available cores. This is great in general, and matches numpy's behavior. But it presents problems when trying to run a bunch of small operations in parallel, eg. running the same script initialized with 4 different random seeds on a 4-core machine.
Is there any option in jax to cap the number of threads that it uses? Something like https://stackoverflow.com/questions/17053671/python-how-do-you-stop-numpy-from-multithreading?
The text was updated successfully, but these errors were encountered: