Skip to content

Commit

Permalink
Deep copy DataPipe during initialization (#786)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #786

Deep copy DataPipe during initialization to avoid sharing states if the same DataPipe is being passed to multiple usages.

This can be merged before or after #746 (the other PR will have to rebase).

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D39741743

Pulled By: ejguan

fbshipit-source-id: f1cdaca054e10bcb3672857b933732178b370535
  • Loading branch information
NivekT authored and ejguan committed Oct 14, 2022
1 parent 87b7426 commit c542222
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
12 changes: 10 additions & 2 deletions test/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,23 @@ class DataLoader2IntegrationTest(TestCase):
def _get_mp_reading_service():
return MultiProcessingReadingService(num_workers=2)

@staticmethod
def _access_datapipe(dl):
"""
Returns a reference to the DataPipe, bypassing serialization wrapper and etc.
"""
return dl.datapipe._datapipe

def test_lazy_load(self):
source_dp: IterDataPipe = IterableWrapper([(i, i) for i in range(10)])
source_dp = IterableWrapper([(i, i) for i in range(10)])
map_dp = source_dp.to_map_datapipe()

reading_service_generators = (self._get_mp_reading_service,)
for reading_service_gen in reading_service_generators:
dl: DataLoader2 = DataLoader2(datapipe=map_dp, reading_service=reading_service_gen())
# Lazy loading
self.assertTrue(dl.datapipe._map is None)
dp = self._access_datapipe(dl)
self.assertTrue(dp._map is None)
it = iter(dl)
self.assertEqual(list(it), list(range(10)))
# Lazy loading in multiprocessing
Expand Down
46 changes: 43 additions & 3 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper

from torch.utils.data.graph import DataPipe
from torchdata.dataloader2.adapter import Adapter

from torchdata.dataloader2.graph import DataPipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe

from .error import PauseIteration
from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface
Expand Down Expand Up @@ -80,13 +84,27 @@ def __getattr__(self, name):


class DataLoader2(Generic[T_co]):
"""
DataLoader2. Given a DataPipe, a ReadingService and adapter function(s), this provides an iterable over
the given DataPipe.
Args:
datapipe (Dataset): DataPipe from which to load the data. A deepcopy of this will be made during
initialization, allowing the input to be re-used in a different DataLoader2 without sharing states.
datapipe_adapter_fn (Iterable[Adapter] or Adapter, optional): Adapter function(s) that will be applied
to the DataPipe (default: ``None``).
reading_service (ReadingServiceInterface, optional): defines how DataLoader2 should execute operations over
the DataPipe, e.g. multiprocessing/distributed (default: ``None``). A deepcopy of this will be made during
initialization, allowing the input to be re-used in a different DataLoader2 without sharing states.
"""

def __init__(
self,
datapipe: DataPipe,
datapipe_adapter_fn: Optional[Union[Iterable[Adapter], Adapter]] = None,
reading_service: Optional[ReadingServiceInterface] = None,
) -> None:
self.datapipe = datapipe
self.datapipe = self._wrap_and_copy_dp(datapipe)
self._adapted: bool = False
self._datapipe_iter: Optional[Iterator[T_co]] = None
self._reset_iter: bool = True # Sets to `False` when __iter__ starts, and `True` when `StopIteration``
Expand All @@ -105,7 +123,7 @@ def __init__(
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt: DataPipe = self.datapipe
self._datapipe_before_reading_service_adapt: DataPipe = self._copy(self.datapipe)

def __iter__(self) -> Iterator[T_co]:
if self._terminated:
Expand Down Expand Up @@ -133,6 +151,28 @@ def __iter__(self) -> Iterator[T_co]:
def __del__(self) -> None:
self.shutdown()

@staticmethod
def _copy(obj):
"""
Standardized way for DataLoader2 to copy an object when needed, such as for DataPipe/ReadingService.
This uses `pickle` to serialize/deserialize to create the copy.
"""
return pickle.loads(pickle.dumps(obj))

@staticmethod
def _wrap_and_copy_dp(datapipe: DataPipe):
"""
Wraps the DataPipe with the corresponding serialization wrapper.
Then, creates a copy with the class's static copy method.
"""
wrapped_dp: DataPipe = datapipe
if isinstance(datapipe, IterDataPipe):
wrapped_dp = _IterDataPipeSerializationWrapper(datapipe)
elif isinstance(datapipe, MapDataPipe):
wrapped_dp = _MapDataPipeSerializationWrapper(datapipe)
return DataLoader2._copy(wrapped_dp)

def shutdown(self) -> None:
if not self._reset_iter:
self._reset_iter = True
Expand Down

0 comments on commit c542222

Please sign in to comment.