-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Andrewkh/polylithic initial commit (#1335)
Initial commit of torchdata.nodes (polylithic) Includes multi-process map, multi-thread map, batching, prefetching, pin-memory, and in-order / out-of-order support.
- Loading branch information
Showing
21 changed files
with
1,080 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
name: Run Nodes Tests | ||
on: | ||
push: | ||
branches: | ||
- main | ||
- release/* | ||
tags: | ||
pull_request: | ||
types: [opened, synchronize, reopened, labeled] | ||
branches: | ||
- main | ||
# For PR created by ghstack | ||
- gh/*/*/base | ||
- release/* | ||
|
||
jobs: | ||
test: | ||
if: ${{ github.repository_owner == 'pytorch' }} | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: | ||
- macos-latest | ||
- ubuntu-latest | ||
- windows-latest | ||
python-version: | ||
- 3.9 | ||
- "3.10" | ||
- "3.11" | ||
- "3.12" | ||
steps: | ||
- name: Get PyTorch Channel | ||
shell: bash | ||
run: | | ||
if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then | ||
PT_CHANNEL="https://download.pytorch.org/whl/test/cpu" | ||
else | ||
PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu" | ||
fi | ||
echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT | ||
id: pytorch_channel | ||
- name: Setup additional system libraries | ||
if: startsWith( matrix.os, 'ubuntu' ) | ||
run: | | ||
sudo add-apt-repository multiverse | ||
sudo apt update | ||
sudo apt install rar unrar libssl-dev libcurl4-openssl-dev zlib1g-dev | ||
- name: Setup Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Setup msbuild on Windows | ||
if: matrix.os == 'windows-latest' | ||
uses: microsoft/[email protected] | ||
- name: Set up Visual Studio shell | ||
if: matrix.os == 'windows-latest' | ||
uses: egor-tensin/vs-shell@v2 | ||
with: | ||
arch: x64 | ||
- name: Check out source repository | ||
uses: actions/checkout@v4 | ||
with: | ||
submodules: recursive | ||
- name: Install dependencies | ||
run: | | ||
pip3 install -r requirements.txt | ||
pip3 install networkx | ||
pip3 install --pre torch --index-url "${{ steps.pytorch_channel.outputs.value }}" | ||
pip3 install cmake ninja | ||
echo "/home/runner/.local/bin" >> $GITHUB_PATH | ||
- name: Build TorchData | ||
run: | | ||
pip3 install . | ||
env: | ||
BUILD_S3: 0 | ||
- name: Install test requirements | ||
run: pip3 install -r test/requirements.txt | ||
- name: Run Node tests with pytest - dataloader | ||
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} | ||
run: pytest --durations=0 --no-header -v test/nodes/ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import testslide | ||
import torch | ||
from torchdata.nodes.batch import Batcher | ||
|
||
from .utils import MockSource | ||
|
||
|
||
class TestBatcher(testslide.TestCase): | ||
def test_batcher(self) -> None: | ||
batch_size = 6 | ||
src = MockSource(num_samples=20) | ||
node = Batcher(src, batch_size=batch_size, drop_last=True) | ||
|
||
results = list(node) | ||
self.assertEqual(len(results), 3) | ||
for i in range(3): | ||
for j in range(batch_size): | ||
self.assertEqual(results[i][j]["step"], i * batch_size + j) | ||
self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) | ||
self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") | ||
|
||
def test_batcher_drop_last_false(self) -> None: | ||
batch_size = 6 | ||
src = MockSource(num_samples=20) | ||
root = Batcher(src, batch_size=batch_size, drop_last=False) | ||
|
||
results = list(root) | ||
self.assertEqual(len(results), 4) | ||
for i in range(4): | ||
n = batch_size if i < 3 else 2 | ||
for j in range(n): | ||
self.assertEqual(results[i][j]["step"], i * batch_size + j) | ||
self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) | ||
self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from typing import List | ||
|
||
import testslide | ||
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA | ||
from torchdata.nodes.batch import Batcher | ||
from torchdata.nodes.map import Mapper, ParallelMapper | ||
from torchdata.nodes.pin_memory import PinMemory | ||
from torchdata.nodes.prefetch import Prefetcher | ||
|
||
from .utils import MockSource, RandomSleepUdf, udf_raises | ||
|
||
|
||
class TestMap(testslide.TestCase): | ||
def _test_exception_handling_mapper(self, pin_memory, method): | ||
batch_size = 6 | ||
multiprocessing_context = None if IS_WINDOWS else "forkserver" | ||
src = MockSource(num_samples=20) | ||
node = Batcher(src, batch_size=batch_size) | ||
node = ParallelMapper( | ||
node, | ||
udf_raises, | ||
num_workers=2, | ||
method=method, | ||
multiprocessing_context=multiprocessing_context, | ||
) | ||
node = Mapper(node, udf_raises) | ||
if pin_memory: | ||
node = PinMemory(node) | ||
node = Prefetcher(node, prefetch_factor=2) | ||
|
||
with self.assertRaisesRegex(ValueError, "test exception"): | ||
print(list(node)) | ||
|
||
def test_exception_handling_mapper(self): | ||
self._test_exception_handling_mapper(False, "thread") | ||
|
||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | ||
def test_exception_handling_mapper_cuda(self): | ||
self._test_exception_handling_mapper(True, "thread") | ||
|
||
def test_exception_handling_mapper_multiprocess(self): | ||
self._test_exception_handling_mapper(False, "process") | ||
|
||
@unittest.skipIf(not TEST_CUDA, "CUDA not found") | ||
def test_exception_handling_mapper_multiprocess_cuda(self): | ||
self._test_exception_handling_mapper(True, "process") | ||
|
||
def _test_map(self, in_order, method) -> None: | ||
batch_size = 6 | ||
n = 80 | ||
multiprocessing_context = None if IS_WINDOWS else "forkserver" | ||
src = MockSource(num_samples=n) | ||
node = Batcher(src, batch_size=batch_size, drop_last=False) | ||
node = ParallelMapper( | ||
node, | ||
RandomSleepUdf(), | ||
num_workers=4, | ||
in_order=in_order, | ||
method=method, | ||
multiprocessing_context=multiprocessing_context, | ||
) | ||
node = Prefetcher(node, prefetch_factor=2) | ||
|
||
results: List[List[dict]] = [[], []] | ||
for epoch in range(2): | ||
for batch in node: | ||
print(f"{epoch=}, {batch=}") | ||
results[epoch].extend(batch) | ||
|
||
for result in results: | ||
self.assertEqual(len(result), n, epoch) | ||
if in_order: | ||
for i, row in enumerate(result): | ||
self.assertEqual(row["step"], i, epoch) | ||
self.assertEqual(row["test_tensor"].item(), i, epoch) | ||
self.assertEqual(row["test_str"], f"str_{i}", epoch) | ||
else: | ||
self.assertEqual({row["step"] for row in result}, set(range(n))), epoch | ||
self.assertEqual( | ||
{row["test_tensor"].item() for row in result}, | ||
set(range(n)), | ||
epoch, | ||
) | ||
self.assertEqual( | ||
{row["test_str"] for row in result}, | ||
{f"str_{i}" for i in range(n)}, | ||
epoch, | ||
) | ||
|
||
def test_in_order_threads(self): | ||
self._test_map(True, "thread") | ||
|
||
def test_out_of_order_threads(self): | ||
self._test_map(False, "thread") | ||
|
||
def test_in_order_process(self): | ||
self._test_map(True, "process") | ||
|
||
def test_out_of_order_process(self): | ||
self._test_map(False, "process") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import testslide | ||
import torch | ||
|
||
from torch.testing._internal.common_utils import TEST_CUDA | ||
|
||
from torchdata.nodes.batch import Batcher | ||
from torchdata.nodes.map import Mapper | ||
from torchdata.nodes.pin_memory import PinMemory | ||
from torchdata.nodes.prefetch import Prefetcher | ||
|
||
from .utils import Collate, IterInitError, MockSource | ||
|
||
|
||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | ||
class TestPinMemory(testslide.TestCase): | ||
def test_pin_memory(self) -> None: | ||
batch_size = 6 | ||
src = MockSource(num_samples=20) | ||
node = Batcher(src, batch_size=batch_size) | ||
node = Mapper(node, Collate()) | ||
node = PinMemory(node) | ||
root = Prefetcher(node, prefetch_factor=2) | ||
|
||
# 2 epochs | ||
for epoch in range(2): | ||
results = list(root) | ||
self.assertEqual(len(results), 3, epoch) | ||
for i in range(3): | ||
for j in range(batch_size): | ||
self.assertEqual(results[i]["step"][j], i * batch_size + j) | ||
self.assertEqual(results[i]["test_tensor"][j], torch.tensor([i * batch_size + j])) | ||
self.assertEqual(results[i]["test_str"][j], f"str_{i * batch_size + j}") | ||
|
||
def test_exception_handling(self): | ||
class PinMemoryFails: | ||
def pin_memory(self): | ||
raise ValueError("test exception") | ||
|
||
batch_size = 6 | ||
src = MockSource(num_samples=20) | ||
node = Mapper(src, lambda x: dict(fail=PinMemoryFails(), **x)) | ||
node = Batcher(node, batch_size=batch_size) | ||
node = Mapper(node, Collate()) | ||
node = PinMemory(node) | ||
root = Prefetcher(node, prefetch_factor=2) | ||
|
||
with self.assertRaisesRegex(ValueError, "test exception"): | ||
list(root) | ||
|
||
def test_iter_init_error(self): | ||
node = IterInitError() | ||
node = PinMemory(node) | ||
root = Prefetcher(node, prefetch_factor=2) | ||
|
||
with self.assertRaisesRegex(ValueError, "Iter Init Error"): | ||
list(root) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import testslide | ||
import torch | ||
from torchdata.nodes.batch import Batcher | ||
from torchdata.nodes.prefetch import Prefetcher | ||
|
||
from .utils import IterInitError, MockSource | ||
|
||
|
||
class TestPrefetcher(testslide.TestCase): | ||
def test_prefetcher(self) -> None: | ||
batch_size = 6 | ||
src = MockSource(num_samples=20) | ||
node = Batcher(src, batch_size=batch_size, drop_last=True) | ||
root = Prefetcher(node, prefetch_factor=2) | ||
|
||
# Test multi epoch shutdown and restart | ||
for _ in range(2): | ||
results = list(root) | ||
self.assertEqual(len(results), 3) | ||
for i in range(3): | ||
for j in range(batch_size): | ||
self.assertEqual(results[i][j]["step"], i * batch_size + j) | ||
self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) | ||
self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") | ||
|
||
def test_iter_init_error(self): | ||
node = IterInitError() | ||
root = Prefetcher(node, prefetch_factor=2) | ||
|
||
with self.assertRaisesRegex(ValueError, "Iter Init Error"): | ||
list(root) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import random | ||
import time | ||
from typing import Iterator | ||
|
||
import torch | ||
from torchdata.nodes import BaseNode | ||
|
||
|
||
class MockSource(BaseNode[dict]): | ||
def __init__(self, num_samples: int) -> None: | ||
self.num_samples = num_samples | ||
|
||
def iterator(self) -> Iterator[dict]: | ||
for i in range(self.num_samples): | ||
yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} | ||
|
||
|
||
def udf_raises(item): | ||
raise ValueError("test exception") | ||
|
||
|
||
class RandomSleepUdf: | ||
def __init__(self, sleep_max_sec: float = 0.01) -> None: | ||
self.sleep_max_sec = sleep_max_sec | ||
|
||
def __call__(self, x): | ||
time.sleep(random.random() * self.sleep_max_sec) | ||
return x | ||
|
||
|
||
class Collate: | ||
def __call__(self, x): | ||
result = {} | ||
for k in x[0].keys(): | ||
result[k] = [i[k] for i in x] | ||
return result | ||
|
||
|
||
class IterInitError(BaseNode[int]): | ||
def __init__(self, msg: str = "Iter Init Error") -> None: | ||
self.msg = msg | ||
|
||
def iterator(self) -> Iterator[int]: | ||
raise ValueError(self.msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
pytest | ||
testslide | ||
expecttest | ||
fsspec | ||
numpy<2 | ||
|
Oops, something went wrong.