Skip to content

Commit

Permalink
add _inside_dvc_pipeline (#718)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Berenbaum authored Oct 5, 2023
1 parent f4fd464 commit d21fa55
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
20 changes: 11 additions & 9 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self._exp_message: Optional[str] = exp_message
self._experiment_rev: Optional[str] = None
self._inside_dvc_exp: bool = False
self._inside_dvc_pipeline: bool = False
self._dvc_repo = None
self._include_untracked: List[str] = []
self._init_dvc()
Expand Down Expand Up @@ -137,7 +138,9 @@ def _init_cleanup(self):
def _init_dvc(self):
from dvc.scm import NoSCM

self._init_dvc_pipeline()
if os.getenv(env.DVC_ROOT, None):
self._inside_dvc_pipeline = True
self._init_dvc_pipeline()
self._dvc_repo = get_dvc_repo()

dvc_logger = logging.getLogger("dvc")
Expand Down Expand Up @@ -170,7 +173,7 @@ def _init_dvc(self):
"\nRemove it from outputs to make DVCLive work as expected."
)

if self._inside_dvc_exp:
if self._inside_dvc_pipeline:
return

self._baseline_rev = self._dvc_repo.scm.get_rev()
Expand Down Expand Up @@ -200,16 +203,15 @@ def _init_dvc_pipeline(self):
self._inside_dvc_exp = True
if self._save_dvc_exp:
logger.info("Ignoring `save_dvc_exp` because `dvc exp run` is running")
self._save_dvc_exp = False
elif os.getenv(env.DVC_ROOT, None):
else:
# `dvc repro` execution
if self._save_dvc_exp:
logger.info("Ignoring `save_dvc_exp` because `dvc repro` is running")
self._save_dvc_exp = False
logger.warning(
"Some DVCLive features are unsupported in `dvc repro`."
"\nTo use DVCLive with a DVC Pipeline, run it with `dvc exp run`."
)
self._save_dvc_exp = False

def _init_studio(self):
self._dvc_studio_config = get_dvc_studio_config(self)
Expand Down Expand Up @@ -494,25 +496,25 @@ def log_artifact(

@catch_and_warn(DvcException, logger)
def cache(self, path):
if self._inside_dvc_exp:
if self._inside_dvc_pipeline:
existing_stage = find_overlapping_stage(self._dvc_repo, path)

if existing_stage:
if existing_stage.cmd:
logger.info(
f"Skipping `dvc add {path}` because it is already being"
" tracked automatically as an output of `dvc exp run`."
" tracked automatically as an output of the DVC pipeline."
)
return # skip caching
logger.warning(
f"To track '{path}' automatically during `dvc exp run`:"
f"To track '{path}' automatically in the DVC pipeline:"
f"\n1. Run `dvc remove {existing_stage.addressing}` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
else:
logger.warning(
f"To track '{path}' automatically during `dvc exp run`, "
f"To track '{path}' automatically in the DVC pipeline, "
"add it as an output of the pipeline stage."
)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo):
def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo")
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)

mocker.patch("dvclive.live.get_dvc_repo", return_value=None)
live = Live()
Expand All @@ -56,6 +57,7 @@ def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
assert live._baseline_rev == "foo"
assert live._exp_name == "bar"
assert live._inside_dvc_exp
assert live._inside_dvc_pipeline


def test_exp_save_run_on_dvc_repro(tmp_dir, mocker):
Expand Down Expand Up @@ -126,6 +128,7 @@ def test_untracked_dvclive_files_inside_dvc_exp_run_are_added(
):
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo")
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)
plot_file = os.path.join("dvclive", "plots", "metrics", "foo.tsv")
mocked_dvc_repo.scm.untracked_files.return_value = [
"dvclive/metrics.json",
Expand All @@ -142,6 +145,7 @@ def test_dvc_outs_are_not_added(tmp_dir, mocked_dvc_repo, monkeypatch):
"""Regression test for https://github.com/iterative/dvclive/issues/516"""
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo")
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)
mocked_dvc_repo.index.outs = ["dvclive/plots"]
plot_file = os.path.join("dvclive", "plots", "metrics", "foo.tsv")
mocked_dvc_repo.scm.untracked_files.return_value = [
Expand Down
16 changes: 8 additions & 8 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_log_artifact_type_model_when_dvc_add_fails(tmp_dir, mocker, mocked_dvc_


@pytest.mark.parametrize("tracked", ["data_source", "stage", None])
def test_log_artifact_inside_exp(tmp_dir, mocker, dvc_repo, tracked):
def test_log_artifact_inside_pipeline(tmp_dir, mocker, dvc_repo, tracked):
logger = mocker.patch("dvclive.live.logger")
data = tmp_dir / "data"
data.touch()
Expand All @@ -239,18 +239,18 @@ def test_log_artifact_inside_exp(tmp_dir, mocker, dvc_repo, tracked):
f.write(dvcyaml)
live = Live(save_dvc_exp=False)
spy = mocker.spy(live._dvc_repo, "add")
live._inside_dvc_exp = True
live._inside_dvc_pipeline = True
live.log_artifact("data")
if tracked == "stage":
msg = (
"Skipping `dvc add data` because it is already being tracked"
" automatically as an output of `dvc exp run`."
" automatically as an output of the DVC pipeline."
)
logger.info.assert_called_with(msg)
spy.assert_not_called()
elif tracked == "data_source":
msg = (
"To track 'data' automatically during `dvc exp run`:"
"To track 'data' automatically in the DVC pipeline:"
"\n1. Run `dvc remove data.dvc` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
Expand All @@ -259,14 +259,14 @@ def test_log_artifact_inside_exp(tmp_dir, mocker, dvc_repo, tracked):
spy.assert_called_once()
else:
msg = (
"To track 'data' automatically during `dvc exp run`, "
"To track 'data' automatically in the DVC pipeline, "
"add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)
spy.assert_called_once()


def test_log_artifact_inside_exp_subdir(tmp_dir, mocker, dvc_repo):
def test_log_artifact_inside_pipeline_subdir(tmp_dir, mocker, dvc_repo):
logger = mocker.patch("dvclive.live.logger")
subdir = tmp_dir / "subdir"
subdir.mkdir()
Expand All @@ -275,10 +275,10 @@ def test_log_artifact_inside_exp_subdir(tmp_dir, mocker, dvc_repo):
dvc_repo.add(subdir)
live = Live()
spy = mocker.spy(live._dvc_repo, "add")
live._inside_dvc_exp = True
live._inside_dvc_pipeline = True
live.log_artifact("subdir/data")
msg = (
"To track 'subdir/data' automatically during `dvc exp run`:"
"To track 'subdir/data' automatically in the DVC pipeline:"
"\n1. Run `dvc remove subdir.dvc` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
Expand Down
5 changes: 4 additions & 1 deletion tests/test_post_to_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image as ImagePIL

from dvclive import Live
from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME
from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT
from dvclive.plots import Image, Metric
from dvclive.studio import _adapt_image, get_dvc_studio_config

Expand Down Expand Up @@ -151,6 +151,7 @@ def test_post_to_studio_skip_start_and_done_on_env_var(

monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40)
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)

with Live() as live:
live.log_metric("foo", 1)
Expand All @@ -167,6 +168,7 @@ def test_post_to_studio_dvc_studio_config(

monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40)
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)

mocked_dvc_repo.config = {"studio": {"token": "token"}}

Expand Down Expand Up @@ -228,6 +230,7 @@ def test_post_to_studio_inside_dvc_exp(

monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40)
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)

with Live() as live:
live.log_metric("foo", 1)
Expand Down

0 comments on commit d21fa55

Please sign in to comment.