-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Install fork() warning during backend initialization, rather than jax… #20734
Conversation
… import. This avoids warning people making an incidental import of JAX.
+1 - the existing behavior is incredibly disruptive to my team right now. Because of how our test suite works every other line is one of these warnings, and we only use jax in a single place in the codebase. The sooner a release can be made with this fix the better! |
@apmorton I'd suggest just downgrading to the previous version or using the next JAX nightly for the moment. (JAX by itself is enough, jaxlib doesn't have to be a nightly.) |
I think this feature of warning on each fork attempt is too aggressive as it is triggered also on valid os.fork use cases in a program that may import jax. Notice that the feature here also duplicates Python 3.12 os.fork feature that will raise So, I suggest to disable |
I don't think so! If JAX's runtime has been initialized, then we have immediately created a threadpool. So you can no longer safely call (From C++: maybe, if you're really careful about what you do, but then the Python
My concern is that |
Ok, just to clarify, is the following program unsafe?: >>> from multiprocessing import Pool
>>> def f(x): return x * x
...
>>> import jax
>>> jax.devices()
[CudaDevice(id=0), CudaDevice(id=1)]
>>> with Pool(5) as p:
... print(p.map(f, [1, 2, 3]))
...
multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
[1, 4, 9] that is, os.fork is called independently from jax. |
Yes, it is unsafe. You must not use the But, for example, this works fine:
|
Yes, I was aware of different start methods in multiprocessing. Note that my original example with the fork start method works fine as well (there is no deadlock), just with the warning message. So there seems to be some slack in the "You must not use ..." statement which I was looking for to clarify. But it is not a big deal and thanks for all your feedback! |
Well, it's a warning. To give an analogy, I'm just telling you you're letting a bunch of angry Komodo dragons out of their cage. They may or may not bite you any given day... :-) How might this go wrong? Suppose your |
… import.
This avoids warning people making an incidental import of JAX.