diff --git a/axlearn/cloud/gcp/tpu_health_check.py b/axlearn/cloud/gcp/tpu_health_check.py index 9424d8eef..99e21e84f 100644 --- a/axlearn/cloud/gcp/tpu_health_check.py +++ b/axlearn/cloud/gcp/tpu_health_check.py @@ -15,11 +15,13 @@ shortest timeout. Pairwise health check should have the longest timeout since different slices may bring up their container at different times. -The main API is the `health_check` function, which is commonly enabled via context manager: - -with health_check(spec, output_dir=...): +The main API is the `setup` function, which is commonly enabled via context manager: +``` +with setup(spec, output_dir=...): # Initialize jax distributed. +``` """ + import os import signal import subprocess @@ -31,7 +33,6 @@ from typing import Literal, Optional, Union import tensorflow as tf -import tensorflow_io # pylint: disable=unused-import from absl import logging from axlearn.cloud.gcp import tpu_health_check_main @@ -127,7 +128,7 @@ def _run_health_check_program( @contextmanager -def health_check(check_spec: str, *, output_dir: str): +def setup(check_spec: str, *, output_dir: str): _pre_init_health_check(check_spec, output_dir=output_dir) yield # Skip global health check if there's an exception. diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index 31532be2a..7e8ebfa27 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -2,16 +2,17 @@ """A library with common flags to launch a trainer.""" +import contextlib import importlib import os import sys -# pylint: disable=wrong-import-position,wrong-import-order -from contextlib import nullcontext - # pylint: disable-next=ungrouped-imports from axlearn.common import compiler_options +# pylint: disable=wrong-import-position,wrong-import-order + + instance_type = os.environ.get("TPU_TYPE", "none") num_tpu_slices = int(os.environ.get("NUM_TPU_SLICES", 1)) @@ -77,36 +78,40 @@ os.environ.get("PROCESS_ID", None), "Rank of the current process. Must be None on tpu, otherwise required.", ) -flags.DEFINE_string( - "health_check_module", - None, - "Path to health check module to run, e.g. axlearn.cloud.gcp.tpu_health_check. " - "Defaults to None, meaning no health check will run.", -) -flags.DEFINE_string( - "health_check_spec", - "", - "See the docstring of your `health_check_module`.", +flags.DEFINE_multi_string( + "init_module", + [], + "Zero or more init modules to import prior to setting up JAX distributed. " + "Each flag value should be a string containing 'module_path' or 'module_path:spec', e.g. " + "'axlearn.cloud.gcp.tpu_health_check' or 'axlearn.cloud.gcp.tpu_health_check:output_dir=...'.\n" + "The module should expose a public function `setup`, a context manager exposing pre- and post-" + "SPMD setup logic which is entered prior to `setup_spmd` and exited immediately afterwards.\n" + "The spec (if provided) will be provided to `module.setup(spec)` and therefore can be " + "implementation dependent. Not specifying a spec is equivalent to passing `None` to `setup`.\n" + "If specifying multiple modules, each `setup` context is entered in the given order.", ) FLAGS = flags.FLAGS +# Kept separate for easier testing. +@contextlib.contextmanager +def _init_context(fv: flags.FlagValues = FLAGS): + with contextlib.ExitStack() as ctx: + for module_spec in fv.init_module: + parts = module_spec.split(":", maxsplit=1) + [None] + module, spec = parts[:2] + ctx.enter_context(importlib.import_module(module).setup(spec)) + yield + + def setup(): if tpu_flags_exc is not None: logging.info("LIBTPU_INIT_FLAGS was not set. Reason: %s", tpu_flags_exc) else: logging.info("LIBTPU_INIT_ARGS='%s'", os.environ["LIBTPU_INIT_ARGS"]) - if FLAGS.health_check_module: - health_check = importlib.import_module(FLAGS.health_check_module).health_check( - FLAGS.health_check_spec, - output_dir=FLAGS.trainer_dir, - ) - else: - health_check = nullcontext() - - with health_check: + with _init_context(): setup_spmd( distributed_coordinator=FLAGS.distributed_coordinator, num_processes=FLAGS.num_processes, diff --git a/axlearn/common/launch_test.py b/axlearn/common/launch_test.py new file mode 100644 index 000000000..4cb87dd37 --- /dev/null +++ b/axlearn/common/launch_test.py @@ -0,0 +1,61 @@ +# Copyright © 2024 Apple Inc. + +"""Tests launch utils.""" + +import contextlib +from typing import Optional +from unittest import mock + +from absl import flags +from absl.testing import parameterized + +from axlearn.common.launch import _init_context +from axlearn.common.test_utils import TestCase + + +class TestInitContext(TestCase): + """Tests _init_context.""" + + @parameterized.parameters( + dict(value=["my.module.path"], expect={"my.module.path": None}), + dict( + value=["my.module.path:k1=v1,k2=v2"], + expect={"my.module.path": "k1=v1,k2=v2"}, + ), + dict( + value=["my.module.path:k1:v1"], + expect={"my.module.path": "k1:v1"}, + ), + dict( + value=["my.module.path:k1:v1", "my.other.module:k2:v2,k3:v3"], + expect={ + "my.module.path": "k1:v1", + "my.other.module": "k2:v2,k3:v3", + }, + ), + ) + def test_init_context(self, value, expect: dict[str, Optional[str]]): + fv = flags.FlagValues() + flags.DEFINE_multi_string("init_module", value, "", flag_values=fv) + fv.mark_as_parsed() + + with mock.patch("importlib.import_module") as mock_import: + with _init_context(fv): + for i, k in enumerate(expect): + self.assertEqual(k, mock_import.call_args_list[i][0][0]) + + side_effect = [] + actual_specs = [] + for _ in range(len(value)): + + @contextlib.contextmanager + def mock_setup(actual): + actual_specs.append(actual) + yield + + mock_module = mock.Mock(**{"setup.side_effect": mock_setup}) + side_effect.append(mock_module) + + with mock.patch("importlib.import_module", side_effect=side_effect): + with _init_context(fv): + self.assertEqual(list(expect.values()), actual_specs)