Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to use traverse_dps instead of traverse #793

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import unittest
from unittest import TestCase

from torch.utils.data.graph import DataPipe

from torchdata.dataloader2 import (
DataLoader2,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)
from torchdata.dataloader2.dataloader2 import READING_SERVICE_STATE_KEY_NAME, SERIALIZED_DATAPIPE_KEY_NAME

from torchdata.dataloader2.graph import DataPipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


Expand Down
16 changes: 8 additions & 8 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from _utils._common_utils_for_test import IS_WINDOWS
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse_dps
from torchdata.datapipes.iter import IterableWrapper, Mapper

T_co = TypeVar("T_co", covariant=True)
Expand All @@ -37,7 +37,7 @@ class TempReadingService(ReadingServiceInterface):
adaptors: List[IterDataPipe] = []

def initialize(self, datapipe: IterDataPipe) -> IterDataPipe:
graph = traverse(datapipe, only_datapipe=True)
graph = traverse_dps(datapipe)
dps = find_dps(graph, Mapper)

for dp in reversed(dps):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _get_datapipes(self) -> Tuple[IterDataPipe, IterDataPipe, IterDataPipe]:
m2 = c1.map(_x_mult_2)
dp = m2.zip(c2)

return traverse(dp, only_datapipe=True), (src_dp, m1, ub, dm, c1, c2, m2, dp)
return traverse_dps(dp), (src_dp, m1, ub, dm, c1, c2, m2, dp)

def test_find_dps(self) -> None:
graph, (_, m1, *_, m2, _) = self._get_datapipes() # pyre-ignore
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_replace_dps(self) -> None:
],
]
]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g1)
self._validate_graph(traverse_dps(dp), exp_g1)

graph = replace_dp(graph, m2, new_dp2)
exp_g2 = [
Expand All @@ -135,7 +135,7 @@ def test_replace_dps(self) -> None:
],
]
]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g2)
self._validate_graph(traverse_dps(dp), exp_g2)

graph = replace_dp(graph, m1, new_dp3)
exp_g3 = [
Expand All @@ -147,7 +147,7 @@ def test_replace_dps(self) -> None:
],
]
]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g3)
self._validate_graph(traverse_dps(dp), exp_g3)

def test_remove_dps(self) -> None:
# pyre-fixme[23]: Unable to unpack 3 values, 2 were expected.
Expand All @@ -164,11 +164,11 @@ def test_remove_dps(self) -> None:

graph = remove_dp(graph, m1)
exp_g1 = [[dp, [[m2, [[c1, [[dm, [[ub, [[src_dp, []]]]]]]]]], [c2, [[dm, [[ub, [[src_dp, []]]]]]]]]]]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g1)
self._validate_graph(traverse_dps(dp), exp_g1)

graph = remove_dp(graph, m2)
exp_g2 = [[dp, [[c1, [[dm, [[ub, [[src_dp, []]]]]]]], [c2, [[dm, [[ub, [[src_dp, []]]]]]]]]]]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g2)
self._validate_graph(traverse_dps(dp), exp_g2)

with self.assertRaisesRegex(RuntimeError, "Cannot remove the source DataPipe"):
remove_dp(graph, src_dp)
Expand Down
4 changes: 2 additions & 2 deletions torchdata/dataloader2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from torch.utils.data.graph import DataPipe
from torchdata.dataloader2.graph import DataPipe, traverse_dps
from torchdata.datapipes.iter.util.cacheholder import _WaitPendingCacheItemIterDataPipe


Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, timeout=None):
self.timeout = timeout

def __call__(self, datapipe: DataPipe) -> DataPipe:
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
graph = traverse_dps(datapipe)
all_pipes = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
cache_locks = {pipe for pipe in all_pipes if isinstance(pipe, _WaitPendingCacheItemIterDataPipe)}

Expand Down
4 changes: 2 additions & 2 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

from torch.utils.data.graph import DataPipe

from torchdata.dataloader2.adapter import Adapter

from torchdata.dataloader2.graph import DataPipe

from .error import PauseIteration
from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface

Expand Down
16 changes: 8 additions & 8 deletions torchdata/dataloader2/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing import List, Type

from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps

from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
Expand All @@ -18,7 +18,7 @@

def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse`` function, return DataPipe
Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe
instances with the provided DataPipe type.
"""
dps: List[DataPipe] = []
Expand All @@ -37,33 +37,33 @@ def helper(g) -> None: # pyre-ignore
# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph
def replace_dp(graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe) -> DataPipeGraph:
r"""
Given the graph of DataPipe generated by ``traverse`` function and the DataPipe to be replaced and
Given the graph of DataPipe generated by ``traverse_dps`` function and the DataPipe to be replaced and
the new DataPipe, return the new graph of DataPipe.
"""
assert len(graph) == 1

if id(old_datapipe) in graph:
graph = traverse(new_datapipe, only_datapipe=True)
graph = traverse_dps(new_datapipe)

final_datapipe = list(graph.values())[0][0]

for recv_dp, send_graph in graph.values():
_replace_dp(recv_dp, send_graph, old_datapipe, new_datapipe)

return traverse(final_datapipe, only_datapipe=True)
return traverse_dps(final_datapipe)


def remove_dp(graph: DataPipeGraph, datapipe: DataPipe) -> DataPipeGraph:
r"""
Given the graph of DataPipe generated by ``traverse`` function and the DataPipe to be removed,
Given the graph of DataPipe generated by ``traverse_dps`` function and the DataPipe to be removed,
return the new graph of DataPipe.

Note:
- This function can not remove DataPipe that takes multiple DataPipes as the input.
"""
assert len(graph) == 1

dp_graph = traverse(datapipe, only_datapipe=True)
dp_graph = traverse_dps(datapipe)
dp_id = id(datapipe)
if len(dp_graph[dp_id][1]) == 0:
raise RuntimeError("Cannot remove the source DataPipe from the graph of DataPipe")
Expand All @@ -80,7 +80,7 @@ def remove_dp(graph: DataPipeGraph, datapipe: DataPipe) -> DataPipeGraph:
assert len(graph) == 1
datapipe = list(graph.values())[0][0]

return traverse(datapipe, only_datapipe=True)
return traverse_dps(datapipe)


# For each `recv_dp`, find if the source_datapipe needs to be replaced by the new one.
Expand Down
4 changes: 2 additions & 2 deletions torchdata/dataloader2/linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse
from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, traverse_dps

from torchdata.datapipes.iter import ShardingFilter, Shuffler

Expand All @@ -14,7 +14,7 @@ def _check_shuffle_before_sharding(datapipe: DataPipe) -> bool:
This function will check if a ``shuffle`` operation is presented before each
``sharding_filter`` operation for every single path in the ``DataPipe`` graph.
"""
graph: DataPipeGraph = traverse(datapipe) # type: ignore[arg-type]
graph: DataPipeGraph = traverse_dps(datapipe) # type: ignore[arg-type]
return _check_shuffler_before_sharding_helper(graph)


Expand Down
2 changes: 1 addition & 1 deletion torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import torch.distributed as dist

from torch.utils.data import DataLoader
from torch.utils.data.graph import DataPipe

from torchdata._constants import default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe
from torchdata.datapipes.iter import FullSync, IterableWrapper


Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE

from torch.utils.data.graph import traverse
from torch.utils.data.graph import traverse_dps
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import FileLister, IterDataPipe

Expand Down Expand Up @@ -396,7 +396,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
if filepath_fn is not None and same_filepath_fn:
raise ValueError("`filepath_fn` is mutually exclusive with `same_filepath_fn`")

graph = traverse(datapipe, only_datapipe=True)
graph = traverse_dps(datapipe)
# Get the last CacheHolder
cache_holder = EndOnDiskCacheHolderIterDataPipe._recursive_search(graph)
if cache_holder is None:
Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/utils/_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Optional, Set, TYPE_CHECKING

from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, IterDataPipe
from torch.utils.data.graph import traverse
from torch.utils.data.graph import traverse_dps

if TYPE_CHECKING:
import graphviz
Expand Down Expand Up @@ -117,7 +117,7 @@ def aggregate(nodes):

return nodes

return aggregate(recurse(traverse(dp, only_datapipe=True)))
return aggregate(recurse(traverse_dps(dp)))


def to_graph(dp, *, debug: bool = False) -> "graphviz.Digraph":
Expand Down