Skip to content

Commit

Permalink
make the split_by_worker and slpit_by_rank optional
Browse files Browse the repository at this point in the history
  • Loading branch information
AlirezaSohofi committed Sep 14, 2023
1 parent 89192b4 commit 6db8b7e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
18 changes: 15 additions & 3 deletions squirrel/iterstream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,23 @@ def split_by_rank_pytorch(self, torch_dist_group: t.Optional[str] = None) -> Com

return self.compose(SplitByRank, torch_dist_group)

def to_torch_iterable(self) -> Composable:
"""Convert the stream to a torch iterable."""
def to_torch_iterable(self, enforce_rank_check: bool = True, enforce_worker_check: bool = True) -> Composable:
"""
Convert the stream to a torch iterable.
Args:
enforce_rank_check: if set to true, checks that the method `split_by_rank_pytorch` has been called prior to
calling `to_torch_iterable`. This is important to avoid loading the same sample more than once in the
multi-rank pytorch environment.
enforce_worker_check: if set to true, checks that the method `split_by_worker_pytorch` has been called
prior to calling `to_torch_iterable`. This is important to avoid loading the same sample more than
once in the multi-worker pytorch environment.
"""
from squirrel.iterstream.torch_composables import TorchIterable

return self.compose(TorchIterable)
return self.compose(
partial(TorchIterable, enforce_rank_check=enforce_rank_check, enforce_worker_check=enforce_worker_check)
)


class _Iterable(Composable):
Expand Down
8 changes: 5 additions & 3 deletions squirrel/iterstream/torch_composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,21 @@ def __iter__(self) -> Iterator:
class TorchIterable(Composable, IterableDataset):
"""Mixin-Composable to have squirrel pipeline inherit from PyTorch IterableDataset"""

def __init__(self) -> None:
def __init__(self, enforce_rank_check: bool = True, enforce_worker_check: bool = True) -> None:
"""Init"""
super().__init__()
self.enforce_rank_check = enforce_rank_check
self.enforce_worker_check = enforce_worker_check

def __iter__(self) -> Iterator:
"""Method to iterate over the source"""
if _in_multi_rank_env():
if self.enforce_rank_check and _in_multi_rank_env():
if not self._contains_rank_split(self.source):
raise PyTorchSplittingException(
"Composable was not split by rank. This will lead to unexpected iteration behaviour."
"Add a 'split_by_rank_pytorch' call to your composable to avoid this error. "
)
if _in_multi_worker_env():
if self.enforce_worker_check and _in_multi_worker_env():
if not self._contains_worker_split(self.source):
raise PyTorchSplittingException(
"Composable was not split by worker. This will lead to unexpected iteration behaviour."
Expand Down
3 changes: 3 additions & 0 deletions test/test_iterstream/test_torch_composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def test_error_when_not_splitting_in_mp(mock_get_worker_info: Any, samples: List
it = IterableSource(samples).to_torch_iterable()
next(iter(it))

res = IterableSource(samples).to_torch_iterable(enforce_worker_check=False, enforce_rank_check=False).collect()
assert res == samples

# Split by rank and worker, this should work

# ADD SIMPLE MAP FN
Expand Down

0 comments on commit 6db8b7e

Please sign in to comment.