diff --git a/test/test_graph.py b/test/test_graph.py index e472c3d74..8f10f4abe 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -363,6 +363,8 @@ def test_replicable_branches(self): self.assertEqual(len(dps), 2) self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps)) + # In theory, this case should never happen because LCA (fork_zip_dp) should be + # replaced by _DummpyIterDataPipe if any of child is non-replicable single_br_dp, _, ch1, ch2, *_, cir_map_dp, _, graph = self._make_dp() graph = replace_by_dummy(graph, ch1) dps = find_replicable_branches(graph) diff --git a/torchdata/dataloader2/utils/dispatch.py b/torchdata/dataloader2/utils/dispatch.py index b8c58e924..716eee4cb 100644 --- a/torchdata/dataloader2/utils/dispatch.py +++ b/torchdata/dataloader2/utils/dispatch.py @@ -55,18 +55,18 @@ def _is_round_robin_sharding(dp: DataPipe) -> bool: root_dp_id = list(graph.keys())[0] root_dp, root_graph = graph[root_dp_id] - cache: Dict[int, Optional[DataPipe]] = {} + lca_for_subgraph: Dict[int, Optional[DataPipe]] = {} - def helper(root_dp_id, root_dp, root_graph) -> Optional[DataPipe]: # pyre-ignore - if root_dp_id in cache: - return cache[root_dp_id] + def _get_lca_from_graph(root_dp_id, root_dp, root_graph) -> Optional[DataPipe]: # pyre-ignore + if root_dp_id in lca_for_subgraph: + return lca_for_subgraph[root_dp_id] if root_dp_id in non_replicable_dps: - cache[root_dp_id] = root_dp + lca_for_subgraph[root_dp_id] = root_dp return root_dp - cache[root_dp_id] = None + lca_for_subgraph[root_dp_id] = None non_replicable_parents = [] for dp_id, (dp, src_graph) in root_graph.items(): - res = helper(dp_id, dp, src_graph) + res = _get_lca_from_graph(dp_id, dp, src_graph) if res is not None: non_replicable_parents.append(res) # `root_dp` becomes the lowest common ancestor of this branch, @@ -76,13 +76,13 @@ def helper(root_dp_id, root_dp, root_graph) -> Optional[DataPipe]: # pyre-ignor if len(non_replicable_parents) == 1 or all( dp == non_replicable_parents[0] for dp in non_replicable_parents ): - cache[root_dp_id] = non_replicable_parents[0] + lca_for_subgraph[root_dp_id] = non_replicable_parents[0] # Multiple non-replicable DataPipes else: - cache[root_dp_id] = root_dp - return cache[root_dp_id] + lca_for_subgraph[root_dp_id] = root_dp + return lca_for_subgraph[root_dp_id] - return helper(root_dp_id, root_dp, root_graph) + return _get_lca_from_graph(root_dp_id, root_dp, root_graph) def find_replicable_branches(graph: DataPipeGraph) -> List[DataPipe]: @@ -93,30 +93,30 @@ def find_replicable_branches(graph: DataPipeGraph) -> List[DataPipe]: assert len(graph) == 1, "DataPipeGraph should only contain a single output DataPipe" dps: List[DataPipe] = [] - cache: Dict[int, bool] = {} + branch_is_replicable: Dict[int, bool] = {} root_dp_id = list(graph.keys())[0] root_dp, root_graph = graph[root_dp_id] - def helper(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore - if root_dp_id in cache: - return cache[root_dp_id] + def _is_replicable_graph(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore + if root_dp_id in branch_is_replicable: + return branch_is_replicable[root_dp_id] if type(root_dp) == _DummyIterDataPipe: - cache[root_dp_id] = False + branch_is_replicable[root_dp_id] = False return False - cache[root_dp_id] = True + branch_is_replicable[root_dp_id] = True for dp_id, (dp, src_graph) in root_graph.items(): - if not helper(dp_id, dp, src_graph): - cache[root_dp_id] = False + if not _is_replicable_graph(dp_id, dp, src_graph): + branch_is_replicable[root_dp_id] = False # Do not break to go through all children - if not cache[root_dp_id]: - # All children should have been cached already + if not branch_is_replicable[root_dp_id]: + # All children should have been added to branch_is_replicable already for dp_id, (dp, _) in root_graph.items(): - if cache[dp_id]: + if branch_is_replicable[dp_id]: dps.append(dp) - return cache[root_dp_id] + return branch_is_replicable[root_dp_id] - if helper(root_dp_id, root_dp, root_graph): + if _is_replicable_graph(root_dp_id, root_dp, root_graph): dps.append(root_dp) return dps