Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit jax multithreading #743

Closed
samuela opened this issue May 21, 2019 · 12 comments
Closed

Limit jax multithreading #743

samuela opened this issue May 21, 2019 · 12 comments
Labels
enhancement New feature or request

Comments

@samuela
Copy link
Contributor

samuela commented May 21, 2019

By default jax appears to multithread most operations, eg.

x = jr.normal(jrkey, shape=(50000, 50000))
x @ x

will run across all available cores. This is great in general, and matches numpy's behavior. But it presents problems when trying to run a bunch of small operations in parallel, eg. running the same script initialized with 4 different random seeds on a 4-core machine.

Is there any option in jax to cap the number of threads that it uses? Something like https://stackoverflow.com/questions/17053671/python-how-do-you-stop-numpy-from-multithreading?

@mattjj
Copy link
Collaborator

mattjj commented May 21, 2019

Great question!

I'm not sure, but I think we'd want to set the xla_cpu_multi_thread_eigen (also defined as a flag) and intra_op_parallelism_threads XLA options to be False. We could expose those in JAX somehow...

@mattjj mattjj added the enhancement New feature or request label May 21, 2019
@samuela
Copy link
Contributor Author

samuela commented May 22, 2019

My Hail Mary attempt of setting the OPENBLAS_NUM_THREADS, MKL_NUM_THREADS, and OMP_NUM_THREADS environment variables didn't work.

@jekbradbury
Copy link
Contributor

Those are the right environment variables to fiddle with for PyTorch, but they won't have an effect in JAX (or any TensorFlow or XLA-based codebase on CPU) because the BLAS library used is Eigen Tensor (not OpenBLAS or MKL) and the threading mechanism used is Eigen threadpools (not OpenMP).

@samuela
Copy link
Contributor Author

samuela commented Jul 16, 2019

This is now a blocking issue for me. I'd be happy to put together a PR, but I'm not really sure how to start... Also not sure how you guys would like this to be exposed to the user.

For my usecase I'd like to run many jax experiments in parallel. Ideally these could all be managed as multiprocessing threads in a Pool with each task restricted to some subset of CPUs. When using numpy I accomplish this with

import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

...

pool = multiprocessing.Pool()
pool.imap_unordered(my_task, range(num_random_seeds))

But if thread-based throttling isn't possible that's not a dealbreaker. I can always kick off jobs as separate python processes.

@mattjj
Copy link
Collaborator

mattjj commented Jul 16, 2019

Did you try setting these environment variables? (My comment didn't explain this very well.)

XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" python my_file.py

That seems to work for me in a test, at least for a big matmul.

@samuela
Copy link
Contributor Author

samuela commented Jul 16, 2019

Awesome, that seems to do the trick! Thank you so much!

@mattjj mattjj closed this as completed Jul 16, 2019
@samuela
Copy link
Contributor Author

samuela commented Sep 30, 2019

For my own (and other's) future googling, my current approach looks like

from multiprocessing import get_context
import os

# Limit ourselves to single-threaded jax/xla operations to avoid thrashing. See
# https://github.com/google/jax/issues/743.
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")

def job(random_seed: int):
  # jax jax jax

if __name__ == "__main__":
  # See https://codewithoutrules.com/2018/09/04/python-multiprocessing/.
  with get_context("spawn").Pool() as pool:
    pool.imap_unordered(job, range(100))

There may be a way better way, but it seems to work 🤷‍♀️

@Joshuaalbert
Copy link
Contributor

@mattjj I don't see intra_op_parallelism_threads among the XLA_FLAG options. Also, I still see the multi-threaded behaviour trying your suggested combination of threads. Any chance that it would be possible to set intra_op thread limit specifically somewhere?

@VolodyaCO
Copy link

The solution above does not work for me in an M1 chip.

@VolodyaCO
Copy link

Update. The following worked, keeping CPU usage low

import os
import jax


os.environ[
    "XLA_FLAGS"
] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREAD"] = "1"


def job():
    A = jax.random.normal(jax.random.PRNGKey(0), (3000, 3000))
    return A @ A @ A @ A


if __name__ == "__main__":
    for i in range(100):
        job()
        print("Done", i + 1)

@louis-shao
Copy link

Hi, I'm trying to limit the computation of my xla compiled C++ code to run on one CPU core. I tried all the approaches here including the last one with 4 environment variables. The running time of the program is indeed slower after the change, but still quite a bit faster than adding taskset NN prefix to the command to explicitly bind to one core. Does that indicate that the environment variables don't fully limit it to one core?

@dachengx
Copy link

dachengx commented Oct 10, 2024

As @juehang pointed out, there is one possible fix: #22739. And it seems to have better performance compared to multi-cores in some circumstances.

export NPROC=1

related TODO: https://github.com/openxla/xla/blob/8d7f72345ce58ce959dcc99f2b1a13e5c672b3e9/xla/pjrt/utils.cc#L817

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants