Selectively enabling/disabling multithreading on CPU #6790
Labels
enhancement
New feature or request
P2 (eventual)
This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Hi!
By default, JAX uses multithreading when running on CPU. I have a usecase where I'd like to disable this for some of my code (I'd like to parallelize things myself in a more flexible way), while still maintaining it overall.
I'm aware that threads #743 and #1539 suggest setting the environmental var
XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
to disable multithreading for the whole program. However, this is not flexible enough for me, as I'd only like to selectively disable/enable multithreading. An ideal API for me would be one where I could specify something along the lines ofintra_op_parallelism_threads=1
when callingjax.jit
.The text was updated successfully, but these errors were encountered: