diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index befde370..26c84d57 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -8,7 +8,6 @@ import asyncio import copy import dataclasses -import functools import os from concurrent import futures from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -63,6 +62,10 @@ class RestoreArgs(ocp.type_handlers.RestoreArgs): def typestr(self) -> str: return "TfIterator" + def _ckpt_dir(self, info: ocp.type_handlers.ParamInfo) -> str: + # Each worker writes its tf checkpoints under a different path. + return os.path.join(info.parent_dir, f"tf_{jax.process_index()}") + async def serialize( self, values: Sequence[tf.data.Iterator], @@ -74,7 +77,11 @@ async def serialize( futs = [] with futures.ThreadPoolExecutor(max_workers=1) as executor: for value, info in zip(values, infos): - futs.append(async_save_tf_savables(value, executor=executor, dir=info.path)) + futs.append( + async_save_tf_savables( + {info.name: value}, executor=executor, dir=self._ckpt_dir(info) + ) + ) return futs async def deserialize( @@ -84,14 +91,16 @@ async def deserialize( ) -> Sequence[tf.data.Iterator]: if args is None: raise ValueError(f"{self.RestoreArgs.__name__} should be supplied as args.") + futs = [] with futures.ThreadPoolExecutor(max_workers=1) as executor: - futs = [ - asyncio.get_event_loop().run_in_executor( - executor, - functools.partial(restore_tf_savables, arg.item, dir=info.path), - ) - for arg, info in zip(args, infos) - ] + for arg, info in zip(args, infos): + + def restore(arg=arg, info=info): + return restore_tf_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[ + info.name + ] + + futs.append(asyncio.get_event_loop().run_in_executor(executor, restore)) return await asyncio.gather(*futs) async def metadata( @@ -115,6 +124,10 @@ class RestoreArgs(ocp.type_handlers.RestoreArgs): def typestr(self) -> str: return "DatasetIterator" + def _ckpt_dir(self, info: ocp.type_handlers.ParamInfo) -> str: + # Each worker writes its grain checkpoints under a different path. + return os.path.join(info.parent_dir, f"grain_{jax.process_index()}") + async def serialize( self, values: Sequence[grain.DatasetIterator], @@ -124,9 +137,7 @@ async def serialize( """Serializes `values` into corresponding `info.path`s.""" del args # Unused. for value, info in zip(values, infos): - ckpt_dir = os.path.dirname(info.path) - path = os.path.basename(info.path) - maybe_save_grain_savables({path: value}, dir=ckpt_dir) + maybe_save_grain_savables({info.name: value}, dir=self._ckpt_dir(info)) return [] async def deserialize( @@ -136,10 +147,14 @@ async def deserialize( ) -> Sequence[_GrainIterator]: if args is None: raise ValueError(f"{self.RestoreArgs.__name__} should be supplied as args.") - return [ - maybe_restore_grain_savables(arg.item, dir=info.path) - for arg, info in zip(args, infos) - ] + ret = [] + for arg, info in zip(args, infos): + ret.append( + maybe_restore_grain_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[ + info.name + ] + ) + return ret async def metadata( self, infos: Sequence[ocp.type_handlers.ParamInfo] diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index 8678692a..7e348530 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -431,9 +431,18 @@ def test_input_iterator(self, checkpointer_cls): input_iter=input_iter, ) + self.assertEqual([], os.listdir(cfg.dir)) + ckpt.save(step=100, state=state0) ckpt.wait_until_finished() + # Check that input iterators are saved under a per-worker path. + # E.g., /path/to//[state/]tf_0/input_iter.index. + state_dir = ckpt.ckpt_dir(100) + if "state" in os.listdir(state_dir): + state_dir = os.path.join(state_dir, "state") + self.assertIn("tf_0", os.listdir(state_dir)) + state0_specs = dict( x=utils.TensorSpec(shape=[], dtype=jnp.int32), # The same iterator, but with the position at 0. @@ -469,9 +478,18 @@ def test_grain(self, checkpointer_cls): self.assertEqual(next(ds), 1) state0 = dict(x=jnp.ones([3, 2]), y=ds) + self.assertEqual([], os.listdir(cfg.dir)) + ckpt.save(step=100, state=state0) ckpt.wait_until_finished() + # Check that input iterators are saved under a per-worker path. + # E.g., /path/to//[state/]grain_0/input_iter.index. + state_dir = ckpt.ckpt_dir(100) + if "state" in os.listdir(state_dir): + state_dir = os.path.join(state_dir, "state") + self.assertIn("grain_0", os.listdir(state_dir)) + state0_specs = dict( x=utils.TensorSpec(shape=[3, 2], dtype=jnp.float32), # The same iterator, but with the position at 0.