Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize jax distributed when checkpointing is enabled #895

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Sep 16, 2024

Nightly tests are failing due to jax.distributed not being initialized in the synchronous checkpointing case.

Nightly tests are failing due to jax.distributed not being initialized in the synchronous checkpointing case.
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jon!

@jonb377
Copy link
Collaborator Author

jonb377 commented Sep 19, 2024

@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.

@gobbleturk
Copy link
Collaborator

@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.

We have to initialize the jax distributed system before the runtime backend, so it doesn't seem this can be inferred. However we can create an additional option to MaxText - either something like "is_single_host" or "should_initialize_jax_distributed_system" in the config as a boolean option, wdyt?

@gobbleturk
Copy link
Collaborator

@gobbleturk Is there a way to detect a multihost environment? I see GPU and inference tests failing due to this change.

We have to initialize the jax distributed system before the runtime backend, so it doesn't seem this can be inferred. However we can create an additional option to MaxText - either something like "is_single_host" or "should_initialize_jax_distributed_system" in the config as a boolean option, wdyt?

Do we still need to initialize the JDI even for single host? We may want to support supplying correct args in this case - process_id = 0, num_processes=1 coordinator_ip = get_own_ip()

@jonb377
Copy link
Collaborator Author

jonb377 commented Sep 19, 2024

For the TPU tests, perhaps I should try to fix the backend being initialized - the error is in pyconfig, so I don't expect the backend to be up.

The GPU failures can be fixed by setting the coordinator address in the environment. But that does seem overkill for single-host...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants