diff --git a/test/test_dataloader2.py b/test/test_dataloader2.py index 02872950c..5f9d8ebb0 100644 --- a/test/test_dataloader2.py +++ b/test/test_dataloader2.py @@ -224,15 +224,23 @@ class DataLoader2IntegrationTest(TestCase): def _get_mp_reading_service(): return MultiProcessingReadingService(num_workers=2) + @staticmethod + def _access_datapipe(dl): + """ + Returns a reference to the DataPipe, bypassing serialization wrapper and etc. + """ + return dl.datapipe._datapipe + def test_lazy_load(self): - source_dp: IterDataPipe = IterableWrapper([(i, i) for i in range(10)]) + source_dp = IterableWrapper([(i, i) for i in range(10)]) map_dp = source_dp.to_map_datapipe() reading_service_generators = (self._get_mp_reading_service,) for reading_service_gen in reading_service_generators: dl: DataLoader2 = DataLoader2(datapipe=map_dp, reading_service=reading_service_gen()) # Lazy loading - self.assertTrue(dl.datapipe._map is None) + dp = self._access_datapipe(dl) + self.assertTrue(dp._map is None) it = iter(dl) self.assertEqual(list(it), list(range(10))) # Lazy loading in multiprocessing diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 7740adb20..fcd023d24 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -9,9 +9,13 @@ from dataclasses import dataclass from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union +from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper + +from torch.utils.data.graph import DataPipe from torchdata.dataloader2.adapter import Adapter -from torchdata.dataloader2.graph import DataPipe +from torchdata.datapipes.iter import IterDataPipe +from torchdata.datapipes.map import MapDataPipe from .error import PauseIteration from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface @@ -80,13 +84,27 @@ def __getattr__(self, name): class DataLoader2(Generic[T_co]): + """ + DataLoader2. Given a DataPipe, a ReadingService and adapter function(s), this provides an iterable over + the given DataPipe. + + Args: + datapipe (Dataset): DataPipe from which to load the data. A deepcopy of this will be made during + initialization, allowing the input to be re-used in a different DataLoader2 without sharing states. + datapipe_adapter_fn (Iterable[Adapter] or Adapter, optional): Adapter function(s) that will be applied + to the DataPipe (default: ``None``). + reading_service (ReadingServiceInterface, optional): defines how DataLoader2 should execute operations over + the DataPipe, e.g. multiprocessing/distributed (default: ``None``). A deepcopy of this will be made during + initialization, allowing the input to be re-used in a different DataLoader2 without sharing states. + """ + def __init__( self, datapipe: DataPipe, datapipe_adapter_fn: Optional[Union[Iterable[Adapter], Adapter]] = None, reading_service: Optional[ReadingServiceInterface] = None, ) -> None: - self.datapipe = datapipe + self.datapipe = self._wrap_and_copy_dp(datapipe) self._adapted: bool = False self._datapipe_iter: Optional[Iterator[T_co]] = None self._reset_iter: bool = True # Sets to `False` when __iter__ starts, and `True` when `StopIteration`` @@ -105,7 +123,7 @@ def __init__( if self.datapipe_adapter_fns is not None: for adapter_fn in self.datapipe_adapter_fns: self.datapipe = adapter_fn(self.datapipe) - self._datapipe_before_reading_service_adapt: DataPipe = self.datapipe + self._datapipe_before_reading_service_adapt: DataPipe = self._copy(self.datapipe) def __iter__(self) -> Iterator[T_co]: if self._terminated: @@ -133,6 +151,28 @@ def __iter__(self) -> Iterator[T_co]: def __del__(self) -> None: self.shutdown() + @staticmethod + def _copy(obj): + """ + Standardized way for DataLoader2 to copy an object when needed, such as for DataPipe/ReadingService. + This uses `pickle` to serialize/deserialize to create the copy. + """ + return pickle.loads(pickle.dumps(obj)) + + @staticmethod + def _wrap_and_copy_dp(datapipe: DataPipe): + """ + Wraps the DataPipe with the corresponding serialization wrapper. + Then, creates a copy with the class's static copy method. + + """ + wrapped_dp: DataPipe = datapipe + if isinstance(datapipe, IterDataPipe): + wrapped_dp = _IterDataPipeSerializationWrapper(datapipe) + elif isinstance(datapipe, MapDataPipe): + wrapped_dp = _MapDataPipeSerializationWrapper(datapipe) + return DataLoader2._copy(wrapped_dp) + def shutdown(self) -> None: if not self._reset_iter: self._reset_iter = True