Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX hangs when allocating array of at least 2^19 bytes on CPU in subprocess #18852

Closed
colehaus opened this issue Dec 6, 2023 · 4 comments · Fixed by #18989
Closed

JAX hangs when allocating array of at least 2^19 bytes on CPU in subprocess #18852

colehaus opened this issue Dec 6, 2023 · 4 comments · Fixed by #18989
Labels
bug Something isn't working

Comments

@colehaus
Copy link

colehaus commented Dec 6, 2023

Description

Suppose I have one main process that's running JAX computations on the GPU. That process spawns a subprocess and uses the CPU as the default device for the subprocess. A JAX computation in the CPU process which tries to allocate an array with 2^19 or more bytes hangs indefinitely. However, this does not occur if the main process never initializes JAX on the GPU (i.e. commenting out the print(jnp.array(1).device()) below allows the whole CPU computation to succeed).

Here's reproducing code:

import logging
import sys
from multiprocessing import Process, current_process

import jax
import jax.numpy as jnp
import numpy as np


@jax.jit
def fn():
    test = jnp.ones((2**18 - 1,), dtype=np.int16)
    jax.debug.print("{x} {y}", x=test, y=test.nbytes)
    test = jnp.ones((2**17 - 1,), dtype=np.int32)
    jax.debug.print("{x} {y}", x=test, y=test.nbytes)
    test = jnp.ones((2**17 - 1,), dtype=np.float32)
    jax.debug.print("{x} {y}", x=test, y=test.nbytes)
    # Lines above this point do print. Lines after do not (regardless of ordering. i.e. all of them hang)
    test = jnp.ones((2**17,), dtype=np.float32)
    jax.debug.print("{x} {y}", x=test, y=test.nbytes)
    test = jnp.ones((2**18,), dtype=np.int16)
    jax.debug.print("{x} {y}", x=test, y=test.nbytes)
    test = jnp.ones((2**17,), dtype=np.int32)
    jax.debug.print("{x} {y}", x=test, y=test.nbytes)


def worker():
    print(current_process().name)
    logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
    with jax.default_device(jax.devices("cpu")[0]):
        fn()


if __name__ == "__main__":
    # If this is commented out, the code runs to completion
    print(jnp.array(1).device())
    Process(target=worker).start()

And here's output from a run:

2023-12-06 22:23:45.259725: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.103). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
cuda:0
Process-1
DEBUG:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.005032062530517578 sec
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types []. Argument mapping: ().
DEBUG:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(fn) in 0.008342504501342773 sec
DEBUG:jax._src.compiler:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.compiler:get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:jax._src.dispatch:Finished XLA compilation of jit(fn) in 0.039002180099487305 sec
[1 1 1 ... 1 1 1] 524286
[1 1 1 ... 1 1 1] 524284
[1. 1. 1. ... 1. 1. 1.] 524284
<hangs indefinitely at this point>

What jax/jaxlib version are you using?

0.4.20 for both

Which accelerator(s) are you using?

CPU and GPU

Additional system info?

Numpy: 1.26.1 Sys: 3.11.6 (main, Oct 23 2023, 22:48:54) [GCC 11.4.0] Uname: uname_result(system='Linux', node='np9fqenf2u', release='5.19.0-45-generic', version='#46~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 7 15:06:04 UTC 20', machine='x86_64')

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   29C    P8    25W / 300W |      1MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
@colehaus colehaus added the bug Something isn't working label Dec 6, 2023
@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 6, 2023

I'm not very familiar with multiprocessing, but if Process uses fork() to start the subprocess, this is expected. fork() is incompatible with multithreading, and JAX is always multithreaded.

I think you're likely to have better luck with either the spawn or forkserver approaches.

@colehaus
Copy link
Author

colehaus commented Dec 6, 2023

Ah, that makes sense. Changing the Process line to:

multiprocessing.get_context("spawn").Process(target=worker).start()

does indeed resolve the error.

How sensible would it be to try to proactively detect misuse like this? The way this problem manifests is pretty non-obvious IMO. It looks like multiprocessing.get_start_method() does generally reflect how the subprocess was started (this presumably wouldn't be perfect but would perhaps be better than nothing).

Thanks!

@hawkinsp
Copy link
Collaborator

#18989 adds a warning if os.fork() is called!

@colehaus
Copy link
Author

Thanks!

hawkinsp added a commit to hawkinsp/jax that referenced this issue Dec 15, 2023
nico-bohlinger added a commit to nico-bohlinger/RL-X that referenced this issue Feb 25, 2024
Prevents possible interference with JAX, see: jax-ml/jax#18852
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants