Skip to content

Commit

Permalink
use repo.stage.create_from_cli to create stages (#5281)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Jan 18, 2021
1 parent 4f2de7a commit 78e1232
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 48 deletions.
13 changes: 3 additions & 10 deletions dvc/command/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,12 @@


class CmdStageAdd(CmdBase):
@staticmethod
def create(repo, force, **kwargs):
from dvc.stage.utils import check_graphs, create_stage_from_cli

stage = create_stage_from_cli(repo, **kwargs)
check_graphs(repo, stage, force=force)
return stage

def run(self):
stage = self.create(self.repo, self.args.force, **vars(self.args))
kwargs = vars(self.args)
stage = self.repo.stage.create_from_cli(validate=True, **kwargs)

stage.ignore_outs()
stage.dump()

return 0


Expand Down
5 changes: 3 additions & 2 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _ignore(self):

self.scm.ignore_list(flist)

def check_modified_graph(self, new_stages):
def check_modified_graph(self, new_stages, old_stages=None):
"""Generate graph including the new stage to check for errors"""
# Building graph might be costly for the ones with many DVC-files,
# so we provide this undocumented hack to skip it. See [1] for
Expand All @@ -295,7 +295,8 @@ def check_modified_graph(self, new_stages):
#
# [1] https://github.com/iterative/dvc/issues/2671
if not getattr(self, "_skip_graph_checks", False):
build_graph(self.stages + new_stages)
existing_stages = self.stages if old_stages is None else old_stages
build_graph(existing_stages + new_stages)

def used_cache(
self,
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def add(
)
except OutputDuplicationError as exc:
raise OutputDuplicationError(
exc.output, set(exc.stages) - set(stages)
exc.output, list(set(exc.stages) - set(stages))
)

link_failures.extend(
Expand Down
27 changes: 12 additions & 15 deletions dvc/repo/run.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,36 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

from . import locked
from .scm_context import scm_context

if TYPE_CHECKING:
from dvc.stage import PipelineStage, Stage

from . import Repo


@locked
@scm_context
def run(
self: "Repo",
fname: str = None,
no_exec: bool = False,
single_stage: bool = False,
no_commit: bool = False,
run_cache: bool = True,
force: bool = True,
**kwargs
):
from dvc.stage.utils import check_graphs, create_stage_from_cli

stage = create_stage_from_cli(
self, single_stage=single_stage, fname=fname, **kwargs
)
) -> Union["Stage", "PipelineStage", None]:
from dvc.stage.utils import validate_state

if kwargs.get("run_cache", True) and stage.can_be_skipped:
stage = self.stage.create_from_cli(**kwargs)
if run_cache and stage.can_be_skipped:
return None

check_graphs(self, stage, force=kwargs.get("force", True))
validate_state(self, stage, force=force)

if no_exec:
stage.ignore_outs()
else:
stage.run(
no_commit=kwargs.get("no_commit", False),
run_cache=kwargs.get("run_cache", True),
)
stage.run(no_commit=no_commit, run_cache=run_cache)

stage.dump(update_lock=not no_exec)
return stage
19 changes: 17 additions & 2 deletions dvc/repo/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import typing
from contextlib import suppress
from typing import Iterable, List, NamedTuple, Optional, Set, Tuple
from typing import Iterable, List, NamedTuple, Optional, Set, Tuple, Union

from dvc.exceptions import (
DvcException,
Expand All @@ -19,7 +19,7 @@
from networkx import DiGraph

from dvc.repo import Repo
from dvc.stage import Stage
from dvc.stage import PipelineStage, Stage
from dvc.stage.loader import StageLoader
from dvc.types import OptStr

Expand Down Expand Up @@ -93,6 +93,21 @@ class StageLoad:
def __init__(self, repo: "Repo") -> None:
self.repo = repo

def create_from_cli(
self, validate: bool = False, **kwargs
) -> Union["Stage", "PipelineStage"]:
"""Creates a stage from CLI args passed as kwargs.
Args:
validate: if true, the new created stage is checked against the
stages in the repo. Eg: graph correctness,
potential overwrites in dvc.yaml file (unless `force=True`).
"""
from dvc.stage.utils import create_stage_from_cli

kwargs.update(validate=validate)
return create_stage_from_cli(self.repo, **kwargs)

def from_target(
self, target: str, accept_group: bool = False, glob: bool = False,
) -> StageList:
Expand Down
43 changes: 28 additions & 15 deletions dvc/stage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,28 +325,37 @@ def _check_stage_exists(
)


def check_graphs(
repo: "Repo", stage: Union["Stage", "PipelineStage"], force: bool = True
def validate_state(
repo: "Repo", new: Union["Stage", "PipelineStage"], force: bool = False,
) -> None:
"""Checks graph and if that stage already exists.
"""Validates that the new stage:
If it exists in the dvc.yaml file, it errors out unless force is given.
* does not already exist with the same (name, path), unless force is True.
* does not affect the correctness of the repo's graph (to-be graph).
"""
from dvc.exceptions import OutputDuplicationError
stages = repo.stages[:] if force else None
if force:
# remove an existing stage (if exists)
with suppress(ValueError):
# uses name and path to determine this which is same for the
# existing (if any) and the new one
stages.remove(new)
else:
_check_stage_exists(repo, new, new.path)
# if not `force`, we don't need to replace existing stage with the new
# one, as the dumping this is not going to replace it anyway (as the
# stage with same path + name does not exist (from above check).

try:
if force:
with suppress(ValueError):
repo.stages.remove(stage)
else:
_check_stage_exists(repo, stage, stage.path)
repo.check_modified_graph([stage])
except OutputDuplicationError as exc:
raise OutputDuplicationError(exc.output, set(exc.stages) - {stage})
repo.check_modified_graph(new_stages=[new], old_stages=stages)


def create_stage_from_cli(
repo: "Repo", single_stage: bool = False, fname: str = None, **kwargs: Any
repo: "Repo",
single_stage: bool = False,
fname: str = None,
validate: bool = False,
force: bool = False,
**kwargs: Any,
) -> Union["Stage", "PipelineStage"]:

from dvc.dvcfile import PIPELINE_FILE
Expand Down Expand Up @@ -388,5 +397,9 @@ def create_stage_from_cli(
stage = create_stage(
stage_cls, repo=repo, path=path, params=params, **kwargs
)

if validate:
validate_state(repo, stage, force=force)

restore_meta(stage)
return stage
4 changes: 1 addition & 3 deletions tests/unit/command/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,11 @@ def test_stage_add(mocker, dvc, extra_args, expected_extra):
]
)
assert cli_args.func == CmdStageAdd
m = mocker.patch.object(CmdStageAdd, "create")

cmd = cli_args.func(cli_args)
m = mocker.patch.object(cmd.repo.stage, "create_from_cli")

assert cmd.run() == 0
assert m.call_args[0] == (cmd.repo, True)

expected = dict(
name="name",
deps=["deps"],
Expand Down

0 comments on commit 78e1232

Please sign in to comment.