Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throw error when torch iterable was not split by rank or split by worker #107

Merged
merged 3 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion squirrel/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.18.3"
__version__ = "0.18.4"
4 changes: 4 additions & 0 deletions squirrel/framework/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class SourceExistsError(Exception):

class SourceArgumentsCombinationException(Exception):
pass


class PyTorchSplittingException(Exception):
pass
46 changes: 46 additions & 0 deletions squirrel/iterstream/torch_composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import IterableDataset

from squirrel.iterstream.base import Composable
from squirrel.framework.exceptions import PyTorchSplittingException

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,8 +67,38 @@ def __init__(self) -> None:

def __iter__(self) -> Iterator:
"""Method to iterate over the source"""
if _in_multi_rank_env():
if not self._contains_rank_split(self.source):
raise PyTorchSplittingException(
"Composable was not split by rank. This will lead to unexpected iteration behaviour."
"Add a 'split_by_rank_pytorch' call to your composable to avoid this error. "
)
if _in_multi_worker_env():
if not self._contains_worker_split(self.source):
raise PyTorchSplittingException(
"Composable was not split by worker. This will lead to unexpected iteration behaviour."
"Add a 'split_by_worker_pytorch' call to your composable to avoid this error. "
)
yield from self.source

def _contains_rank_split(self, source: Composable) -> bool:
"""Check if SplitByRank was chained to this Composable"""
if isinstance(source, SplitByRank):
return True
elif not isinstance(source, Composable):
return False
else:
return self._contains_rank_split(source.source)

def _contains_worker_split(self, source: Composable) -> bool:
"""Check if SplitByWorker was chained to this Composable"""
if isinstance(source, SplitByWorker):
return True
elif not isinstance(source, Composable):
return False
else:
return self._contains_worker_split(source.source)


def _skip_k(it: Iterable, start: int, step: int) -> Iterator:
"""
Expand All @@ -93,3 +124,18 @@ def skip_k(rank: int, world_size: int) -> Callable[[Iterable], Iterator]:
world_size: int denoting the full world size.
"""
return partial(_skip_k, start=rank, step=world_size)


def _in_multi_worker_env() -> bool:
"""Check if currently in multi-worker environment"""
return False if torch.utils.data.get_worker_info() is None else True


def _in_multi_rank_env() -> bool:
"""Check if currently in multi-rank environment"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
group = torch.distributed.group.WORLD
size = torch.distributed.get_world_size(group=group)
return True if size > 1 else False
else:
return False
50 changes: 50 additions & 0 deletions test/test_iterstream/test_torch_composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial
from typing import List, Any
from unittest import mock
from collections import namedtuple

import pytest
import torch
Expand All @@ -11,6 +12,7 @@
from squirrel.iterstream.iterators import map_
from squirrel.iterstream.source import IterableSource
from squirrel.iterstream.torch_composables import SplitByRank, SplitByWorker, TorchIterable, skip_k
from squirrel.framework.exceptions import PyTorchSplittingException


@pytest.fixture(scope="module", autouse=True)
Expand Down Expand Up @@ -175,3 +177,51 @@ def _cb(x: Any) -> Any:
dl3 = tud.DataLoader(it3, num_workers=num_workers)
out3 = torch.Tensor(list(dl3))
assert sorted(out3.cpu().flatten().numpy().tolist()) == expected


@mock.patch("torch.distributed.is_available", mock.MagicMock(return_value=True))
@mock.patch("torch.distributed.is_initialized", mock.MagicMock(return_value=True))
@mock.patch("torch.distributed.group.WORLD", mock.MagicMock(return_value="WORLD"))
@mock.patch("torch.distributed.get_rank", mock.MagicMock(return_value=4))
@mock.patch("torch.distributed.get_world_size", mock.MagicMock(return_value=4))
@mock.patch("torch.utils.data.get_worker_info")
def test_error_when_not_splitting_in_mp(mock_get_worker_info: Any, samples: List[int]) -> None:
"""Test that a ValueError is thrown when composable is not split by rank and worker if calling to_torch_iterable"""
# Needed for multi-worker env
num_workers = 3
worker_id = 0
WorkerInfo = namedtuple("WorkerInfo", ["id", "num_workers"])
mock_get_worker_info.return_value = WorkerInfo(id=worker_id, num_workers=num_workers)

# Needed for multi-rank env
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

# Not splitting by worker
with pytest.raises(PyTorchSplittingException):
it = IterableSource(samples).split_by_rank_pytorch().to_torch_iterable()
next(iter(it))

# Not splitting by rank
with pytest.raises(PyTorchSplittingException):
it = IterableSource(samples).split_by_worker_pytorch().to_torch_iterable()
next(iter(it))

# None of the above
with pytest.raises(PyTorchSplittingException):
it = IterableSource(samples).to_torch_iterable()
next(iter(it))

# Split by rank and worker, this should work

# ADD SIMPLE MAP FN
it = (
IterableSource(samples)
.split_by_worker_pytorch()
.split_by_rank_pytorch()
.async_map(_times_two)
.to_torch_iterable()
)
dl = tud.DataLoader(it, num_workers=num_workers)
out = torch.Tensor(list(dl))
assert len(out.cpu().flatten().numpy().tolist()) == len(samples[rank::world_size])