From 78e12327bcce996dc634fb66f9bd254b404407c9 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Mon, 18 Jan 2021 18:55:19 +0545 Subject: [PATCH] use repo.stage.create_from_cli to create stages (#5281) --- dvc/command/stage.py | 13 +++------- dvc/repo/__init__.py | 5 ++-- dvc/repo/add.py | 2 +- dvc/repo/run.py | 27 +++++++++----------- dvc/repo/stage.py | 19 ++++++++++++-- dvc/stage/utils.py | 43 +++++++++++++++++++++----------- tests/unit/command/test_stage.py | 4 +-- 7 files changed, 65 insertions(+), 48 deletions(-) diff --git a/dvc/command/stage.py b/dvc/command/stage.py index 5a69f9a499..c2463804ac 100644 --- a/dvc/command/stage.py +++ b/dvc/command/stage.py @@ -8,19 +8,12 @@ class CmdStageAdd(CmdBase): - @staticmethod - def create(repo, force, **kwargs): - from dvc.stage.utils import check_graphs, create_stage_from_cli - - stage = create_stage_from_cli(repo, **kwargs) - check_graphs(repo, stage, force=force) - return stage - def run(self): - stage = self.create(self.repo, self.args.force, **vars(self.args)) + kwargs = vars(self.args) + stage = self.repo.stage.create_from_cli(validate=True, **kwargs) + stage.ignore_outs() stage.dump() - return 0 diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 4b64fd0cda..95bda4cdb1 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -280,7 +280,7 @@ def _ignore(self): self.scm.ignore_list(flist) - def check_modified_graph(self, new_stages): + def check_modified_graph(self, new_stages, old_stages=None): """Generate graph including the new stage to check for errors""" # Building graph might be costly for the ones with many DVC-files, # so we provide this undocumented hack to skip it. See [1] for @@ -295,7 +295,8 @@ def check_modified_graph(self, new_stages): # # [1] https://github.com/iterative/dvc/issues/2671 if not getattr(self, "_skip_graph_checks", False): - build_graph(self.stages + new_stages) + existing_stages = self.stages if old_stages is None else old_stages + build_graph(existing_stages + new_stages) def used_cache( self, diff --git a/dvc/repo/add.py b/dvc/repo/add.py index 5e59114f55..553ed4adc8 100644 --- a/dvc/repo/add.py +++ b/dvc/repo/add.py @@ -85,7 +85,7 @@ def add( ) except OutputDuplicationError as exc: raise OutputDuplicationError( - exc.output, set(exc.stages) - set(stages) + exc.output, list(set(exc.stages) - set(stages)) ) link_failures.extend( diff --git a/dvc/repo/run.py b/dvc/repo/run.py index 8124355102..93ff4ad70b 100644 --- a/dvc/repo/run.py +++ b/dvc/repo/run.py @@ -1,9 +1,11 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from . import locked from .scm_context import scm_context if TYPE_CHECKING: + from dvc.stage import PipelineStage, Stage + from . import Repo @@ -11,29 +13,24 @@ @scm_context def run( self: "Repo", - fname: str = None, no_exec: bool = False, - single_stage: bool = False, + no_commit: bool = False, + run_cache: bool = True, + force: bool = True, **kwargs -): - from dvc.stage.utils import check_graphs, create_stage_from_cli - - stage = create_stage_from_cli( - self, single_stage=single_stage, fname=fname, **kwargs - ) +) -> Union["Stage", "PipelineStage", None]: + from dvc.stage.utils import validate_state - if kwargs.get("run_cache", True) and stage.can_be_skipped: + stage = self.stage.create_from_cli(**kwargs) + if run_cache and stage.can_be_skipped: return None - check_graphs(self, stage, force=kwargs.get("force", True)) + validate_state(self, stage, force=force) if no_exec: stage.ignore_outs() else: - stage.run( - no_commit=kwargs.get("no_commit", False), - run_cache=kwargs.get("run_cache", True), - ) + stage.run(no_commit=no_commit, run_cache=run_cache) stage.dump(update_lock=not no_exec) return stage diff --git a/dvc/repo/stage.py b/dvc/repo/stage.py index 1d90af8d3f..15df756ad0 100644 --- a/dvc/repo/stage.py +++ b/dvc/repo/stage.py @@ -3,7 +3,7 @@ import os import typing from contextlib import suppress -from typing import Iterable, List, NamedTuple, Optional, Set, Tuple +from typing import Iterable, List, NamedTuple, Optional, Set, Tuple, Union from dvc.exceptions import ( DvcException, @@ -19,7 +19,7 @@ from networkx import DiGraph from dvc.repo import Repo - from dvc.stage import Stage + from dvc.stage import PipelineStage, Stage from dvc.stage.loader import StageLoader from dvc.types import OptStr @@ -93,6 +93,21 @@ class StageLoad: def __init__(self, repo: "Repo") -> None: self.repo = repo + def create_from_cli( + self, validate: bool = False, **kwargs + ) -> Union["Stage", "PipelineStage"]: + """Creates a stage from CLI args passed as kwargs. + + Args: + validate: if true, the new created stage is checked against the + stages in the repo. Eg: graph correctness, + potential overwrites in dvc.yaml file (unless `force=True`). + """ + from dvc.stage.utils import create_stage_from_cli + + kwargs.update(validate=validate) + return create_stage_from_cli(self.repo, **kwargs) + def from_target( self, target: str, accept_group: bool = False, glob: bool = False, ) -> StageList: diff --git a/dvc/stage/utils.py b/dvc/stage/utils.py index d619cd9223..4464191db9 100644 --- a/dvc/stage/utils.py +++ b/dvc/stage/utils.py @@ -325,28 +325,37 @@ def _check_stage_exists( ) -def check_graphs( - repo: "Repo", stage: Union["Stage", "PipelineStage"], force: bool = True +def validate_state( + repo: "Repo", new: Union["Stage", "PipelineStage"], force: bool = False, ) -> None: - """Checks graph and if that stage already exists. + """Validates that the new stage: - If it exists in the dvc.yaml file, it errors out unless force is given. + * does not already exist with the same (name, path), unless force is True. + * does not affect the correctness of the repo's graph (to-be graph). """ - from dvc.exceptions import OutputDuplicationError + stages = repo.stages[:] if force else None + if force: + # remove an existing stage (if exists) + with suppress(ValueError): + # uses name and path to determine this which is same for the + # existing (if any) and the new one + stages.remove(new) + else: + _check_stage_exists(repo, new, new.path) + # if not `force`, we don't need to replace existing stage with the new + # one, as the dumping this is not going to replace it anyway (as the + # stage with same path + name does not exist (from above check). - try: - if force: - with suppress(ValueError): - repo.stages.remove(stage) - else: - _check_stage_exists(repo, stage, stage.path) - repo.check_modified_graph([stage]) - except OutputDuplicationError as exc: - raise OutputDuplicationError(exc.output, set(exc.stages) - {stage}) + repo.check_modified_graph(new_stages=[new], old_stages=stages) def create_stage_from_cli( - repo: "Repo", single_stage: bool = False, fname: str = None, **kwargs: Any + repo: "Repo", + single_stage: bool = False, + fname: str = None, + validate: bool = False, + force: bool = False, + **kwargs: Any, ) -> Union["Stage", "PipelineStage"]: from dvc.dvcfile import PIPELINE_FILE @@ -388,5 +397,9 @@ def create_stage_from_cli( stage = create_stage( stage_cls, repo=repo, path=path, params=params, **kwargs ) + + if validate: + validate_state(repo, stage, force=force) + restore_meta(stage) return stage diff --git a/tests/unit/command/test_stage.py b/tests/unit/command/test_stage.py index 01cf56178e..462091132a 100644 --- a/tests/unit/command/test_stage.py +++ b/tests/unit/command/test_stage.py @@ -57,13 +57,11 @@ def test_stage_add(mocker, dvc, extra_args, expected_extra): ] ) assert cli_args.func == CmdStageAdd - m = mocker.patch.object(CmdStageAdd, "create") cmd = cli_args.func(cli_args) + m = mocker.patch.object(cmd.repo.stage, "create_from_cli") assert cmd.run() == 0 - assert m.call_args[0] == (cmd.repo, True) - expected = dict( name="name", deps=["deps"],