You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I know this is probably more of an issue on the jax side and has been discussed there, e.g. jax-ml/jax#743, jax-ml/jax#1539 and jax-ml/jax#6790, although I'm still wondering if you know how to limit the # of threads for jax. Below is a simple snippet showing that currently, jax does not observe the threadpool limits.
import jax.numpy as jnp
from threadpoolctl import threadpool_limits
ja = jnp.ones((1000, 1000))
with threadpool_limits(5):
for _ in range(100):
foo = ja @ ja
The text was updated successfully, but these errors were encountered:
Hi @HerculesJack, according to this comment jax-ml/jax#743 (comment), the threading mechanism of jax is not one of the ones that threadpoolctl supports. It could be interesting to check if Eigen threadpools exposes some symbols allowing to control the number of threads.
Note: if Eigen exposes some well defined symbols to inspect and control the number of threads in its threadpool, then the mechanism implemented in #137 should make it possible to add support for tensorflow and jax to threadpoolctl.
Hi, I know this is probably more of an issue on the jax side and has been discussed there, e.g. jax-ml/jax#743, jax-ml/jax#1539 and jax-ml/jax#6790, although I'm still wondering if you know how to limit the # of threads for jax. Below is a simple snippet showing that currently, jax does not observe the threadpool limits.
The text was updated successfully, but these errors were encountered: