Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install fork() warning during backend initialization, rather than jax… #20734

Merged
merged 1 commit into from
Apr 15, 2024

Conversation

hawkinsp
Copy link
Collaborator

… import.

This avoids warning people making an incidental import of JAX.

… import.

This avoids warning people making an incidental import of JAX.
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 12, 2024
@apmorton
Copy link

+1 - the existing behavior is incredibly disruptive to my team right now.

Because of how our test suite works every other line is one of these warnings, and we only use jax in a single place in the codebase.

The sooner a release can be made with this fix the better!

@copybara-service copybara-service bot merged commit b9a853d into jax-ml:main Apr 15, 2024
13 checks passed
@hawkinsp
Copy link
Collaborator Author

@apmorton I'd suggest just downgrading to the previous version or using the next JAX nightly for the moment. (JAX by itself is enough, jaxlib doesn't have to be a nightly.)

@pearu
Copy link
Collaborator

pearu commented Sep 18, 2024

I think this feature of warning on each fork attempt is too aggressive as it is triggered also on valid os.fork use cases in a program that may import jax.

Notice that the feature here also duplicates Python 3.12 os.fork feature that will raise DeprecationWarning when one attempts to fork from a process with multiple threads.

So, I suggest to disable _at_fork handler for Python 3.12 or newer, or implement https://github.com/gpshead/cpython/blob/f10ece00f86f084c099c38f3f32fac2cd0a4d3ee/Modules/posixmodule.c#L6757 method within _at_fork. What do you think?

^ @hawkinsp

@hawkinsp
Copy link
Collaborator Author

I think this feature of warning on each fork attempt is too aggressive as it is triggered also on valid os.fork use cases in a program that may import jax.

I don't think so! If JAX's runtime has been initialized, then we have immediately created a threadpool. So you can no longer safely call fork() from Python.

(From C++: maybe, if you're really careful about what you do, but then the Python atfork() handler isn't relevant.)

Notice that the feature here also duplicates Python 3.12 os.fork feature that will raise DeprecationWarning when one attempts to fork from a process with multiple threads.

My concern is that DeprecationWarning is likely to be suppressed by the user. If you see this warning, it's a pretty high priority thing: it says: "your program is likely about to crash or hang".

@pearu
Copy link
Collaborator

pearu commented Sep 18, 2024

Ok, just to clarify, is the following program unsafe?:

>>> from multiprocessing import Pool
>>> def f(x): return x * x
... 
>>> import jax
>>> jax.devices()
[CudaDevice(id=0), CudaDevice(id=1)]
>>> with Pool(5) as p:
...   print(p.map(f, [1, 2, 3]))
... 
multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
[1, 4, 9]

that is, os.fork is called independently from jax.

@hawkinsp
Copy link
Collaborator Author

Yes, it is unsafe. You must not use the fork strategy in multiprocessing after JAX has been initialized.

But, for example, this works fine:

import multiprocessing
import jax

def f(x): return x * x

if __name__ == "__main__":
  multiprocessing.set_start_method("spawn")
  print(jax.devices())
  with multiprocessing.Pool(5) as p:
    print(p.map(f, [1, 2, 3]))

@pearu
Copy link
Collaborator

pearu commented Sep 18, 2024

You must not use the fork strategy in multiprocessing after JAX has been initialized.
But, for example, this works fine: ...

Yes, I was aware of different start methods in multiprocessing. Note that my original example with the fork start method works fine as well (there is no deadlock), just with the warning message. So there seems to be some slack in the "You must not use ..." statement which I was looking for to clarify. But it is not a big deal and thanks for all your feedback!

@hawkinsp
Copy link
Collaborator Author

Well, it's a warning. To give an analogy, I'm just telling you you're letting a bunch of angry Komodo dragons out of their cage. They may or may not bite you any given day... :-)

How might this go wrong? Suppose your fork() happened while another thread in JAX held a mutex. You might, say, deadlock on shutdown if we needed to acquire that mutex to shut down.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants