Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader2] Handle MapDataPipe by converting to IterDataPipe internally by default #1146

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

from torchdata.dataloader2.adapter import Adapter
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.graph._serialization import clone, DataPipe, deserialize_datapipe, serialize_datapipe
from torchdata.dataloader2.graph._serialization import (
clone,
DataPipe,
deserialize_datapipe,
MapDataPipe,
serialize_datapipe,
)
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.random.seed_generator import _UINT64_UPPER_BOUND
from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface
Expand Down Expand Up @@ -145,6 +151,13 @@ class DataLoader2(Generic[T_co]):
the ``DataPipe``, e.g. multiprocessing/distributed (default: ``None``). A deepcopy of this will be
created during initialization, allowing the ReadingService to be re-used in a different
``DataLoader2`` without sharing states.

Note:
When a ``MapDataPipe`` is passed into ``DataLoader2``, in order to iterate through
the data, ``DataLoader2`` will attempt to create an iterator via ``iter(datapipe)``.
If the object has a non-zero-indexed indices, this may fail.
Consider using ``.shuffle()`` (which converts ``MapDataPipe`` to ``IterDataPipe``)
or ``datapipe.to_iter_datapipe(custom_indices)``.
"""

def __init__(
Expand All @@ -153,6 +166,8 @@ def __init__(
datapipe_adapter_fn: Optional[Union[Iterable[Adapter], Adapter]] = None,
reading_service: Optional[ReadingServiceInterface] = None,
) -> None:
if isinstance(datapipe, MapDataPipe):
datapipe = datapipe.to_iter_datapipe()
self.datapipe = clone(datapipe) if datapipe is not None else None
self._adapted: bool = False
self._datapipe_iter: Optional[Iterator[T_co]] = None
Expand Down