Skip to content

Commit

Permalink
Add top-level artifacts: section (#9220)
Browse files Browse the repository at this point in the history
* add artifacts section

* adding and removing artifacts

* delete code for add/remove; add tests

* fix pre-commit

* remove extra piece of writing to dvc.yaml

* remove extra piece

* simplify kwargs

* allow same names for artifacts in different dvc.yaml files

* set ID as path; don't return dvc.yaml(s) with no artifacts

* fix PR feedback

* issue warning at incorrect artifact name

* issue warning at incorrect artifact name

* reference right schema

* add r to regexp

* add tests for regexp

* set artifacts for SingleStageFile

* fix windows test

* make path required; simplify regexp and remove slash from it

* allow to use gto's artifacts.yaml

* remove extra piece of code for migration from gto
  • Loading branch information
aguschin authored Apr 5, 2023
1 parent 3564813 commit 32bebc0
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 1 deletion.
11 changes: 11 additions & 0 deletions dvc/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, ClassVar, Dict, List, Optional

from funcy import compact
from voluptuous import Required


@dataclass
Expand All @@ -26,10 +27,20 @@ def to_dict(self) -> Dict[str, str]:
return compact(asdict(self))


@dataclass
class Artifact(Annotation):
PARAM_PATH: ClassVar[str] = "path"
path: Optional[str] = None


ANNOTATION_FIELDS = [field.name for field in fields(Annotation)]
ANNOTATION_SCHEMA = {
Annotation.PARAM_DESC: str,
Annotation.PARAM_TYPE: str,
Annotation.PARAM_LABELS: [str],
Annotation.PARAM_META: object,
}
ARTIFACT_SCHEMA = {
Required(Artifact.PARAM_PATH): str,
**ANNOTATION_SCHEMA, # type: ignore[arg-type]
}
5 changes: 5 additions & 0 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class SingleStageFile(FileMixin):
metrics: List[str] = []
plots: Any = {}
params: List[str] = []
artifacts: Dict[str, Optional[Dict[str, Any]]] = {}

@property
def stage(self) -> "Stage":
Expand Down Expand Up @@ -323,6 +324,10 @@ def plots(self) -> Any:
def params(self) -> List[str]:
return self.contents.get("params", [])

@property
def artifacts(self) -> Dict[str, Optional[Dict[str, Any]]]:
return self.contents.get("artifacts", {})

def remove(self, force=False):
if not force:
logger.warning("Cannot remove pipeline file.")
Expand Down
2 changes: 2 additions & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__( # noqa: PLR0915
from dvc.data_cloud import DataCloud
from dvc.fs import GitFileSystem, LocalFileSystem, localfs
from dvc.lock import LockNoop, make_lock
from dvc.repo.artifacts import Artifacts
from dvc.repo.metrics import Metrics
from dvc.repo.params import Params
from dvc.repo.plots import Plots
Expand Down Expand Up @@ -224,6 +225,7 @@ def __init__( # noqa: PLR0915
self.metrics: Metrics = Metrics(self)
self.plots: Plots = Plots(self)
self.params: Params = Params(self)
self.artifacts: Artifacts = Artifacts(self)

self.stage_collection_error_handler: Optional[
Callable[[str, Exception], None]
Expand Down
67 changes: 67 additions & 0 deletions dvc/repo/artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import os
import re
from typing import TYPE_CHECKING, Dict

from dvc.annotations import Artifact
from dvc.dvcfile import FileMixin
from dvc.utils import relpath

if TYPE_CHECKING:
from dvc.repo import Repo

logger = logging.getLogger(__name__)


NAME_RE = re.compile(r"^[a-z]([a-z0-9-]*[a-z0-9])?$")


def name_is_compatible(name: str) -> bool:
return bool(NAME_RE.search(name))


def check_name_format(name: str) -> None:
if not name_is_compatible(name):
logger.warning(
"Can't use '%s' as artifact name (ID)."
" You can use letters and numbers, and use '-' as separator"
" (but not at the start or end). The first character must be a letter.",
name,
)


class ArtifactsFile(FileMixin):
from dvc.schema import SINGLE_ARTIFACT_SCHEMA as SCHEMA

def dump(self, stage, **kwargs):
raise NotImplementedError

def merge(self, ancestor, other, allowed=None):
raise NotImplementedError


class Artifacts:
def __init__(self, repo: "Repo") -> None:
self.repo = repo

def read(self) -> Dict[str, Dict[str, Artifact]]:
artifacts: Dict[str, Dict[str, Artifact]] = {}
for (
dvcfile,
dvcfile_artifacts,
) in self.repo.index._artifacts.items(): # pylint: disable=protected-access
# read the artifacts.yaml file if needed
if isinstance(dvcfile_artifacts, str):
dvcfile_artifacts = ArtifactsFile(
self.repo,
os.path.join(os.path.dirname(dvcfile), dvcfile_artifacts),
verify=False,
).load()
if not dvcfile_artifacts:
continue
dvcyaml = relpath(dvcfile, self.repo.root_dir)
artifacts[dvcyaml] = {}
for name, value in dvcfile_artifacts.items():
check_name_format(name)
artifacts[dvcyaml][name] = Artifact(**value)
return artifacts
7 changes: 7 additions & 0 deletions dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,14 @@ def __init__(
metrics: Optional[Dict[str, List[str]]] = None,
plots: Optional[Dict[str, List[str]]] = None,
params: Optional[Dict[str, Any]] = None,
artifacts: Optional[Dict[str, Any]] = None,
) -> None:
self.repo = repo
self.stages = stages or []
self._metrics = metrics or {}
self._plots = plots or {}
self._params = params or {}
self._artifacts = artifacts or {}
self._collected_targets: Dict[int, List["StageInfo"]] = {}

@cached_property
Expand All @@ -202,6 +204,7 @@ def from_repo(
metrics = {}
plots = {}
params = {}
artifacts = {}

onerror = onerror or repo.stage_collection_error_handler
for _, idx in collect_files(repo, onerror=onerror):
Expand All @@ -210,12 +213,14 @@ def from_repo(
metrics.update(idx._metrics)
plots.update(idx._plots)
params.update(idx._params)
artifacts.update(idx._artifacts)
return cls(
repo,
stages=stages,
metrics=metrics,
plots=plots,
params=params,
artifacts=artifacts,
)

@classmethod
Expand All @@ -229,6 +234,7 @@ def from_file(cls, repo: "Repo", path: str) -> "Index":
metrics={path: dvcfile.metrics} if dvcfile.metrics else {},
plots={path: dvcfile.plots} if dvcfile.plots else {},
params={path: dvcfile.params} if dvcfile.params else {},
artifacts={path: dvcfile.artifacts} if dvcfile.artifacts else {},
)

def update(self, stages: Iterable["Stage"]) -> "Index":
Expand All @@ -242,6 +248,7 @@ def update(self, stages: Iterable["Stage"]) -> "Index":
metrics=self._metrics,
plots=self._plots,
params=self._params,
artifacts=self._artifacts,
)

@cached_property
Expand Down
6 changes: 5 additions & 1 deletion dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from voluptuous import Any, Optional, Required, Schema

from dvc import dependency, output
from dvc.annotations import ANNOTATION_SCHEMA
from dvc.annotations import ANNOTATION_SCHEMA, ARTIFACT_SCHEMA
from dvc.output import (
CHECKSUMS_SCHEMA,
CLOUD_SCHEMA,
Expand Down Expand Up @@ -114,6 +114,9 @@ def validator(data):
Output.PARAM_PLOT_TEMPLATE: str,
}
SINGLE_PLOT_SCHEMA = {str: Any(PLOT_DEFINITION, None)}
ARTIFACTS = "artifacts"
SINGLE_ARTIFACT_SCHEMA = Schema({str: ARTIFACT_SCHEMA})
ARTIFACTS_SCHEMA = Any(str, SINGLE_ARTIFACT_SCHEMA)
FOREACH_IN = {
Required(FOREACH_KWD): Any(dict, list, str),
Required(DO_KWD): STAGE_DEFINITION,
Expand All @@ -127,6 +130,7 @@ def validator(data):
VARS_KWD: VARS_SCHEMA,
StageParams.PARAM_PARAMS: [str],
StageParams.PARAM_METRICS: [str],
ARTIFACTS: ARTIFACTS_SCHEMA,
}

COMPILED_SINGLE_STAGE_SCHEMA = Schema(SINGLE_STAGE_SCHEMA)
Expand Down
Empty file.
132 changes: 132 additions & 0 deletions tests/func/artifacts/test_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os

import pytest

from dvc.annotations import Artifact
from dvc.repo.artifacts import name_is_compatible
from dvc.utils.strictyaml import YAMLSyntaxError, YAMLValidationError

dvcyaml = {
"artifacts": {
"myart": {"type": "model", "path": "myart.pkl"},
"hello": {"type": "file", "path": "hello.txt"},
"world": {
"type": "object",
"path": "world.txt",
"desc": "The world is not enough",
"labels": ["but", "this", "is"],
"meta": {"such": "a", "perfect": "place to start"},
},
}
}


def test_reading_artifacts_subdir(tmp_dir, dvc):
(tmp_dir / "dvc.yaml").dump(dvcyaml)

subdir = tmp_dir / "subdir"
subdir.mkdir()

(subdir / "dvc.yaml").dump(dvcyaml)

artifacts = {
name: Artifact(**values) for name, values in dvcyaml["artifacts"].items()
}
assert tmp_dir.dvc.artifacts.read() == {
"dvc.yaml": artifacts,
f"subdir{os.path.sep}dvc.yaml": artifacts,
}


bad_dvcyaml_extra_field = {
"artifacts": {
"lol": {"kek": "cheburek", "path": "lol"},
"hello": {"type": "file", "path": "hello.txt"},
}
}


bad_dvcyaml_missing_path = {
"artifacts": {
"lol": {},
}
}


@pytest.mark.parametrize(
"bad_dvcyaml", [bad_dvcyaml_extra_field, bad_dvcyaml_missing_path]
)
def test_broken_dvcyaml_extra_field(tmp_dir, dvc, bad_dvcyaml):
(tmp_dir / "dvc.yaml").dump(bad_dvcyaml)

with pytest.raises(YAMLValidationError):
tmp_dir.dvc.artifacts.read()


bad_dvcyaml_id_duplication = """
artifacts:
lol:
type: kek
lol: {}
"""


def test_broken_dvcyaml_id_duplication(tmp_dir, dvc):
with open(tmp_dir / "dvc.yaml", "w") as f:
f.write(bad_dvcyaml_id_duplication)

with pytest.raises(YAMLSyntaxError):
tmp_dir.dvc.artifacts.read()


dvcyaml_redirecting = {"artifacts": "artifacts.yaml"}


def test_read_artifacts_yaml(tmp_dir, dvc):
(tmp_dir / "dvc.yaml").dump(dvcyaml_redirecting)
(tmp_dir / "artifacts.yaml").dump(dvcyaml["artifacts"])

artifacts = {
name: Artifact(**values) for name, values in dvcyaml["artifacts"].items()
}
assert tmp_dir.dvc.artifacts.read() == {
"dvc.yaml": artifacts,
}


@pytest.mark.parametrize(
"name",
[
"m",
"nn",
"m1",
"model-prod",
"model-prod-v1",
],
)
def test_check_name_is_valid(name):
assert name_is_compatible(name)


@pytest.mark.parametrize(
"name",
[
"",
"1",
"m/",
"/m",
"1nn",
"###",
"@@@",
"a model",
"a_model",
"-model",
"model-",
"model@1",
"model#1",
"@namespace/model",
"namespace/model",
],
)
def test_check_name_is_invalid(name):
assert not name_is_compatible(name)

0 comments on commit 32bebc0

Please sign in to comment.