From a087018c6e8144bb49ac865c2e5dd182730b8c6d Mon Sep 17 00:00:00 2001 From: Andrew Ho Date: Wed, 23 Oct 2024 19:25:17 -0400 Subject: [PATCH] Andrewkh/polylithic initial commit (#1335) Initial commit of torchdata.nodes (polylithic) Includes multi-process map, multi-thread map, batching, prefetching, pin-memory, and in-order / out-of-order support. --- .github/workflows/nodes_ci.yml | 81 +++++ test/nodes/__init__.py | 0 test/nodes/test_batch.py | 40 +++ test/nodes/test_map.py | 107 ++++++ test/nodes/test_pin_memory.py | 64 ++++ test/nodes/test_prefetch.py | 37 ++ test/nodes/utils.py | 50 +++ test/requirements.txt | 1 + test/stateful_dataloader/test_hugging_face.py | 6 + test/stateful_dataloader/test_sampler.py | 4 +- test/stateful_dataloader/test_state_dict.py | 1 - torchdata/nodes/__init__.py | 22 ++ torchdata/nodes/_apply_udf.py | 53 +++ torchdata/nodes/_populate_queue.py | 83 +++++ torchdata/nodes/base_node.py | 38 ++ torchdata/nodes/batch.py | 27 ++ torchdata/nodes/constants.py | 7 + torchdata/nodes/exception_wrapper.py | 11 + torchdata/nodes/map.py | 324 ++++++++++++++++++ torchdata/nodes/pin_memory.py | 102 ++++++ torchdata/nodes/prefetch.py | 26 ++ 21 files changed, 1080 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/nodes_ci.yml create mode 100644 test/nodes/__init__.py create mode 100644 test/nodes/test_batch.py create mode 100644 test/nodes/test_map.py create mode 100644 test/nodes/test_pin_memory.py create mode 100644 test/nodes/test_prefetch.py create mode 100644 test/nodes/utils.py create mode 100644 torchdata/nodes/__init__.py create mode 100644 torchdata/nodes/_apply_udf.py create mode 100644 torchdata/nodes/_populate_queue.py create mode 100644 torchdata/nodes/base_node.py create mode 100644 torchdata/nodes/batch.py create mode 100644 torchdata/nodes/constants.py create mode 100644 torchdata/nodes/exception_wrapper.py create mode 100644 torchdata/nodes/map.py create mode 100644 torchdata/nodes/pin_memory.py create mode 100644 torchdata/nodes/prefetch.py diff --git a/.github/workflows/nodes_ci.yml b/.github/workflows/nodes_ci.yml new file mode 100644 index 000000000..4a37aad5b --- /dev/null +++ b/.github/workflows/nodes_ci.yml @@ -0,0 +1,81 @@ +name: Run Nodes Tests +on: + push: + branches: + - main + - release/* + tags: + pull_request: + types: [opened, synchronize, reopened, labeled] + branches: + - main + # For PR created by ghstack + - gh/*/*/base + - release/* + +jobs: + test: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: + - macos-latest + - ubuntu-latest + - windows-latest + python-version: + - 3.9 + - "3.10" + - "3.11" + - "3.12" + steps: + - name: Get PyTorch Channel + shell: bash + run: | + if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then + PT_CHANNEL="https://download.pytorch.org/whl/test/cpu" + else + PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu" + fi + echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT + id: pytorch_channel + - name: Setup additional system libraries + if: startsWith( matrix.os, 'ubuntu' ) + run: | + sudo add-apt-repository multiverse + sudo apt update + sudo apt install rar unrar libssl-dev libcurl4-openssl-dev zlib1g-dev + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Setup msbuild on Windows + if: matrix.os == 'windows-latest' + uses: microsoft/setup-msbuild@v1.1 + - name: Set up Visual Studio shell + if: matrix.os == 'windows-latest' + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 + - name: Check out source repository + uses: actions/checkout@v4 + with: + submodules: recursive + - name: Install dependencies + run: | + pip3 install -r requirements.txt + pip3 install networkx + pip3 install --pre torch --index-url "${{ steps.pytorch_channel.outputs.value }}" + pip3 install cmake ninja + echo "/home/runner/.local/bin" >> $GITHUB_PATH + - name: Build TorchData + run: | + pip3 install . + env: + BUILD_S3: 0 + - name: Install test requirements + run: pip3 install -r test/requirements.txt + - name: Run Node tests with pytest - dataloader + if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} + run: pytest --durations=0 --no-header -v test/nodes/ diff --git a/test/nodes/__init__.py b/test/nodes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/nodes/test_batch.py b/test/nodes/test_batch.py new file mode 100644 index 000000000..1d6c0ade8 --- /dev/null +++ b/test/nodes/test_batch.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import testslide +import torch +from torchdata.nodes.batch import Batcher + +from .utils import MockSource + + +class TestBatcher(testslide.TestCase): + def test_batcher(self) -> None: + batch_size = 6 + src = MockSource(num_samples=20) + node = Batcher(src, batch_size=batch_size, drop_last=True) + + results = list(node) + self.assertEqual(len(results), 3) + for i in range(3): + for j in range(batch_size): + self.assertEqual(results[i][j]["step"], i * batch_size + j) + self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) + self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") + + def test_batcher_drop_last_false(self) -> None: + batch_size = 6 + src = MockSource(num_samples=20) + root = Batcher(src, batch_size=batch_size, drop_last=False) + + results = list(root) + self.assertEqual(len(results), 4) + for i in range(4): + n = batch_size if i < 3 else 2 + for j in range(n): + self.assertEqual(results[i][j]["step"], i * batch_size + j) + self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) + self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") diff --git a/test/nodes/test_map.py b/test/nodes/test_map.py new file mode 100644 index 000000000..551876e5c --- /dev/null +++ b/test/nodes/test_map.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import List + +import testslide +from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA +from torchdata.nodes.batch import Batcher +from torchdata.nodes.map import Mapper, ParallelMapper +from torchdata.nodes.pin_memory import PinMemory +from torchdata.nodes.prefetch import Prefetcher + +from .utils import MockSource, RandomSleepUdf, udf_raises + + +class TestMap(testslide.TestCase): + def _test_exception_handling_mapper(self, pin_memory, method): + batch_size = 6 + multiprocessing_context = None if IS_WINDOWS else "forkserver" + src = MockSource(num_samples=20) + node = Batcher(src, batch_size=batch_size) + node = ParallelMapper( + node, + udf_raises, + num_workers=2, + method=method, + multiprocessing_context=multiprocessing_context, + ) + node = Mapper(node, udf_raises) + if pin_memory: + node = PinMemory(node) + node = Prefetcher(node, prefetch_factor=2) + + with self.assertRaisesRegex(ValueError, "test exception"): + print(list(node)) + + def test_exception_handling_mapper(self): + self._test_exception_handling_mapper(False, "thread") + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_exception_handling_mapper_cuda(self): + self._test_exception_handling_mapper(True, "thread") + + def test_exception_handling_mapper_multiprocess(self): + self._test_exception_handling_mapper(False, "process") + + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + def test_exception_handling_mapper_multiprocess_cuda(self): + self._test_exception_handling_mapper(True, "process") + + def _test_map(self, in_order, method) -> None: + batch_size = 6 + n = 80 + multiprocessing_context = None if IS_WINDOWS else "forkserver" + src = MockSource(num_samples=n) + node = Batcher(src, batch_size=batch_size, drop_last=False) + node = ParallelMapper( + node, + RandomSleepUdf(), + num_workers=4, + in_order=in_order, + method=method, + multiprocessing_context=multiprocessing_context, + ) + node = Prefetcher(node, prefetch_factor=2) + + results: List[List[dict]] = [[], []] + for epoch in range(2): + for batch in node: + print(f"{epoch=}, {batch=}") + results[epoch].extend(batch) + + for result in results: + self.assertEqual(len(result), n, epoch) + if in_order: + for i, row in enumerate(result): + self.assertEqual(row["step"], i, epoch) + self.assertEqual(row["test_tensor"].item(), i, epoch) + self.assertEqual(row["test_str"], f"str_{i}", epoch) + else: + self.assertEqual({row["step"] for row in result}, set(range(n))), epoch + self.assertEqual( + {row["test_tensor"].item() for row in result}, + set(range(n)), + epoch, + ) + self.assertEqual( + {row["test_str"] for row in result}, + {f"str_{i}" for i in range(n)}, + epoch, + ) + + def test_in_order_threads(self): + self._test_map(True, "thread") + + def test_out_of_order_threads(self): + self._test_map(False, "thread") + + def test_in_order_process(self): + self._test_map(True, "process") + + def test_out_of_order_process(self): + self._test_map(False, "process") diff --git a/test/nodes/test_pin_memory.py b/test/nodes/test_pin_memory.py new file mode 100644 index 000000000..f92ae1769 --- /dev/null +++ b/test/nodes/test_pin_memory.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import testslide +import torch + +from torch.testing._internal.common_utils import TEST_CUDA + +from torchdata.nodes.batch import Batcher +from torchdata.nodes.map import Mapper +from torchdata.nodes.pin_memory import PinMemory +from torchdata.nodes.prefetch import Prefetcher + +from .utils import Collate, IterInitError, MockSource + + +@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") +class TestPinMemory(testslide.TestCase): + def test_pin_memory(self) -> None: + batch_size = 6 + src = MockSource(num_samples=20) + node = Batcher(src, batch_size=batch_size) + node = Mapper(node, Collate()) + node = PinMemory(node) + root = Prefetcher(node, prefetch_factor=2) + + # 2 epochs + for epoch in range(2): + results = list(root) + self.assertEqual(len(results), 3, epoch) + for i in range(3): + for j in range(batch_size): + self.assertEqual(results[i]["step"][j], i * batch_size + j) + self.assertEqual(results[i]["test_tensor"][j], torch.tensor([i * batch_size + j])) + self.assertEqual(results[i]["test_str"][j], f"str_{i * batch_size + j}") + + def test_exception_handling(self): + class PinMemoryFails: + def pin_memory(self): + raise ValueError("test exception") + + batch_size = 6 + src = MockSource(num_samples=20) + node = Mapper(src, lambda x: dict(fail=PinMemoryFails(), **x)) + node = Batcher(node, batch_size=batch_size) + node = Mapper(node, Collate()) + node = PinMemory(node) + root = Prefetcher(node, prefetch_factor=2) + + with self.assertRaisesRegex(ValueError, "test exception"): + list(root) + + def test_iter_init_error(self): + node = IterInitError() + node = PinMemory(node) + root = Prefetcher(node, prefetch_factor=2) + + with self.assertRaisesRegex(ValueError, "Iter Init Error"): + list(root) diff --git a/test/nodes/test_prefetch.py b/test/nodes/test_prefetch.py new file mode 100644 index 000000000..82d7d9545 --- /dev/null +++ b/test/nodes/test_prefetch.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import testslide +import torch +from torchdata.nodes.batch import Batcher +from torchdata.nodes.prefetch import Prefetcher + +from .utils import IterInitError, MockSource + + +class TestPrefetcher(testslide.TestCase): + def test_prefetcher(self) -> None: + batch_size = 6 + src = MockSource(num_samples=20) + node = Batcher(src, batch_size=batch_size, drop_last=True) + root = Prefetcher(node, prefetch_factor=2) + + # Test multi epoch shutdown and restart + for _ in range(2): + results = list(root) + self.assertEqual(len(results), 3) + for i in range(3): + for j in range(batch_size): + self.assertEqual(results[i][j]["step"], i * batch_size + j) + self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) + self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") + + def test_iter_init_error(self): + node = IterInitError() + root = Prefetcher(node, prefetch_factor=2) + + with self.assertRaisesRegex(ValueError, "Iter Init Error"): + list(root) diff --git a/test/nodes/utils.py b/test/nodes/utils.py new file mode 100644 index 000000000..3a44e8772 --- /dev/null +++ b/test/nodes/utils.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import random +import time +from typing import Iterator + +import torch +from torchdata.nodes import BaseNode + + +class MockSource(BaseNode[dict]): + def __init__(self, num_samples: int) -> None: + self.num_samples = num_samples + + def iterator(self) -> Iterator[dict]: + for i in range(self.num_samples): + yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} + + +def udf_raises(item): + raise ValueError("test exception") + + +class RandomSleepUdf: + def __init__(self, sleep_max_sec: float = 0.01) -> None: + self.sleep_max_sec = sleep_max_sec + + def __call__(self, x): + time.sleep(random.random() * self.sleep_max_sec) + return x + + +class Collate: + def __call__(self, x): + result = {} + for k in x[0].keys(): + result[k] = [i[k] for i in x] + return result + + +class IterInitError(BaseNode[int]): + def __init__(self, msg: str = "Iter Init Error") -> None: + self.msg = msg + + def iterator(self) -> Iterator[int]: + raise ValueError(self.msg) diff --git a/test/requirements.txt b/test/requirements.txt index 169e812cb..904437976 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,4 +1,5 @@ pytest +testslide expecttest fsspec numpy<2 diff --git a/test/stateful_dataloader/test_hugging_face.py b/test/stateful_dataloader/test_hugging_face.py index 34980cc5f..cf793ef7b 100644 --- a/test/stateful_dataloader/test_hugging_face.py +++ b/test/stateful_dataloader/test_hugging_face.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import itertools from datasets.info import DatasetInfo diff --git a/test/stateful_dataloader/test_sampler.py b/test/stateful_dataloader/test_sampler.py index f1dd5a734..665cf5a36 100644 --- a/test/stateful_dataloader/test_sampler.py +++ b/test/stateful_dataloader/test_sampler.py @@ -4,10 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import math import unittest -import warnings import torch @@ -15,7 +13,7 @@ from torch.utils.data import Dataset -from torchdata.stateful_dataloader import Stateful, StatefulDataLoader +from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 5bd1e6161..639a46310 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -9,7 +9,6 @@ import unittest from copy import deepcopy -from enum import Enum from typing import Iterator import torch diff --git a/torchdata/nodes/__init__.py b/torchdata/nodes/__init__.py new file mode 100644 index 000000000..e4b723f91 --- /dev/null +++ b/torchdata/nodes/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .base_node import BaseNode, T +from .batch import Batcher +from .map import Mapper, ParallelMapper +from .pin_memory import PinMemory +from .prefetch import Prefetcher + + +__all__ = [ + "BaseNode", + "Batcher", + "Mapper", + "Prefetcher", + "ParallelMapper", + "PinMemory", + "T", +] diff --git a/torchdata/nodes/_apply_udf.py b/torchdata/nodes/_apply_udf.py new file mode 100644 index 000000000..019ed7b44 --- /dev/null +++ b/torchdata/nodes/_apply_udf.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import multiprocessing.synchronize as python_mp_synchronize +import queue +import threading +from typing import Callable, Union + +import torch +import torch.multiprocessing as mp + +from torch._utils import ExceptionWrapper + +from .constants import QUEUE_TIMEOUT + + +def _apply_udf( + worker_id: int, + in_q: Union[queue.Queue, mp.Queue], + out_q: Union[queue.Queue, mp.Queue], + udf: Callable, + stop_event: Union[threading.Event, python_mp_synchronize.Event], +): + """_apply_udf assumes in_q emits tuples of (x, idx) where x is the + payload, idx is the index of the result, potentially used for maintaining + ordered outputs. For every input it pulls, a tuple (y, idx) is put on the out_q + where the output of udf(x), an ExceptionWrapper, or StopIteration (if it pulled + StopIteration from in_q). + """ + torch.set_num_threads(1) + while True: + if stop_event.is_set() and in_q.empty(): + break + + try: + item, idx = in_q.get(block=True, timeout=QUEUE_TIMEOUT) + except queue.Empty: + continue + + if isinstance(item, ExceptionWrapper): + out_q.put((item, idx)) + elif isinstance(item, StopIteration): + out_q.put((item, idx)) + else: + try: + y = udf(item) + except Exception: + y = ExceptionWrapper(where="in _apply_udf") + + out_q.put((y, idx), block=False) diff --git a/torchdata/nodes/_populate_queue.py b/torchdata/nodes/_populate_queue.py new file mode 100644 index 000000000..c7e288c7c --- /dev/null +++ b/torchdata/nodes/_populate_queue.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import queue +import threading +from dataclasses import dataclass +from typing import Iterable + +from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper + +from .constants import QUEUE_TIMEOUT + + +@dataclass +class _MonotonicIndex: + initial: int = 0 + + def __post_init__(self): + self._idx = self.initial + + def get(self) -> int: + idx = self._idx + self._idx += 1 + return idx + + +def _populate_queue( + source: Iterable, + q: queue.Queue, + semaphore: threading.BoundedSemaphore, + stop_event: threading.Event, + add_index: bool = False, +): + """_populate_queue calls `iter(source)` to get an iterator `it`, waits for semaphore.acquire, + and puts its outputs onto q. It never releases the sempahore. It continues to put items on the + q as long as it can acquire the sempahore, stop_event is not set, and StopIteration has not + been thrown by the `it`. + + If add_index = True, this function will always put tuples of (x, idx) on the q where idx + starts from 0 and is monotonically increasing. x may be the output of next(it), StopIteration, + or an ExceptionWrapper. + + If there is an exception raised during the call to `iter(source)`, this function does not + wait to acquire sempahore before putting StartupExceptionWrapper on q. + + Note: this is only intended to be used by a single thread at once. Each instance + creates its own iter for source so if this is called with multiple threads, you may get + duplicates if source is not sharded properly. + """ + + # Include a monotonic index starting from 0 to each item in the queue + idx = _MonotonicIndex() + + def _put(item, block: bool = True): + if add_index: + q.put((item, idx.get()), block=block) + else: + q.put(item, block=block) + + try: + src_iter = iter(source) + except Exception: + e = StartupExceptionWrapper(where="in _populate_queue startup for device") + _put(e) + return + + while not stop_event.is_set(): + if not semaphore.acquire(blocking=True, timeout=QUEUE_TIMEOUT): + continue + try: + item = next(src_iter) # FIXME: This may hang! + except StopIteration as e: + _put(e) + break + except Exception: + item = ExceptionWrapper(where="in _populate_queue") + try: + _put(item, block=False) # Semaphore should prevent this from throwing + except queue.Full: + raise RuntimeError("Queue should not be full") diff --git a/torchdata/nodes/base_node.py b/torchdata/nodes/base_node.py new file mode 100644 index 000000000..10c7825b3 --- /dev/null +++ b/torchdata/nodes/base_node.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Generic, Iterator, TypeVar + +import torch.utils.data + + +T = TypeVar("T") + + +class BaseNode(torch.utils.data.IterableDataset, Generic[T]): + def iterator(self) -> Iterator[T]: + """Override this method to implement the iterator. + Iterators are expected to raise StopIteration to signal + end of iteration, so they can be used in for loops. + Generators just need to return, as usual. + """ + raise NotImplementedError() + + def __iter__(self) -> "_EagerIter[T]": + return _EagerIter(self) + + +class _EagerIter(Iterator[T]): + """ + Basic iterator which will runs next-calls eagerly + """ + + def __init__(self, parent: BaseNode[T]): + self.parent = parent + self.it = self.parent.iterator() + + def __next__(self): + return next(self.it) diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py new file mode 100644 index 000000000..482368135 --- /dev/null +++ b/torchdata/nodes/batch.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterator, List + +from torchdata.nodes import BaseNode, T + + +class Batcher(BaseNode[List[T]]): + def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True): + self.source = source + self.batch_size = batch_size + self.drop_last = drop_last + + def iterator(self) -> Iterator[List[T]]: + batch = [] + for item in self.source: + batch.append(item) + if len(batch) == self.batch_size: + yield batch + batch = [] + + if len(batch) and not self.drop_last: + yield batch diff --git a/torchdata/nodes/constants.py b/torchdata/nodes/constants.py new file mode 100644 index 000000000..f1fce5d62 --- /dev/null +++ b/torchdata/nodes/constants.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +QUEUE_TIMEOUT = 0.1 diff --git a/torchdata/nodes/exception_wrapper.py b/torchdata/nodes/exception_wrapper.py new file mode 100644 index 000000000..95d97a73a --- /dev/null +++ b/torchdata/nodes/exception_wrapper.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch._utils import ExceptionWrapper + + +class StartupExceptionWrapper(ExceptionWrapper): + pass diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py new file mode 100644 index 000000000..f8e12c351 --- /dev/null +++ b/torchdata/nodes/map.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, TypeVar, Union + +import torch.multiprocessing as mp +from torchdata.nodes import BaseNode, T +from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper + +from ._apply_udf import _apply_udf + +from ._populate_queue import _populate_queue + +from .constants import QUEUE_TIMEOUT + + +# We define this protocol for type checking +class _MultiprocessContext(Protocol): + def Process(self, *args, **kwargs): + ... + + def Event(self, *args, **kwargs): + ... + + def Queue(self, *args, **kwargs): + ... + + +X = TypeVar("X") + + +class Mapper(BaseNode[T]): + def __init__( + self, + source: BaseNode[X], + map_fn: Callable[[X], T], + ): + self.source = source + self.map_fn = map_fn + + def iterator(self) -> Iterator[T]: + for item in self.source: + yield self.map_fn(item) + + +def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_event: threading.Event): + buffer: Dict[int, Any] = {} + cur_idx = 0 + while not stop_event.is_set(): + try: + item, idx = in_q.get(block=True, timeout=QUEUE_TIMEOUT) + except queue.Empty: + continue + if idx == cur_idx: + out_q.put((item, cur_idx), block=False) + cur_idx += 1 + else: + if idx in buffer: + # This is the easiest way to create an exception wrapper + try: + raise ValueError(f"Duplicate index {idx=}, {buffer.keys()=}, {item=}") + except Exception: + item = ExceptionWrapper(where="in _sort_worker") + out_q.put((item, idx), block=False) + break + buffer[idx] = item + while cur_idx in buffer: + out_q.put((buffer.pop(cur_idx), cur_idx), block=False) + cur_idx += 1 + + +class _ParallelMapperIter(Iterator[T]): + """_ParallelMapperIter will start at least two threads, one running + _populate_queue, and one for _apply_udf. If in_order == True, a + third thread will be started to read from _apply_udf's result q + and block the output_q until the appropriate in_order element is available, + buffering outputs as needed. + + A BoundedSemaphore with initial value max_concurrent will limit the number + of items in flight, and in all of the queues. + """ + + def __init__( + self, + source: BaseNode[X], + map_fn: Callable[[X], T], + num_workers: int, + in_order: bool, + method: Literal["thread", "process"], + mp_context: _MultiprocessContext, + max_concurrent: Optional[int], + ): + self.source = source + self.map_fn = map_fn + self.num_workers = num_workers + self.in_order = in_order + self.method = method + self.mp_context = mp_context + + self._in_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue() + self._intermed_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue() + self._max_tasks = 2 * self.num_workers if max_concurrent is None else max_concurrent + self._sem = threading.BoundedSemaphore(value=self._max_tasks) + + self._done = False + + self._stop = threading.Event() + self._mp_stop = mp_context.Event() + + self._read_thread = threading.Thread( + target=_populate_queue, + args=(self.source, self._in_q, self._sem, self._stop, True), + ) + self._map_threads: List[Union[threading.Thread, mp.Process]] = [] + for worker_id in range(self.num_workers): + args = ( + worker_id, + self._in_q, + self._intermed_q, + self.map_fn, + self._stop if self.method == "thread" else self._mp_stop, + ) + self._map_threads.append( + threading.Thread(target=_apply_udf, args=args) + if self.method == "thread" + else mp_context.Process(target=_apply_udf, args=args) + ) + self._sort_q: queue.Queue = queue.Queue() + self._sort_thread = threading.Thread(target=_sort_worker, args=(self._intermed_q, self._sort_q, self._stop)) + + self._out_q = self._intermed_q + if self.in_order: + self._out_q = self._sort_q + + self._read_thread.start() + for t in self._map_threads: + t.start() + if self.in_order: + self._sort_thread.start() + + def __iter__(self): + return self + + def __next__(self): + while True: + if self._stop.is_set(): + raise StopIteration() + elif self._done and self._sem._value == self._max_tasks: + # Don't stop if we still have items in the queue + self._stop.set() + self._mp_stop.set() + raise StopIteration() + try: + item, idx = self._out_q.get(block=True, timeout=QUEUE_TIMEOUT) + except queue.Empty: + continue + + if isinstance(item, StopIteration): + self._done = True + self._sem.release() + # Make sure queues are flushed before returning early + continue + elif isinstance(item, ExceptionWrapper): + if not isinstance(item, StartupExceptionWrapper): + self._sem.release() + item.reraise() + else: + self._sem.release() + return item + + def __del__(self): + self._shutdown() + + def _shutdown(self): + self._stop.set() + self._mp_stop.set() + if self._read_thread.is_alive(): + self._read_thread.join(timeout=QUEUE_TIMEOUT) + if self._sort_thread.is_alive(): + self._sort_thread.join(timeout=QUEUE_TIMEOUT) + for t in self._map_threads: + if t.is_alive(): + t.join(timeout=QUEUE_TIMEOUT) + + +class ParallelMapper(BaseNode[T]): + """ParallelMapper executes map_fn in parallel either in num_workers threads or + processes. For processes, multiprocessing_context can be spawn, forkserver, fork, + or None (chooses OS default). At most max_concurrent items will be either processed + or in the iterator's output queue, to limit CPU and Memory utilization. If None + (default) the value will be 2 * num_workers. + + At most one iter() is created from source, and at most one thread will call + next() on it at once. + + If in_order is true, the iterator will return items in the order from which they arrive + from source's iterator, potentially blocking even if other items are available. + """ + + def __init__( + self, + source: BaseNode[X], + map_fn: Callable[[X], T], + num_workers: int, + in_order: bool = True, + method: Literal["thread", "process"] = "thread", + multiprocessing_context: Optional[str] = None, + max_concurrent: Optional[int] = None, + ): + assert method in ["thread", "process"] + self.source = source + self.map_fn = map_fn + self.num_workers = num_workers + self.in_order = in_order + self.method = method + self.multiprocessing_context = multiprocessing_context + self._mp_context: Any = mp + if self.method == "process" and self.multiprocessing_context is not None: + self._mp_context = mp.get_context(self.multiprocessing_context) + + if max_concurrent is not None: + if not isinstance(max_concurrent, int) and max_concurrent > num_workers: + raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!") + self.max_concurrent = max_concurrent + + def iterator(self) -> Iterator[T]: + return _ParallelMapperIter( + source=self.source, + map_fn=self.map_fn, + num_workers=self.num_workers, + in_order=self.in_order, + method=self.method, + mp_context=self._mp_context, + max_concurrent=self.max_concurrent, + ) + + +_WorkerType = Callable[[BaseNode, queue.Queue, threading.BoundedSemaphore, threading.Event], None] + + +class _SingleThreadedMapper(Iterator[T]): + """Utility Iterator for performing mapping with a single thread. + Because only a single thread is used, we don't need an input queue to guard + against multiple threads reading from the same iterator. This is used for + Prefetcher and PinMemory. + + A thread is started on __init__ and stopped on __del__/_shutdown. + The thread runs _populate_queue, which acquires a BoundedSemaphore with initial value + of `prefetch_factor`. + + When next() is called on this iterator, it will block until an item is available on _q. + Next will perform the following depending on what is pulled from the q: + - StopIteration: raise StopIteration. Any subsequent next() calls will also raise StopIteration + - ExceptionWrapper: call reraise() on the exception wraper + - any other item: return the item + + A Bounded semaphore is used to limit concurrency and memory utilization. + If N items have been pulled from the source, and M items have been yielded by this iterator, + we maintain the invariant that semaphore.value + (N - M) == prefetch_factor (modulo + non-atomicness of operations). + + _populate_queue calls semaphore.acquire. When we pull an item from the queue, we + call semaphore.release (unless it's a StartupExceptionWrapper, because _populate_queue + does not acquire sempahores in this case). All outstanding items are either being + processed in _populate_queue, in the _q, or about to be returned by an in-flight next() call. + """ + + def __init__(self, source: BaseNode[T], prefetch_factor: int, worker: _WorkerType): + self.source = source + self.prefetch_factor = prefetch_factor + self.worker = worker + + self._q: queue.Queue = queue.Queue() + self._sem = threading.BoundedSemaphore(value=prefetch_factor) + self._stop_event = threading.Event() + + self._thread = threading.Thread( + target=self.worker, + args=(self.source, self._q, self._sem, self._stop_event), + ) + self._thread.start() + self._stopped = False + + def __iter__(self) -> Iterator[T]: + return self + + def __next__(self): + if self._stopped: + raise StopIteration() + + while True: + try: + item = self._q.get(block=True, timeout=QUEUE_TIMEOUT) + break + except queue.Empty: + continue + + if isinstance(item, StopIteration): + self._stopped = True + self._sem.release() + self._stop_event.set() + raise item + elif isinstance(item, ExceptionWrapper): + self._stopped = True + if not isinstance(item, StartupExceptionWrapper): + # We don't need to release for startup exceptions + self._sem.release() + self._stop_event.set() + item.reraise() + else: + self._sem.release() + return item + + def __del__(self): + self._shutdown() + + def _shutdown(self): + self._stop_event.set() + self._thread.join(timeout=QUEUE_TIMEOUT) diff --git a/torchdata/nodes/pin_memory.py b/torchdata/nodes/pin_memory.py new file mode 100644 index 000000000..c7d0b57e2 --- /dev/null +++ b/torchdata/nodes/pin_memory.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import queue +import threading + +from typing import Iterator, Optional, Union + +import torch +import torch.multiprocessing + +from torch.utils.data._utils.pin_memory import pin_memory +from torchdata.nodes import BaseNode, T + +from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper +from torchdata.nodes.map import _SingleThreadedMapper + + +def _pin_memory_loop( + source: BaseNode, + q: queue.Queue, + semaphore: threading.BoundedSemaphore, + stop_event: threading.Event, + device_id: Union[int, str], + device: Optional[str], +): + # this is fork of from torch.utils.data._utils.pin_memory import _pin_memory_loop + # to remove the index tuples + + # This setting is thread local, and prevents the copy in pin_memory from + # consuming all CPU cores. + try: + torch.set_num_threads(1) + + torch.multiprocessing._set_thread_name("pt_data_pin") + + if device == "cuda": + torch.cuda.set_device(device_id) + elif device == "xpu": + torch.xpu.set_device(device_id) # type: ignore[attr-defined] + elif device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + custom_device_mod.set_device(device_id) + + src_iter = iter(source) + except Exception: + e = StartupExceptionWrapper(where=f"in _pin_memory_loop startup for device {device_id}") + q.put(e) + return + + while not stop_event.is_set(): + if not semaphore.acquire(blocking=True, timeout=0.1): + continue + try: + item = next(src_iter) + item = pin_memory(item, device) + q.put(item, block=False) + except StopIteration as e: + item = e + q.put(item, block=False) + break + except Exception: + item = ExceptionWrapper(where=f"in _pin_memory_loop for device {device_id}") + q.put(item, block=False) + break + + +class PinMemory(BaseNode[T]): + def __init__( + self, + source: BaseNode[T], + pin_memory_device: str = "", + ): + self.source = source + self._pin_memory = torch.cuda.is_available() + if len(pin_memory_device) == 0: + self._pin_memory_device = None + else: + self._pin_memory_device = pin_memory_device + + if self._pin_memory_device == "xpu": + self._current_device = torch.xpu.current_device() # type: ignore[attr-defined] + elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + self._current_device = custom_device_mod.current_device() + else: + self._current_device = torch.cuda.current_device() + + def iterator(self) -> Iterator[T]: + return _SingleThreadedMapper( + source=self.source, + prefetch_factor=1, + worker=functools.partial( + _pin_memory_loop, + device_id=self._current_device, + device=self._pin_memory_device, + ), + ) diff --git a/torchdata/nodes/prefetch.py b/torchdata/nodes/prefetch.py new file mode 100644 index 000000000..a2715bbd4 --- /dev/null +++ b/torchdata/nodes/prefetch.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterator + +from torchdata.nodes import BaseNode, T + +from torchdata.nodes.map import _SingleThreadedMapper + +from ._populate_queue import _populate_queue + + +class Prefetcher(BaseNode[T]): + def __init__(self, source: BaseNode[T], prefetch_factor: int): + self.source = source + self.prefetch_factor = prefetch_factor + + def iterator(self) -> Iterator[T]: + return _SingleThreadedMapper( + source=self.source, + prefetch_factor=self.prefetch_factor, + worker=_populate_queue, + )