Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DistributedSampler stateful #1315

Merged
merged 16 commits into from
Aug 28, 2024
Merged
3 changes: 3 additions & 0 deletions .github/workflows/stateful_dataloader_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ jobs:
- name: Run StatefulDataLoader tests with pytest - dataloader
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_dataloader.py
- name: Run StatefulDataSampler tests with pytest - datasampler
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_sampler.py
- name: Run StatefulDataLoader tests with pytest - state_dict 0
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard0
Expand Down
2 changes: 1 addition & 1 deletion test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2580,7 +2580,7 @@ def setUp(self):
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_shuffle_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
for (s, n) in loader:
for s, n in loader:
self.assertIsInstance(s[0], str)
self.assertTrue(n.is_pinned())

Expand Down
191 changes: 191 additions & 0 deletions test/stateful_dataloader/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import math
import unittest
import warnings

import torch

from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TestCase

from torch.utils.data import Dataset

from torchdata.stateful_dataloader import Stateful, StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler


class MockDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = torch.arange(size) # Simple data that is easy to verify

def __len__(self):
return self.size

def __getitem__(self, idx):
return self.data[idx]


@unittest.skipIf(
TEST_WITH_TSAN,
"Fails with TSAN with the following error: starting new threads after multi-threaded "
"fork is not supported. Dying (set die_after_fork=0 to override)",
)
@unittest.skipIf(TEST_WITH_ASAN, "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223")
class TestDataLoader(TestCase):
def setUp(self):
super().setUp()
self.dataset = MockDataset(100)
self.persistent_workers = False

def test_initialization_StatefulDistributedSampler(self):

sampler = StatefulDistributedSampler(
self.dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False
)
self.assertEqual(sampler.dataset, self.dataset)
self.assertEqual(sampler.num_replicas, 10)
self.assertEqual(sampler.rank, 0)
self.assertFalse(sampler.shuffle)
self.assertEqual(sampler.seed, 42)
self.assertFalse(sampler.drop_last)
self.assertEqual(sampler.yielded, 0)
self.assertIsNone(sampler.next_yielded)

def test_dataloader_state_dict(self):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
dataloader = StatefulDataLoader(self.dataset, batch_size=10, sampler=sampler)
# Partial iteration over the DataLoader
iter_count = 5
for i, _ in enumerate(dataloader):
if i == iter_count - 1:
break
state_dict = dataloader.state_dict()
new_sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)

new_dataloader = StatefulDataLoader(self.dataset, batch_size=10, sampler=new_sampler)
new_dataloader.load_state_dict(state_dict)
resumed_data = []
for data in new_dataloader:
resumed_data.append(data.tolist())
expected_data = []
full_dataloader = StatefulDataLoader(self.dataset, batch_size=10, sampler=sampler)
for data in full_dataloader:
expected_data.append(data.tolist())

self.assertEqual(resumed_data, expected_data[iter_count:])

def test_sampler_state_dict(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=10, rank=0)
sampler.yielded = 5
state_dict = sampler.state_dict()
self.assertEqual(state_dict["yielded"], 5)

def test_sampler_load_state_dict(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=10, rank=0)
sampler.load_state_dict({"yielded": 3})
self.assertEqual(sampler.next_yielded, 3)
with self.assertRaises(ValueError):
sampler.load_state_dict({"yielded": -1})

def test_sampler_next_yielded(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=2, rank=0, shuffle=True, seed=42)
iterator = iter(sampler)
next(iterator) # advance the iterator
self.assertEqual(sampler.yielded, 1)
self.assertIsNone(sampler.next_yielded)
sampler.load_state_dict({StatefulDistributedSampler._YIELDED: 5})
self.assertEqual(sampler.next_yielded, 5)
iterator = iter(sampler)
next(iterator) # advance the iterator again
self.assertEqual(sampler.yielded, 6)

def test_drop_last_effect(self):
num_replicas = 3
total_samples = len(self.dataset)
expected_length_with_drop = total_samples // num_replicas
expected_length_without_drop = math.ceil(total_samples / num_replicas)

sampler_with_drop = StatefulDistributedSampler(
self.dataset, num_replicas=3, rank=0, drop_last=True, shuffle=False
)
dataloader_with_drop = StatefulDataLoader(self.dataset, sampler=sampler_with_drop)

sampler_without_drop = StatefulDistributedSampler(
self.dataset, num_replicas=3, rank=0, drop_last=False, shuffle=False
)
dataloader_without_drop = StatefulDataLoader(self.dataset, sampler=sampler_without_drop)

# Collect all indices from dataloaders
indices_with_drop = [data for batch in dataloader_with_drop for data in batch]
indices_without_drop = [data for batch in dataloader_without_drop for data in batch]

# Check the lengths of the outputs
self.assertEqual(
len(indices_with_drop),
expected_length_with_drop,
"Length with drop_last=True should match expected truncated length",
)
self.assertEqual(
len(indices_without_drop),
expected_length_without_drop,
"Length with drop_last=False should match total dataset size",
)

self.assertTrue(
len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices"
)

def test_data_order_with_shuffle(self):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=True)
indices = list(iter(sampler))
data_sampled = [self.dataset[i] for i in indices]
self.assertNotEqual(data_sampled, list(range(100)), "Data should be shuffled")

dataloader = StatefulDataLoader(self.dataset, sampler=sampler)
data_loaded = []
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")

def test_data_order_without_shuffle(self):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
indices = list(iter(sampler))
data_sampled = [self.dataset[i] for i in indices]
self.assertEqual(data_sampled, list(range(100)), "Data should not be shuffled")

batch_size = 32
dataloader = StatefulDataLoader(self.dataset, batch_size=batch_size, sampler=sampler)
data_loaded = []
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
self.assertEqual(data_loaded, list(range(100)), "Data loaded by DataLoader should be in original order")

def test_data_distribution_across_replicas(self):
num_replicas = 5
all_data = []
for rank in range(num_replicas):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=num_replicas, rank=rank, shuffle=False)
dataloader = StatefulDataLoader(self.dataset, sampler=sampler)
data_loaded = []
for batch in dataloader:
data_loaded.extend([int(x.item()) for x in batch])
all_data.extend(data_loaded)
self.assertEqual(
sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas"
)


if __name__ == "__main__":
run_tests()
39 changes: 39 additions & 0 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Any, Dict, Iterator, Optional, Sized

import torch.utils.data.sampler
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _InfiniteConstantSampler

from .stateful import Stateful
Expand Down Expand Up @@ -125,3 +127,40 @@ def __iter__(self):
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]


class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler):
_YIELDED = "yielded"

def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.yielded = 0
self.next_yielded = None

def __iter__(self):
self.yielded = 0
if self.next_yielded is not None:
self.yielded = self.next_yielded
self.next_yielded = None
it = super().__iter__()
for idx in itertools.islice(it, self.yielded, None):
self.yielded += 1
yield idx

def state_dict(self) -> Dict[str, Any]:
return {self._YIELDED: self.yielded}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self._YIELDED not in state_dict:
raise ValueError("Invalid state_dict")
if state_dict[self._YIELDED] < 0:
raise ValueError("Cannot load state_dict with negative yielded value")
self.next_yielded = state_dict[self._YIELDED]
Loading