Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader] Deep copy ReadingService during DL2 initialization #746

Closed
wants to merge 6 commits into from
23 changes: 22 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ def __getattr__(self, name):


class DataLoader2(Generic[T_co]):
"""
DataLoader2. Given a DataPipe, a ReadingService and adapter function(s), this provides an iterable over
the given DataPipe.

Args:
datapipe (Dataset): DataPipe from which to load the data.
datapipe_adapter_fn (Iterable[Adapter] or Adapter, optional): Adapter function(s) that will be applied
to the DataPipe (default: ``None``).
reading_service (ReadingServiceInterface, optional): defines how DataLoader2 should execute operations over
the DataPipe, e.g. multiprocessing/distributed (default: ``None``). A deepcopy of this will be made during
initialization, allowing the input to be re-used in a different DataLoader2.
"""

def __init__(
self,
datapipe: DataPipe,
Expand All @@ -97,7 +110,7 @@ def __init__(
self.datapipe_adapter_fns = datapipe_adapter_fn
else:
self.datapipe_adapter_fns = [datapipe_adapter_fn]
self.reading_service = reading_service
self.reading_service = self._copy(reading_service)
self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called
self._terminated: bool = False
self.valid_iterator_id: Optional[int] = None
Expand Down Expand Up @@ -133,6 +146,14 @@ def __iter__(self) -> Iterator[T_co]:
def __del__(self) -> None:
self.shutdown()

@staticmethod
def _copy(obj):
"""
Standardized way for DataLoader2 to copy an object when needed, such as for DataPipe/ReadingService.
This uses `pickle` to serialize/deserialize to create the copy.
"""
return pickle.loads(pickle.dumps(obj))

def shutdown(self) -> None:
if not self._reset_iter:
self._reset_iter = True
Expand Down