Skip to content

Commit

Permalink
Ensure iterators are saved in per-process dir. (#846)
Browse files Browse the repository at this point in the history
* Ensure iterators are saved in per-process dir.

* Move dir handling to method.
  • Loading branch information
markblee authored Nov 18, 2024
1 parent 4bf92fa commit 2335d13
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
47 changes: 31 additions & 16 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import asyncio
import copy
import dataclasses
import functools
import os
from concurrent import futures
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -63,6 +62,10 @@ class RestoreArgs(ocp.type_handlers.RestoreArgs):
def typestr(self) -> str:
return "TfIterator"

def _ckpt_dir(self, info: ocp.type_handlers.ParamInfo) -> str:
# Each worker writes its tf checkpoints under a different path.
return os.path.join(info.parent_dir, f"tf_{jax.process_index()}")

async def serialize(
self,
values: Sequence[tf.data.Iterator],
Expand All @@ -74,7 +77,11 @@ async def serialize(
futs = []
with futures.ThreadPoolExecutor(max_workers=1) as executor:
for value, info in zip(values, infos):
futs.append(async_save_tf_savables(value, executor=executor, dir=info.path))
futs.append(
async_save_tf_savables(
{info.name: value}, executor=executor, dir=self._ckpt_dir(info)
)
)
return futs

async def deserialize(
Expand All @@ -84,14 +91,16 @@ async def deserialize(
) -> Sequence[tf.data.Iterator]:
if args is None:
raise ValueError(f"{self.RestoreArgs.__name__} should be supplied as args.")
futs = []
with futures.ThreadPoolExecutor(max_workers=1) as executor:
futs = [
asyncio.get_event_loop().run_in_executor(
executor,
functools.partial(restore_tf_savables, arg.item, dir=info.path),
)
for arg, info in zip(args, infos)
]
for arg, info in zip(args, infos):

def restore(arg=arg, info=info):
return restore_tf_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[
info.name
]

futs.append(asyncio.get_event_loop().run_in_executor(executor, restore))
return await asyncio.gather(*futs)

async def metadata(
Expand All @@ -115,6 +124,10 @@ class RestoreArgs(ocp.type_handlers.RestoreArgs):
def typestr(self) -> str:
return "DatasetIterator"

def _ckpt_dir(self, info: ocp.type_handlers.ParamInfo) -> str:
# Each worker writes its grain checkpoints under a different path.
return os.path.join(info.parent_dir, f"grain_{jax.process_index()}")

async def serialize(
self,
values: Sequence[grain.DatasetIterator],
Expand All @@ -124,9 +137,7 @@ async def serialize(
"""Serializes `values` into corresponding `info.path`s."""
del args # Unused.
for value, info in zip(values, infos):
ckpt_dir = os.path.dirname(info.path)
path = os.path.basename(info.path)
maybe_save_grain_savables({path: value}, dir=ckpt_dir)
maybe_save_grain_savables({info.name: value}, dir=self._ckpt_dir(info))
return []

async def deserialize(
Expand All @@ -136,10 +147,14 @@ async def deserialize(
) -> Sequence[_GrainIterator]:
if args is None:
raise ValueError(f"{self.RestoreArgs.__name__} should be supplied as args.")
return [
maybe_restore_grain_savables(arg.item, dir=info.path)
for arg, info in zip(args, infos)
]
ret = []
for arg, info in zip(args, infos):
ret.append(
maybe_restore_grain_savables({info.name: arg.item}, dir=self._ckpt_dir(info))[
info.name
]
)
return ret

async def metadata(
self, infos: Sequence[ocp.type_handlers.ParamInfo]
Expand Down
18 changes: 18 additions & 0 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,18 @@ def test_input_iterator(self, checkpointer_cls):
input_iter=input_iter,
)

self.assertEqual([], os.listdir(cfg.dir))

ckpt.save(step=100, state=state0)
ckpt.wait_until_finished()

# Check that input iterators are saved under a per-worker path.
# E.g., /path/to/<step>/[state/]tf_0/input_iter.index.
state_dir = ckpt.ckpt_dir(100)
if "state" in os.listdir(state_dir):
state_dir = os.path.join(state_dir, "state")
self.assertIn("tf_0", os.listdir(state_dir))

state0_specs = dict(
x=utils.TensorSpec(shape=[], dtype=jnp.int32),
# The same iterator, but with the position at 0.
Expand Down Expand Up @@ -469,9 +478,18 @@ def test_grain(self, checkpointer_cls):
self.assertEqual(next(ds), 1)
state0 = dict(x=jnp.ones([3, 2]), y=ds)

self.assertEqual([], os.listdir(cfg.dir))

ckpt.save(step=100, state=state0)
ckpt.wait_until_finished()

# Check that input iterators are saved under a per-worker path.
# E.g., /path/to/<step>/[state/]grain_0/input_iter.index.
state_dir = ckpt.ckpt_dir(100)
if "state" in os.listdir(state_dir):
state_dir = os.path.join(state_dir, "state")
self.assertIn("grain_0", os.listdir(state_dir))

state0_specs = dict(
x=utils.TensorSpec(shape=[3, 2], dtype=jnp.float32),
# The same iterator, but with the position at 0.
Expand Down

0 comments on commit 2335d13

Please sign in to comment.