diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index a0ae1c35af..8aa95af4a2 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -95,17 +95,12 @@ def reproduce( path, name=name, recursive=recursive, graph=active_graph ) - ret = [] - for target in targets: - stages = _reproduce_stages(active_graph, target, **kwargs) - ret.extend(stages) - - return ret + return _reproduce_stages(active_graph, targets, **kwargs) def _reproduce_stages( G, - stage, + stages, downstream=False, ignore_build_cache=False, single_item=False, @@ -148,23 +143,34 @@ def _reproduce_stages( import networkx as nx if single_item: - pipeline = [stage] - elif downstream: - # NOTE (py3 only): - # Python's `deepcopy` defaults to pickle/unpickle the object. - # Stages are complex objects (with references to `repo`, `outs`, - # and `deps`) that cause struggles when you try to serialize them. - # We need to create a copy of the graph itself, and then reverse it, - # instead of using graph.reverse() directly because it calls - # `deepcopy` underneath -- unless copy=False is specified. - pipeline = nx.dfs_preorder_nodes(G.copy().reverse(copy=False), stage) + all_pipelines = stages else: - pipeline = nx.dfs_postorder_nodes(G, stage) + all_pipelines = [] + for stage in stages: + if downstream: + # NOTE (py3 only): + # Python's `deepcopy` defaults to pickle/unpickle the object. + # Stages are complex objects (with references to `repo`, + # `outs`, and `deps`) that cause struggles when you try + # to serialize them. We need to create a copy of the graph + # itself, and then reverse it, instead of using + # graph.reverse() directly because it calls `deepcopy` + # underneath -- unless copy=False is specified. + all_pipelines += nx.dfs_preorder_nodes( + G.copy().reverse(copy=False), stage + ) + else: + all_pipelines += nx.dfs_postorder_nodes(G, stage) + + pipeline = [] + for stage in all_pipelines: + if stage not in pipeline: + pipeline.append(stage) result = [] - for st in pipeline: + for stage in pipeline: try: - ret = _reproduce_stage(st, **kwargs) + ret = _reproduce_stage(stage, **kwargs) if len(ret) != 0 and ignore_build_cache: # NOTE: we are walking our pipeline from the top to the @@ -176,5 +182,6 @@ def _reproduce_stages( result.extend(ret) except Exception as exc: - raise ReproductionError(st.relpath) from exc + raise ReproductionError(stage.relpath) from exc + return result diff --git a/tests/unit/repo/test_reproduce.py b/tests/unit/repo/test_reproduce.py index 2b49273ed3..7df9603f36 100644 --- a/tests/unit/repo/test_reproduce.py +++ b/tests/unit/repo/test_reproduce.py @@ -1,3 +1,5 @@ +import mock + from dvc.repo.reproduce import _get_active_graph @@ -23,3 +25,19 @@ def test_get_active_graph(tmp_dir, dvc): active_graph = _get_active_graph(graph) assert set(active_graph.nodes) == {bar_stage, baz_stage} assert not active_graph.edges + + +@mock.patch("dvc.repo.reproduce._reproduce_stage", returns=[]) +def test_number_reproduces(reproduce_stage_mock, tmp_dir, dvc): + tmp_dir.dvc_gen({"pre-foo": "pre-foo"}) + + dvc.run(deps=["pre-foo"], outs=["foo"], cmd="echo foo > foo") + dvc.run(deps=["foo"], outs=["bar"], cmd="echo bar > bar") + dvc.run(deps=["foo"], outs=["baz"], cmd="echo baz > baz") + dvc.run(deps=["bar"], outs=["boop"], cmd="echo boop > boop") + + reproduce_stage_mock.reset_mock() + + dvc.reproduce(all_pipelines=True) + + assert reproduce_stage_mock.call_count == 5