Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FullSyncIterDataPipe #713

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import os
import unittest

from unittest import TestCase

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize

from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter.util.prefetch import PrefetchTimeoutError

TEST_MASTER_ADDR = "127.0.0.1"
TEST_MASTER_PORT = "29500"
DEFAULT_WORLD_SIZE = 2


if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)


def launch_distributed_training(backend, world_size, fn):
os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR
os.environ["MASTER_PORT"] = TEST_MASTER_PORT
mp.spawn(
fn,
args=(
world_size,
backend,
),
nprocs=world_size,
join=True,
)


class DistributedTest(TestCase):
@staticmethod
def _test_fullsync(rank, world_size, backend):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Use a prime number to make sure uneven data sharding
data_length = 23
dp = IterableWrapper(list(range(data_length))).sharding_filter()
torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank)

dp1 = dp.fullsync()
for _ in range(2):
res = []
for d in dp1:
res.append(d)
# Simulate training synchronization
dist.barrier()
assert res == list(range(rank, data_length // world_size * world_size, world_size))

# Timeout Test
dp2 = dp.fullsync(timeout=0.01)
try:
for _ in range(2):
_ = list(dp2)
except Exception as e:
assert isinstance(e, PrefetchTimeoutError)

@parametrize(
"backend",
["gloo", "nccl"]
if torch.cuda.nccl.is_available([])
else [
"gloo",
],
)
def test_fullsync(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend == "gloo" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, DistributedTest._test_fullsync)


instantiate_parametrized_tests(DistributedTest)


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions torchdata/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Use the same timeout as PyTorch Distributed
default_timeout_in_s = 30 * 60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment (no change) - we should put other things such as default buffer size here too.

2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
CSVParserIterDataPipe as CSVParser,
LineReaderIterDataPipe as LineReader,
)
from torchdata.datapipes.iter.util.prefetch import FullSyncIterDataPipe as FullSync
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
from torchdata.datapipes.iter.util.samplemultiplexer import SampleMultiplexerDataPipe as SampleMultiplexer
Expand Down Expand Up @@ -154,6 +155,7 @@
"Filter",
"FlatMapper",
"Forker",
"FullSync",
"GDriveReader",
"Grouper",
"HashChecker",
Expand Down
1 change: 1 addition & 0 deletions torchdata/datapipes/iter/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ${init_base}
# classes/objects here, even though we are not injecting extra code into them at the moment.

from .util.decompressor import CompressionType
from torchdata._constants import default_timeout_in_s
ejguan marked this conversation as resolved.
Show resolved Hide resolved
from torchdata.datapipes.map import MapDataPipe
from torch.utils.data import DataChunk, IterableDataset, default_collate
from torch.utils.data.datapipes._typing import _DataPipeMeta
Expand Down
209 changes: 209 additions & 0 deletions torchdata/datapipes/iter/util/prefetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading

from collections import deque
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from dataclasses import dataclass
from functools import partial
from typing import Callable, Deque, Iterator, Optional, TypeVar

import torch
import torch.distributed as dist

from torchdata._constants import default_timeout_in_s
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

T_co = TypeVar("T_co", covariant=True)


__all__ = ["Expected", "FullSyncIterDataPipe", "PrefetchTimeoutError"]


class PrefetchTimeoutError(RuntimeError):
def __init__(self, timeout: int) -> None:
super().__init__(f"Fail to fetch data within {timeout} seconds")


class _EndOfPrefetch:
...


@dataclass
class Expected:
r"""
Expected data provided to callback function in ``_PrefetchExecutor``.
"""
index: int
error: Optional[BaseException] = None

def has_error(self) -> bool:
return self.error is not None


class _PrefetchExecutor:
Copy link
Contributor Author

@ejguan ejguan Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This executor takes reference from the internal implementation: https://fburl.com/code/7dk6mvs4
On top of the implementation, I added prefetch_size and attached index to Expected object to make sure it can work with Prefetch in the future.

def __init__(
self,
datapipe_iterator: Iterator,
prefetch_size: int = 1,
callback_fn: Optional[Callable[[Expected], None]] = None,
timeout: int = default_timeout_in_s,
) -> None:
self.datapipe_iterator = datapipe_iterator
self.prefetch_size = prefetch_size
self.callback_fn = callback_fn
self.timeout = timeout
# Use max_workers as 1 to guarantee the order of data fetched from iterator
self._executor = ThreadPoolExecutor(max_workers=1)
self._futures: Deque[Future] = deque()
self._lock = threading.RLock()
self._end_flag = False
self._idx = 0
for _ in range(prefetch_size):
with self._lock:
if self._end_flag:
break
fetch_future: Future = self._executor.submit(self.fetch_next)
fetch_future.add_done_callback(partial(self._done_callback_fn, self._idx))
self._futures.append(fetch_future)
with self._lock:
self._idx += 1

def fetch_next(self):
return next(self.datapipe_iterator)

def _done_callback_fn(self, index: int, f: Future):
if f.exception():
with self._lock:
self._end_flag = True
ejguan marked this conversation as resolved.
Show resolved Hide resolved
if self.callback_fn is not None:
self._executor.submit(self.callback_fn, Expected(index, f.exception()))

def return_next(self):
if self._futures:
fetch_future = self._futures.popleft()
try:
data = fetch_future.result(timeout=self.timeout)
except TimeoutError:
raise PrefetchTimeoutError(self.timeout)
with self._lock:
if not self._end_flag:
next_future = self._executor.submit(self.fetch_next)
next_future.add_done_callback(partial(self._done_callback_fn, self._idx))
self._futures.append(next_future)
self._idx += 1
else:
data = _EndOfPrefetch()
return data

def shutdown(self):
self._executor.shutdown(wait=True)


@functional_datapipe("fullsync")
class FullSyncIterDataPipe(IterDataPipe[T_co]):
r"""
Synchronizes data across distributed processes to prevent hanging during training,
which is caused by uneven sharded data (functional name: ``fullsync``). It should
be appended at the end of the graph of ``DataPipe`` by ``DistributedReadingService``
automatically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do we recommend against usage of this DataPipe outside of a ReadingService? If not, can we potentially include an example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Will add it even though we should always recommend users relying on RS

Args:
datapipe: IterDataPipe that needs to be synchronized
timeout: Timeout for prefetching data in seconds. Default value equals to 30 minutes
"""

def __init__(self, datapipe: IterDataPipe, timeout=default_timeout_in_s):
self.datapipe = datapipe
self.timeout = timeout

self._process_group = None
self._world_size = 1

self._lock = threading.RLock()
self._cv = threading.Condition(lock=self._lock)
NivekT marked this conversation as resolved.
Show resolved Hide resolved
self._executor: Optional[_PrefetchExecutor] = None
# Use single values rather than deques for the following variables
# because fullsync only prefetches 1 element
self._error = None
self._sync_counter = torch.tensor([0], dtype=torch.int32)
self._done_callback = False

def _callback_fn(self, exp: Expected) -> None:
with self._cv:
if exp.has_error():
if not isinstance(exp.error, StopIteration):
self._error = exp.error # type: ignore[assignment]
self._sync_counter = torch.tensor([0], dtype=torch.int32)
else:
self._sync_counter = torch.tensor([1], dtype=torch.int32)
dist.all_reduce(
tensor=self._sync_counter,
ejguan marked this conversation as resolved.
Show resolved Hide resolved
op=dist.ReduceOp.SUM,
group=self._process_group,
)
self._done_callback = True
self._cv.notify()

def __iter__(self) -> Iterator[T_co]:
assert self._executor is None

if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError("Torch Distributed is required to be initialized")
self._process_group = dist.new_group(backend="gloo")
self._world_size = dist.get_world_size()

self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout)
while True:
with self._cv:
is_success = self._cv.wait_for(
lambda: self._done_callback is True,
self.timeout,
)
if not is_success:
raise PrefetchTimeoutError(self.timeout)
if self._error is not None:
raise self._error
if bool(self._sync_counter < self._world_size):
break
self._done_callback = False
data = self._executor.return_next() # type: ignore[attr-defined]
if isinstance(data, _EndOfPrefetch):
break
yield data

def reset(self):
if self._executor is not None:
self._executor.shutdown()
self._executor = None
self._process_group = None
self._world_size = 1
with self._cv:
self._error = None
self._sync_counter = torch.tensor([0], dtype=torch.int32)
self._done_callback = False

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (
self.datapipe,
self.timeout,
)
return state
Comment on lines +205 to +212
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, checkpoint for fullsync or prefetch is a little tricky.
Let's confirm the expected behavior. When we do checkpoint, we should pause any further prefetching and save all prefetched data into a buffer. Then, we serialize the buffer ant inner datapipe (because we have to serialize datapipe after prefetching is done). And, only when we start iteration again, would we start prefetching again.

WDYT: @VitalyFedyunin @NivekT

Then, the whole logic of fullsync should be changed. This is even more complicated when the data ends when put the prefetched data into the buffer. I might open a new PR to achieve serialization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think we should stop the prefetch and capture current data. I feel this can be similar to internal client snapshot, so https://fburl.com/code/6hrjawgh may be helpful for reference


def __setstate__(self, state):
self.datapipe, self.timeout = state
self._process_group = None
self._world_size = 1
self._lock = threading.RLock()
self._cv = threading.Condition(lock=self._lock)
self._executor = None
self._error = None
self._sync_counter = torch.tensor([0], dtype=torch.int32)
self._done_callback = False