Skip to content

Commit

Permalink
add more tests for collection of outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Apr 27, 2020
1 parent 5d53fa8 commit dd2c992
Show file tree
Hide file tree
Showing 28 changed files with 509 additions and 162 deletions.
8 changes: 4 additions & 4 deletions dvc/command/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def _show(self, target, commands, outs, locked):
from dvc import dvcfile
from dvc.utils import parse_target

path, name = parse_target(target)
stage = dvcfile.Dvcfile(self.repo, path).stages[name]
path, name, tag = parse_target(target)
stage = dvcfile.Dvcfile(self.repo, path, tag=tag).stages[name]
G = self.repo.graph
stages = networkx.dfs_postorder_nodes(G, stage)
if locked:
Expand All @@ -38,8 +38,8 @@ def _build_graph(self, target, commands=False, outs=False):
from dvc.repo.graph import get_pipeline
from dvc.utils import parse_target

path, name = parse_target(target)
target_stage = dvcfile.Dvcfile(self.repo, path).stages[name]
path, name, tag = parse_target(target)
target_stage = dvcfile.Dvcfile(self.repo, path, tag=tag).stages[name]
G = get_pipeline(self.repo.pipelines, target_stage)

nodes = set()
Expand Down
134 changes: 77 additions & 57 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import contextlib
import os
import re
import logging

import dvc.prompt as prompt

from voluptuous import MultipleInvalid
from dvc import serialize
from dvc.exceptions import DvcException
from dvc.loader import SingleStageLoader, StageLoader
from dvc.stage.loader import SingleStageLoader, StageLoader
from dvc.stage.exceptions import (
StageFileBadNameError,
StageFileDoesNotExistError,
Expand All @@ -16,6 +16,7 @@
StageFileAlreadyExistsError,
)
from dvc.utils import relpath
from dvc.utils.collections import apply_diff
from dvc.utils.stage import (
dump_stage_file,
parse_stage,
Expand All @@ -28,7 +29,6 @@
DVC_FILE_SUFFIX = ".dvc"
PIPELINE_FILE = "pipelines.yaml"
PIPELINE_LOCK = "pipelines.lock"
TAG_REGEX = r"^(?P<path>.*)@(?P<tag>[^\\/@:]*)$"


def is_valid_filename(path):
Expand All @@ -39,7 +39,9 @@ def is_valid_filename(path):


def is_dvc_file(path):
return os.path.isfile(path) and is_valid_filename(path)
return os.path.isfile(path) and (
is_valid_filename(path) or os.path.basename(path) == PIPELINE_LOCK
)


def check_dvc_filename(path):
Expand All @@ -52,59 +54,46 @@ def check_dvc_filename(path):
)


def _get_path_tag(s):
regex = re.compile(TAG_REGEX)
match = regex.match(s)
if not match:
return s, None
return match.group("path"), match.group("tag")


class MultiStageFileLoadError(DvcException):
def __init__(self, file):
super().__init__("Cannot load multi-stage file: '{}'".format(file))


class FileMixin:
SCHEMA = None

def __init__(self, repo, path):
def __init__(self, repo, path, **kwargs):
self.repo = repo
self.path, self.tag = _get_path_tag(path)
self.path = path

def __repr__(self):
return "{}: {}".format(
DVC_FILE, relpath(self.path, self.repo.root_dir)
self.__class__.__name__, relpath(self.path, self.repo.root_dir)
)

def __hash__(self):
return hash(self.path)

def __eq__(self, other):
return self.repo == other.repo and os.path.abspath(
self.path
) == os.path.abspath(other.path)

def __str__(self):
return "{}: {}".format(DVC_FILE, self.relpath)
return "{}: {}".format(self.__class__.__name__, self.relpath)

@property
def relpath(self):
return relpath(self.path)

def exists(self):
return self.repo.tree.exists(self.path)

def check_file_exists(self):
if not self.exists():
raise StageFileDoesNotExistError(self.path)

def check_isfile(self):
if not self.repo.tree.isfile(self.path):
raise StageFileIsNotDvcFileError(self.path)

def check_filename(self):
raise NotImplementedError

def _load(self):
# it raises the proper exceptions by priority:
# 1. when the file doesn't exists
# 2. filename is not a DVC-file
# 3. path doesn't represent a regular file
self.check_file_exists()
if not self.exists():
raise StageFileDoesNotExistError(self.path)
check_dvc_filename(self.path)
self.check_isfile()
if not self.repo.tree.isfile(self.path):
raise StageFileIsNotDvcFileError(self.path)

with self.repo.tree.open(self.path) as fd:
stage_text = fd.read()
Expand All @@ -121,31 +110,28 @@ def validate(cls, d, fname=None):
raise StageFileFormatError(fname, exc)

def remove_with_prompt(self, force=False):
if not self.exists():
return

msg = (
"'{}' already exists. Do you wish to run the command and "
"overwrite it?".format(relpath(self.path))
)
if not (force or prompt.confirm(msg)):
raise StageFileAlreadyExistsError(self.path)
raise NotImplementedError

os.unlink(self.path)
def remove(self):
raise NotImplementedError


class SingleStageFile(FileMixin):
from dvc.schema import COMPILED_SINGLE_STAGE_SCHEMA as SCHEMA

def __init__(self, repo, path, tag=None):
super().__init__(repo, path)
self.tag = tag

@property
def stage(self):
data, raw = self._load()
return SingleStageLoader.load_stage(self, data, raw)
return SingleStageLoader.load_stage(self, data, raw, tag=self.tag)

@property
def stages(self):
data, raw = self._load()
return SingleStageLoader(self, data, raw)
return SingleStageLoader(self, data, raw, tag=self.tag)

def dump(self, stage, **kwargs):
"""Dumps given stage appropriately in the dvcfile."""
Expand All @@ -159,8 +145,27 @@ def dump(self, stage, **kwargs):
dump_stage_file(self.path, serialize.to_single_stage_file(stage))
self.repo.scm.track_file(relpath(self.path))

def remove_with_prompt(self, force=False):
if not self.exists():
return

msg = (
"'{}' already exists. Do you wish to run the command and "
"overwrite it?".format(relpath(self.path))
)
if not (force or prompt.confirm(msg)):
raise StageFileAlreadyExistsError(self.path)

self.remove()

def remove(self, force=False):
with contextlib.suppress(FileNotFoundError):
os.unlink(self.path)


class PipelineFile(FileMixin):
"""Abstraction for pipelines file, .yaml + .lock combined."""

from dvc.schema import COMPILED_MULTI_STAGE_SCHEMA as SCHEMA

@property
Expand All @@ -172,6 +177,7 @@ def dump(self, stage, update_pipeline=False):
from dvc.stage import PipelineStage

assert isinstance(stage, PipelineStage)
check_dvc_filename(self.path)
self._dump_lockfile(stage)
if update_pipeline and not stage.is_data_source:
self._dump_pipeline_file(stage)
Expand All @@ -191,14 +197,21 @@ def _dump_pipeline_file(self, stage):
open(self.path, "w+").close()

data["stages"] = data.get("stages", {})
data["stages"].update(serialize.to_dvcfile(stage))
stage_data = serialize.to_dvcfile(stage)
if data["stages"].get(stage.name):
orig_stage_data = data["stages"][stage.name]
apply_diff(stage_data[stage.name], orig_stage_data)
else:
data["stages"].update(stage_data)

dump_stage_file(self.path, self.SCHEMA(data))
dump_stage_file(self.path, data)
self.repo.scm.track_file(relpath(self.path))

@property
def stage(self):
raise MultiStageFileLoadError(self.path)
raise DvcException(
"PipelineFile has multiple stages. Please specify it's name."
)

@property
def stages(self):
Expand All @@ -208,16 +221,23 @@ def stages(self):
lockfile_data = lockfile.load(self.repo, self._lockfile)
return StageLoader(self, data.get("stages", {}), lockfile_data)

def remove(self, force=False):
if not force:
logger.warning("Cannot remove pipeline file.")
return

for file in [self.path, self._lockfile]:
with contextlib.suppress(FileNotFoundError):
os.unlink(file)


class Dvcfile:
def __new__(cls, repo, path):
def __new__(cls, repo, path, **kwargs):
assert path
assert repo

file, _ = _get_path_tag(path)
_, ext = os.path.splitext(file)
assert not ext or ext in [".yml", ".yaml", ".dvc"]

if not ext or ext == DVC_FILE_SUFFIX:
return SingleStageFile(repo, path)
return PipelineFile(repo, path)
_, ext = os.path.splitext(path)
if ext in [".yaml", ".yml"]:
return PipelineFile(repo, path, **kwargs)
# fallback to single stage file for better error messages
return SingleStageFile(repo, path, **kwargs)
2 changes: 2 additions & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(self, output, stages):
output, "\n".join("\t{}".format(s.addressing) for s in stages)
)
super().__init__(msg)
self.stages = stages
self.output = output


class OutputNotFoundError(DvcException):
Expand Down
8 changes: 4 additions & 4 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def collect(self, target, with_deps=False, recursive=False, graph=None):
os.path.abspath(target), graph or self.graph
)

file, name, = parse_target(target)
dvcfile = Dvcfile(self, file)
file, name, tag = parse_target(target)
dvcfile = Dvcfile(self, file, tag=tag)
stages = list(dvcfile.stages.filter(name).values())
if not with_deps:
return stages
Expand All @@ -229,10 +229,10 @@ def collect_granular(self, target, *args, **kwargs):
if not target:
return [(stage, None) for stage in self.stages]

file, name = parse_target(target)
file, name, tag = parse_target(target)
if is_valid_filename(file) and not kwargs.get("with_deps"):
# Optimization: do not collect the graph for a specific .dvc target
stages = Dvcfile(self, file).stages.filter(name)
stages = Dvcfile(self, file, tag=tag).stages.filter(name)
return [(stage, None) for stage in stages.values()]

try:
Expand Down
5 changes: 5 additions & 0 deletions dvc/repo/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..exceptions import (
RecursiveAddingWhileUsingFilename,
OverlappingOutputPathsError,
OutputDuplicationError,
)
from ..output.base import OutputDoesNotExistError
from ..progress import Tqdm
Expand Down Expand Up @@ -68,6 +69,10 @@ def add(repo, targets, recursive=False, no_commit=False, fname=None):
raise OverlappingOutputPathsError(
exc.parent, exc.overlapping_out, msg
)
except OutputDuplicationError as exc:
raise OutputDuplicationError(
exc.output, set(exc.stages) - set(stages)
)

with Tqdm(
total=len(stages),
Expand Down
3 changes: 2 additions & 1 deletion dvc/repo/destroy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
@locked
def _destroy_stages(repo):
for stage in repo.stages:
stage.remove(remove_outs=False)
stage.unprotect_outs()
stage.dvcfile.remove(force=True)


# NOTE: not locking `destroy`, as `remove` will need to delete `.dvc` dir,
Expand Down
8 changes: 7 additions & 1 deletion dvc/repo/imp_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dvc.repo.scm_context import scm_context
from dvc.utils import resolve_output, resolve_paths, relpath
from dvc.utils.fs import path_isin
from ..exceptions import OutputDuplicationError


@locked_repo
Expand Down Expand Up @@ -35,7 +36,12 @@ def imp_url(self, url, out=None, fname=None, erepo=None, locked=True):
dvcfile = Dvcfile(self, stage.path)
dvcfile.remove_with_prompt(force=True)

self.check_modified_graph([stage])
try:
self.check_modified_graph([stage])
except OutputDuplicationError as exc:
raise OutputDuplicationError(
exc.output, set(exc.stages) - set([stage])
)

stage.run()

Expand Down
6 changes: 3 additions & 3 deletions dvc/repo/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ def lock(self, target, unlock=False):
from .. import dvcfile
from dvc.utils import parse_target

path, target = parse_target(target)
dvcfile = dvcfile.Dvcfile(self, path)
stage = dvcfile.stages[target]
path, name, tag = parse_target(target)
dvcfile = dvcfile.Dvcfile(self, path, tag=tag)
stage = dvcfile.stages[name]
stage.locked = False if unlock else True
dvcfile.dump(stage, update_pipeline=True)

Expand Down
Loading

0 comments on commit dd2c992

Please sign in to comment.