-
Notifications
You must be signed in to change notification settings - Fork 286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialize jax distributed when checkpointing is enabled #895
base: main
Are you sure you want to change the base?
Conversation
Nightly tests are failing due to jax.distributed not being initialized in the synchronous checkpointing case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Jon!
@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change. |
We have to initialize the jax distributed system before the runtime backend, so it doesn't seem this can be inferred. However we can create an additional option to MaxText - either something like "is_single_host" or "should_initialize_jax_distributed_system" in the config as a boolean option, wdyt? |
Do we still need to initialize the JDI even for single host? We may want to support supplying correct args in this case - process_id = 0, num_processes=1 coordinator_ip = get_own_ip() |
For the TPU tests, perhaps I should try to fix the backend being initialized - the error is in pyconfig, so I don't expect the backend to be up. The GPU failures can be fixed by setting the coordinator address in the environment. But that does seem overkill for single-host... |
Nightly tests are failing due to jax.distributed not being initialized in the synchronous checkpointing case.