Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader] Deep copy ReadingService during DL2 initialization #746

Closed
wants to merge 6 commits into from
16 changes: 15 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import pickle
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

Expand Down Expand Up @@ -80,6 +81,19 @@ def __getattr__(self, name):


class DataLoader2(Generic[T_co]):
"""
DataLoader2. Given a DataPipe, a ReadingService and adapter function(s), this provides an iterable over
the given DataPipe.

Args:
datapipe (Dataset): DataPipe from which to load the data.
datapipe_adapter_fn (Iterable[Adapter] or Adapter, optional): Adapter function(s) that will be applied
to the DataPipe (default: ``None``).
reading_service (ReadingServiceInterface, optional): defines how DataLoader2 should execute operations over
the DataPipe, e.g. multiprocessing/distributed (default: ``None``). A deepcopy of this will be made during
initialization, allowing the input to be re-used in a different DataLoader2.
"""

def __init__(
self,
datapipe: DataPipe,
Expand All @@ -97,7 +111,7 @@ def __init__(
self.datapipe_adapter_fns = datapipe_adapter_fn
else:
self.datapipe_adapter_fns = [datapipe_adapter_fn]
self.reading_service = reading_service
self.reading_service = deepcopy(reading_service)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to import your PR to internal to validate if this is working properly with internal RS.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on checking existing implementations.

self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called
self._terminated: bool = False
self.valid_iterator_id: Optional[int] = None
Expand Down