Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log_artifact: add cache option, only write to dvc.yaml if metadata ex… #620

Merged
merged 1 commit into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,11 @@ def log_artifact(
path: StrPath,
type: Optional[str] = None, # noqa: A002
name: Optional[str] = None,
desc: Optional[str] = None, # noqa: ARG002
labels: Optional[List[str]] = None, # noqa: ARG002
meta: Optional[Dict[str, Any]] = None, # noqa: ARG002
desc: Optional[str] = None,
labels: Optional[List[str]] = None,
meta: Optional[Dict[str, Any]] = None,
copy: bool = False,
cache: bool = True,
):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
Expand All @@ -428,21 +429,24 @@ def log_artifact(
if copy:
path = clean_and_copy_into(path, self.artifacts_dir)

self.cache(path)

name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta") and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)
if cache:
self.cache(path)

if any((type, name, desc, labels, meta)):
name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta")
and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)

def cache(self, path):
try:
Expand Down
26 changes: 15 additions & 11 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import shutil
from pathlib import Path

import pytest

from dvclive import Live
from dvclive.serialize import load_yaml


def test_log_artifact(tmp_dir, dvc_repo):
@pytest.mark.parametrize("cache", [True, False])
def test_log_artifact(tmp_dir, dvc_repo, cache):
data = tmp_dir / "data"
data.touch()
with Live() as live:
live.log_artifact("data")
assert data.with_suffix(".dvc").exists()
live.log_artifact("data", cache=cache)
assert data.with_suffix(".dvc").exists() is cache
assert load_yaml(live.dvc_file) == {}


def test_log_artifact_on_existing_dvc_file(tmp_dir, dvc_repo):
Expand Down Expand Up @@ -78,14 +82,14 @@ def test_log_artifact_copy(tmp_dir, dvc_repo):
(tmp_dir / "model.pth").touch()

with Live() as live:
live.log_artifact("model.pth", copy=True)
live.log_artifact("model.pth", type="model", copy=True)

artifacts_dir = Path(live.artifacts_dir)
assert (artifacts_dir / "model.pth").exists()
assert (artifacts_dir / "model.pth.dvc").exists()

assert load_yaml(live.dvc_file) == {
"artifacts": {"model": {"path": "artifacts/model.pth"}}
"artifacts": {"model": {"path": "artifacts/model.pth", "type": "model"}}
}


Expand All @@ -97,15 +101,15 @@ def test_log_artifact_copy_overwrite(tmp_dir, dvc_repo):
# testing with symlink cache to make sure that DVC protected mode
# does not prevent the overwrite
live._dvc_repo.cache.local.cache_types = ["symlink"]
live.log_artifact("model.pth", copy=True)
live.log_artifact("model.pth", type="model", copy=True)
assert (artifacts_dir / "model.pth").is_symlink()
live.log_artifact("model.pth", copy=True)
live.log_artifact("model.pth", type="model", copy=True)

assert (artifacts_dir / "model.pth").exists()
assert (artifacts_dir / "model.pth.dvc").exists()

assert load_yaml(live.dvc_file) == {
"artifacts": {"model": {"path": "artifacts/model.pth"}}
"artifacts": {"model": {"path": "artifacts/model.pth", "type": "model"}}
}


Expand All @@ -119,14 +123,14 @@ def test_log_artifact_copy_directory_overwrite(tmp_dir, dvc_repo):
# testing with symlink cache to make sure that DVC protected mode
# does not prevent the overwrite
live._dvc_repo.cache.local.cache_types = ["symlink"]
live.log_artifact(model_path, copy=True)
live.log_artifact(model_path, type="model", copy=True)
assert (artifacts_dir / "weights" / "model-epoch-1.pth").is_symlink()

shutil.rmtree(model_path)
model_path.mkdir()
(tmp_dir / "weights" / "model-epoch-10.pth").write_text("Model weights")
(tmp_dir / "weights" / "best.pth").write_text("Best model weights")
live.log_artifact(model_path, copy=True)
live.log_artifact(model_path, type="model", copy=True)

assert (artifacts_dir / "weights").exists()
assert (artifacts_dir / "weights" / "best.pth").is_symlink()
Expand All @@ -135,7 +139,7 @@ def test_log_artifact_copy_directory_overwrite(tmp_dir, dvc_repo):
assert len(list((artifacts_dir / "weights").iterdir())) == 2

assert load_yaml(live.dvc_file) == {
"artifacts": {"weights": {"path": "artifacts/weights"}}
"artifacts": {"weights": {"path": "artifacts/weights", "type": "model"}}
}


Expand Down