-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX hangs when allocating array of at least 2^19 bytes on CPU in subprocess #18852
Comments
I'm not very familiar with I think you're likely to have better luck with either the |
Ah, that makes sense. Changing the
does indeed resolve the error. How sensible would it be to try to proactively detect misuse like this? The way this problem manifests is pretty non-obvious IMO. It looks like Thanks! |
#18989 adds a warning if |
Thanks! |
Prevents possible interference with JAX, see: jax-ml/jax#18852
Description
Suppose I have one main process that's running JAX computations on the GPU. That process spawns a subprocess and uses the CPU as the default device for the subprocess. A JAX computation in the CPU process which tries to allocate an array with 2^19 or more bytes hangs indefinitely. However, this does not occur if the main process never initializes JAX on the GPU (i.e. commenting out the
print(jnp.array(1).device())
below allows the whole CPU computation to succeed).Here's reproducing code:
And here's output from a run:
What jax/jaxlib version are you using?
0.4.20 for both
Which accelerator(s) are you using?
CPU and GPU
Additional system info?
Numpy: 1.26.1 Sys: 3.11.6 (main, Oct 23 2023, 22:48:54) [GCC 11.4.0] Uname: uname_result(system='Linux', node='np9fqenf2u', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: