Skip to content

Commit

Permalink
Add tests for shardable branches
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Dec 12, 2022
1 parent 71a5c04 commit 47201aa
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
40 changes: 40 additions & 0 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
replace_dp,
traverse_dps,
)
from torchdata.dataloader2.utils.non_shardable import _DummyIterDataPipe, find_shardable_branches
from torchdata.datapipes.iter import IterableWrapper, Mapper
from torchdata.datapipes.utils import to_graph

Expand Down Expand Up @@ -266,6 +267,10 @@ def _is_shardable(self):
return datapipe


def replace_by_dummy(graph, datapipe):
return replace_dp(graph, datapipe, _DummyIterDataPipe())


class TestNonShardableDataPipe(expecttest.TestCase):
def _make_dp(self):
r"""
Expand Down Expand Up @@ -345,6 +350,41 @@ def test_multi_non_shardable_dps(self):
cir_br_dp = make_dp_non_shardable(cir_br_dp)
self.assertEqual(find_lca_non_shardable_dp(graph), end_dp)

def test_shardable_branches(self):
r"""
There should be a single DataPipe as the lowest common ancestor of all
non-shardable DataPipes that is replaced by ``DummyIterDataPipe``.
"""
single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
graph = replace_by_dummy(graph, single_br_dp)
dps = find_shardable_branches(graph)
self.assertTrue(all(dp in (fork_zip_dp, cir_map_dp) for dp in dps))

single_br_dp, multi_br_dp, *_, cir_map_dp, _, graph = self._make_dp()
graph = replace_by_dummy(graph, multi_br_dp)
dps = find_shardable_branches(graph)
self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps))

single_br_dp, _, ch1, ch2, *_, cir_map_dp, _, graph = self._make_dp()
graph = replace_by_dummy(graph, ch1)
dps = find_shardable_branches(graph)
self.assertTrue(all(dp in (single_br_dp, ch2, cir_map_dp) for dp in dps))

single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
graph = replace_by_dummy(graph, cir_map_dp)
dps = find_shardable_branches(graph)
self.assertTrue(all(dp in (single_br_dp, fork_zip_dp) for dp in dps))

*_, end_dp, graph = self._make_dp()
graph = replace_by_dummy(graph, end_dp)
dps = find_shardable_branches(graph)
self.assertEqual(len(dps), 0)

single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
graph = replace_by_dummy(graph, fork_zip_dp)
dps = find_shardable_branches(graph)
self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps))


class TestGraphVisualization(expecttest.TestCase):
@unittest.skipIf(not HAS_GRAPHVIZ, "Package `graphviz` is required to test graph visualization functionalities.")
Expand Down
8 changes: 4 additions & 4 deletions torchdata/dataloader2/utils/non_shardable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def find_shardable_branches(graph: DataPipeGraph) -> List[DataPipe]:
cache: Dict[int, bool] = {}

root_dp_id = list(graph.keys())[0]
root_dp, root_graph = graph[roo_dp_id]
root_dp, root_graph = graph[root_dp_id]

def helper(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore
if root_dp_id in cache:
Expand All @@ -45,12 +45,12 @@ def helper(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore
for dp_id, (dp, src_graph) in root_graph.items():
if not helper(dp_id, dp, src_graph):
cache[root_dp_id] = False
break
# Do not break to go through all children
if not cache[root_dp_id]:
# All children should have been cached already
for dp_id in root_graph.keys():
for dp_id, (dp, _) in root_graph.items():
if cache[dp_id]:
dps.append(dp_id)
dps.append(dp)
return cache[root_dp_id]

if helper(root_dp_id, root_dp, root_graph):
Expand Down

0 comments on commit 47201aa

Please sign in to comment.