Skip to content

Commit

Permalink
Refactor to use traverse_dps instead of traverse (#793)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #793

This should be merged only after pytorch/pytorch#85667 has been merged into nightly and internal.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D39860692

Pulled By: NivekT

fbshipit-source-id: 4fd992acfa39d0877575d739c50949fe9d52056f
  • Loading branch information
NivekT authored and facebook-github-bot committed Oct 3, 2022
1 parent b87d215 commit db5ec7a
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 29 deletions.
4 changes: 2 additions & 2 deletions test/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import unittest
from unittest import TestCase

from torch.utils.data.graph import DataPipe

from torchdata.dataloader2 import (
DataLoader2,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)
from torchdata.dataloader2.dataloader2 import READING_SERVICE_STATE_KEY_NAME, SERIALIZED_DATAPIPE_KEY_NAME

from torchdata.dataloader2.graph import DataPipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


Expand Down
16 changes: 8 additions & 8 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from _utils._common_utils_for_test import IS_WINDOWS
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse_dps
from torchdata.datapipes.iter import IterableWrapper, Mapper

T_co = TypeVar("T_co", covariant=True)
Expand All @@ -37,7 +37,7 @@ class TempReadingService(ReadingServiceInterface):
adaptors: List[IterDataPipe] = []

def initialize(self, datapipe: IterDataPipe) -> IterDataPipe:
graph = traverse(datapipe, only_datapipe=True)
graph = traverse_dps(datapipe)
dps = find_dps(graph, Mapper)

for dp in reversed(dps):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _get_datapipes(self) -> Tuple[IterDataPipe, IterDataPipe, IterDataPipe]:
m2 = c1.map(_x_mult_2)
dp = m2.zip(c2)

return traverse(dp, only_datapipe=True), (src_dp, m1, ub, dm, c1, c2, m2, dp)
return traverse_dps(dp), (src_dp, m1, ub, dm, c1, c2, m2, dp)

def test_find_dps(self) -> None:
graph, (_, m1, *_, m2, _) = self._get_datapipes() # pyre-ignore
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_replace_dps(self) -> None:
],
]
]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g1)
self._validate_graph(traverse_dps(dp), exp_g1)

graph = replace_dp(graph, m2, new_dp2)
exp_g2 = [
Expand All @@ -135,7 +135,7 @@ def test_replace_dps(self) -> None:
],
]
]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g2)
self._validate_graph(traverse_dps(dp), exp_g2)

graph = replace_dp(graph, m1, new_dp3)
exp_g3 = [
Expand All @@ -147,7 +147,7 @@ def test_replace_dps(self) -> None:
],
]
]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g3)
self._validate_graph(traverse_dps(dp), exp_g3)

def test_remove_dps(self) -> None:
# pyre-fixme[23]: Unable to unpack 3 values, 2 were expected.
Expand All @@ -164,11 +164,11 @@ def test_remove_dps(self) -> None:

graph = remove_dp(graph, m1)
exp_g1 = [[dp, [[m2, [[c1, [[dm, [[ub, [[src_dp, []]]]]]]]]], [c2, [[dm, [[ub, [[src_dp, []]]]]]]]]]]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g1)
self._validate_graph(traverse_dps(dp), exp_g1)

graph = remove_dp(graph, m2)
exp_g2 = [[dp, [[c1, [[dm, [[ub, [[src_dp, []]]]]]]], [c2, [[dm, [[ub, [[src_dp, []]]]]]]]]]]
self._validate_graph(traverse(dp, only_datapipe=True), exp_g2)
self._validate_graph(traverse_dps(dp), exp_g2)

with self.assertRaisesRegex(RuntimeError, "Cannot remove the source DataPipe"):
remove_dp(graph, src_dp)
Expand Down
4 changes: 2 additions & 2 deletions torchdata/dataloader2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from torch.utils.data.graph import DataPipe
from torchdata.dataloader2.graph import DataPipe, traverse_dps
from torchdata.datapipes.iter.util.cacheholder import _WaitPendingCacheItemIterDataPipe


Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, timeout=None):
self.timeout = timeout

def __call__(self, datapipe: DataPipe) -> DataPipe:
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
graph = traverse_dps(datapipe)
all_pipes = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
cache_locks = {pipe for pipe in all_pipes if isinstance(pipe, _WaitPendingCacheItemIterDataPipe)}

Expand Down
4 changes: 2 additions & 2 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

from torch.utils.data.graph import DataPipe

from torchdata.dataloader2.adapter import Adapter

from torchdata.dataloader2.graph import DataPipe

from .error import PauseIteration
from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface

Expand Down
16 changes: 8 additions & 8 deletions torchdata/dataloader2/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing import List, Type

from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps

from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
Expand All @@ -18,7 +18,7 @@

def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse`` function, return DataPipe
Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe
instances with the provided DataPipe type.
"""
dps: List[DataPipe] = []
Expand All @@ -37,33 +37,33 @@ def helper(g) -> None: # pyre-ignore
# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph
def replace_dp(graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe) -> DataPipeGraph:
r"""
Given the graph of DataPipe generated by ``traverse`` function and the DataPipe to be replaced and
Given the graph of DataPipe generated by ``traverse_dps`` function and the DataPipe to be replaced and
the new DataPipe, return the new graph of DataPipe.
"""
assert len(graph) == 1

if id(old_datapipe) in graph:
graph = traverse(new_datapipe, only_datapipe=True)
graph = traverse_dps(new_datapipe)

final_datapipe = list(graph.values())[0][0]

for recv_dp, send_graph in graph.values():
_replace_dp(recv_dp, send_graph, old_datapipe, new_datapipe)

return traverse(final_datapipe, only_datapipe=True)
return traverse_dps(final_datapipe)


def remove_dp(graph: DataPipeGraph, datapipe: DataPipe) -> DataPipeGraph:
r"""
Given the graph of DataPipe generated by ``traverse`` function and the DataPipe to be removed,
Given the graph of DataPipe generated by ``traverse_dps`` function and the DataPipe to be removed,
return the new graph of DataPipe.
Note:
- This function can not remove DataPipe that takes multiple DataPipes as the input.
"""
assert len(graph) == 1

dp_graph = traverse(datapipe, only_datapipe=True)
dp_graph = traverse_dps(datapipe)
dp_id = id(datapipe)
if len(dp_graph[dp_id][1]) == 0:
raise RuntimeError("Cannot remove the source DataPipe from the graph of DataPipe")
Expand All @@ -80,7 +80,7 @@ def remove_dp(graph: DataPipeGraph, datapipe: DataPipe) -> DataPipeGraph:
assert len(graph) == 1
datapipe = list(graph.values())[0][0]

return traverse(datapipe, only_datapipe=True)
return traverse_dps(datapipe)


# For each `recv_dp`, find if the source_datapipe needs to be replaced by the new one.
Expand Down
4 changes: 2 additions & 2 deletions torchdata/dataloader2/linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse
from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, traverse_dps

from torchdata.datapipes.iter import ShardingFilter, Shuffler

Expand All @@ -14,7 +14,7 @@ def _check_shuffle_before_sharding(datapipe: DataPipe) -> bool:
This function will check if a ``shuffle`` operation is presented before each
``sharding_filter`` operation for every single path in the ``DataPipe`` graph.
"""
graph: DataPipeGraph = traverse(datapipe) # type: ignore[arg-type]
graph: DataPipeGraph = traverse_dps(datapipe) # type: ignore[arg-type]
return _check_shuffler_before_sharding_helper(graph)


Expand Down
2 changes: 1 addition & 1 deletion torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import torch.distributed as dist

from torch.utils.data import DataLoader
from torch.utils.data.graph import DataPipe

from torchdata._constants import default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe
from torchdata.datapipes.iter import FullSync, IterableWrapper


Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE

from torch.utils.data.graph import traverse
from torch.utils.data.graph import traverse_dps
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import FileLister, IterDataPipe

Expand Down Expand Up @@ -396,7 +396,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
if filepath_fn is not None and same_filepath_fn:
raise ValueError("`filepath_fn` is mutually exclusive with `same_filepath_fn`")

graph = traverse(datapipe, only_datapipe=True)
graph = traverse_dps(datapipe)
# Get the last CacheHolder
cache_holder = EndOnDiskCacheHolderIterDataPipe._recursive_search(graph)
if cache_holder is None:
Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/utils/_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Optional, Set, TYPE_CHECKING

from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, IterDataPipe
from torch.utils.data.graph import traverse
from torch.utils.data.graph import traverse_dps

if TYPE_CHECKING:
import graphviz
Expand Down Expand Up @@ -117,7 +117,7 @@ def aggregate(nodes):

return nodes

return aggregate(recurse(traverse(dp, only_datapipe=True)))
return aggregate(recurse(traverse_dps(dp)))


def to_graph(dp, *, debug: bool = False) -> "graphviz.Digraph":
Expand Down

0 comments on commit db5ec7a

Please sign in to comment.