Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DistributedSampler stateful #1315

Merged
merged 16 commits into from
Aug 28, 2024
Merged

Conversation

ramanishsingh
Copy link
Contributor

Fixes #1269

Changes

  • torchdata/stateful_dataloader/sampler.py : Added new classes StatefulDistributedSampler and _StatefulDistributedSamplerIterator
  • test/stateful_dataloader/test_dataloader.py new tests for StatefulDistributedSampler

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 21, 2024
@ramanishsingh ramanishsingh changed the title Make DistributedSampling stateful Make DistributedSampler stateful Aug 21, 2024
@andrewkho
Copy link
Contributor

AI Store test can be safely ignored for now

Copy link
Contributor

@andrewkho andrewkho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, but would like to simplify the code a bit and move the tests around as well

test/stateful_dataloader/test_dataloader.py Outdated Show resolved Hide resolved
test/stateful_dataloader/test_dataloader.py Outdated Show resolved Hide resolved
@@ -1947,6 +1960,116 @@ def test_sampler_reproducibility(self):
ls[i].append(next(its[i]))
self.assertEqual(ls[0], ls[1])

def test_initialization_StatefulDistributedSampler(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move all of these tests out to a new file called test_sampler.py. You can update https://github.com/pytorch/data/blob/main/.github/workflows/stateful_dataloader_ci.yml to call it in an additional step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created here: https://github.com/pytorch/data/blob/stateful_distributedsampler/test/stateful_dataloader/test_sampler.py

Added new line here:

- name: Run StatefulDataSampler tests with pytest - datasampler

from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

dataset = self.dataset
sampler = StatefulDistributedSampler(dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing state_dict, let's have most of the tests set up with passing sampler + dataset to StatefulDataLoader so we can test that it works end-to-end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to use a dummy Collate function to easily inspect elements, check the test_state_dict.py file for examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New tests here:

def test_dataloader_state_dict(self):

self.next_yielded = None

def __iter__(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to fork the DistributedSampler.__iter__ code here instead and just update, instead of having a separate Iterator class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.indices = list(super().__iter__())

Comment on lines 136 to 162
if self.sampler.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.sampler.seed + self.sampler.epoch)
indices = torch.randperm(len(self.sampler.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.sampler.dataset))) # type: ignore[arg-type]

if not self.sampler.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.sampler.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.sampler.total_size]
assert len(indices) == self.sampler.total_size

# subsample
indices = indices[self.sampler.rank : self.sampler.total_size : self.sampler.num_replicas]
assert len(indices) == self.sampler.num_samples

self.parent_iterator = iter(indices)
self.indices = list(self.parent_iterator)
self.current_index = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to call the original code instead of forking it here?

Comment on lines 177 to 181
def state_dict(self) -> Dict[str, Any]:
return self.sampler.state_dict()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.sampler.load_state_dict(state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this both here and in the main sampler class, can we consolidate to have this in just one place?

Copy link
Contributor

@andrewkho andrewkho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of suggestions, but looks great! very nice test suite.

When you're done making changes, please run the fbcode CI for media_dataloader

test/stateful_dataloader/test_sampler.py Outdated Show resolved Hide resolved
torchdata/stateful_dataloader/sampler.py Outdated Show resolved Hide resolved
@facebook-github-bot
Copy link
Contributor

@ramanishsingh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ramanishsingh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@andrewkho andrewkho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ramanishsingh ramanishsingh merged commit 8b6e903 into main Aug 28, 2024
42 of 45 checks passed
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61772177

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make DistributedSampler stateful
3 participants