Skip to content

Commit

Permalink
Change cache and helper function to more meaningful names
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Dec 22, 2022
1 parent 0262278 commit f93215b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
2 changes: 2 additions & 0 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def test_replicable_branches(self):
self.assertEqual(len(dps), 2)
self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps))

# In theory, this case should never happen because LCA (fork_zip_dp) should be
# replaced by _DummpyIterDataPipe if any of child is non-replicable
single_br_dp, _, ch1, ch2, *_, cir_map_dp, _, graph = self._make_dp()
graph = replace_by_dummy(graph, ch1)
dps = find_replicable_branches(graph)
Expand Down
48 changes: 24 additions & 24 deletions torchdata/dataloader2/utils/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ def _is_round_robin_sharding(dp: DataPipe) -> bool:
root_dp_id = list(graph.keys())[0]
root_dp, root_graph = graph[root_dp_id]

cache: Dict[int, Optional[DataPipe]] = {}
lca_for_subgraph: Dict[int, Optional[DataPipe]] = {}

def helper(root_dp_id, root_dp, root_graph) -> Optional[DataPipe]: # pyre-ignore
if root_dp_id in cache:
return cache[root_dp_id]
def _get_lca_from_graph(root_dp_id, root_dp, root_graph) -> Optional[DataPipe]: # pyre-ignore
if root_dp_id in lca_for_subgraph:
return lca_for_subgraph[root_dp_id]
if root_dp_id in non_replicable_dps:
cache[root_dp_id] = root_dp
lca_for_subgraph[root_dp_id] = root_dp
return root_dp
cache[root_dp_id] = None
lca_for_subgraph[root_dp_id] = None
non_replicable_parents = []
for dp_id, (dp, src_graph) in root_graph.items():
res = helper(dp_id, dp, src_graph)
res = _get_lca_from_graph(dp_id, dp, src_graph)
if res is not None:
non_replicable_parents.append(res)
# `root_dp` becomes the lowest common ancestor of this branch,
Expand All @@ -76,13 +76,13 @@ def helper(root_dp_id, root_dp, root_graph) -> Optional[DataPipe]: # pyre-ignor
if len(non_replicable_parents) == 1 or all(
dp == non_replicable_parents[0] for dp in non_replicable_parents
):
cache[root_dp_id] = non_replicable_parents[0]
lca_for_subgraph[root_dp_id] = non_replicable_parents[0]
# Multiple non-replicable DataPipes
else:
cache[root_dp_id] = root_dp
return cache[root_dp_id]
lca_for_subgraph[root_dp_id] = root_dp
return lca_for_subgraph[root_dp_id]

return helper(root_dp_id, root_dp, root_graph)
return _get_lca_from_graph(root_dp_id, root_dp, root_graph)


def find_replicable_branches(graph: DataPipeGraph) -> List[DataPipe]:
Expand All @@ -93,30 +93,30 @@ def find_replicable_branches(graph: DataPipeGraph) -> List[DataPipe]:
assert len(graph) == 1, "DataPipeGraph should only contain a single output DataPipe"

dps: List[DataPipe] = []
cache: Dict[int, bool] = {}
branch_is_replicable: Dict[int, bool] = {}

root_dp_id = list(graph.keys())[0]
root_dp, root_graph = graph[root_dp_id]

def helper(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore
if root_dp_id in cache:
return cache[root_dp_id]
def _is_replicable_graph(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore
if root_dp_id in branch_is_replicable:
return branch_is_replicable[root_dp_id]
if type(root_dp) == _DummyIterDataPipe:
cache[root_dp_id] = False
branch_is_replicable[root_dp_id] = False
return False
cache[root_dp_id] = True
branch_is_replicable[root_dp_id] = True
for dp_id, (dp, src_graph) in root_graph.items():
if not helper(dp_id, dp, src_graph):
cache[root_dp_id] = False
if not _is_replicable_graph(dp_id, dp, src_graph):
branch_is_replicable[root_dp_id] = False
# Do not break to go through all children
if not cache[root_dp_id]:
# All children should have been cached already
if not branch_is_replicable[root_dp_id]:
# All children should have been added to branch_is_replicable already
for dp_id, (dp, _) in root_graph.items():
if cache[dp_id]:
if branch_is_replicable[dp_id]:
dps.append(dp)
return cache[root_dp_id]
return branch_is_replicable[root_dp_id]

if helper(root_dp_id, root_dp, root_graph):
if _is_replicable_graph(root_dp_id, root_dp, root_graph):
dps.append(root_dp)

return dps

0 comments on commit f93215b

Please sign in to comment.