From 44c37310c41c751e18eba2c26c2ba4a6c10af83c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Redzy=C5=84ski?= Date: Wed, 30 Dec 2020 16:30:46 +0100 Subject: [PATCH] live: fix summary tracking by experiments --- dvc/repo/reproduce.py | 6 +++ dvc/stage/utils.py | 9 ++-- tests/func/test_live.py | 98 +++++++++++++++++++++++------------------ 3 files changed, 65 insertions(+), 48 deletions(-) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index fd27d7c7c4..b36395588a 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -67,6 +67,12 @@ def _track_stage(stage): for out in stage.outs: if not out.use_scm_ignore and out.is_in_repo: stage.repo.scm.track_file(os.fspath(out.path_info)) + if out.live: + from dvc.repo.live import summary_path_info + + summary = summary_path_info(out) + if summary: + stage.repo.scm.track_file(os.fspath(summary)) stage.repo.scm.track_changed_files() diff --git a/dvc/stage/utils.py b/dvc/stage/utils.py index f038157ed6..8974537f9a 100644 --- a/dvc/stage/utils.py +++ b/dvc/stage/utils.py @@ -59,7 +59,7 @@ def fill_stage_outputs(stage, **kwargs): stage.outs = [] stage.outs += _load_live_outputs( - stage, kwargs.get("live", []), kwargs.get("live_summary", False) + stage, kwargs.get("live", None), kwargs.get("live_summary", False) ) for key in keys: @@ -74,18 +74,15 @@ def fill_stage_outputs(stage, **kwargs): ) -def _load_live_outputs(stage, live_l, live_summary): - from dvc.exceptions import DvcException +def _load_live_outputs(stage, live_l=None, live_summary=False): from dvc.output import BaseOutput outs = [] if live_l: - if len(live_l) != 1: - raise DvcException("Only one live output allowed!") outs += output.loads_from( stage, - live_l, + [live_l], use_cache=False, live={BaseOutput.PARAM_LIVE_SUMMARY: live_summary}, ) diff --git a/tests/func/test_live.py b/tests/func/test_live.py index 391cd5f9d9..54790d6218 100644 --- a/tests/func/test_live.py +++ b/tests/func/test_live.py @@ -1,35 +1,60 @@ -import subprocess from textwrap import dedent import pytest +from dvc import stage as stage_module from dvc.exceptions import MetricsError LIVE_SCRITP = dedent( """ from dvclive import dvclive import sys - r = 5 + r = 2 for i in range(r): - dvclive.log("loss", -i/5) - dvclive.log("accuracy", i/5)""" + dvclive.log("loss", 1-i/r) + dvclive.log("accuracy", i/r) + dvclive.next_step()""" ) +@pytest.fixture +def live_stage(tmp_dir, scm, dvc): + + pytest.skip("dvclive does not exist yet") + + def make(summary=True): + tmp_dir.gen("train.py", LIVE_SCRITP) + tmp_dir.gen("params.yaml", "foo: 1") + stage = dvc.run( + cmd="python train.py", + params=["foo"], + deps=["train.py"], + name="live_stage", + live="logs", + live_summary=summary, + ) + + scm.add(["dvc.yaml", "dvc.lock", "train.py", "params.yaml"]) + scm.commit("initial: live_stage") + return stage + + yield make + + @pytest.mark.parametrize("summary", (True, False)) def test_export_config_tmp(tmp_dir, dvc, mocker, summary): - proc_spy = mocker.spy(subprocess, "Popen") + run_spy = mocker.spy(stage_module.run, "_run") tmp_dir.gen("src", "dependency") dvc.run( cmd="mkdir logs && touch logs.json", deps=["src"], name="run_logger", - live=["logs"], + live="logs", live_summary=summary, ) - assert proc_spy.call_count == 1 - _, kwargs = proc_spy.call_args + assert run_spy.call_count == 1 + _, kwargs = run_spy.call_args assert "DVCLIVE_PATH" in kwargs["env"] assert kwargs["env"]["DVCLIVE_PATH"] == "logs" @@ -38,20 +63,13 @@ def test_export_config_tmp(tmp_dir, dvc, mocker, summary): assert kwargs["env"]["DVCLIVE_SUMMARY"] == str(int(summary)) -@pytest.mark.skip(reason="dvclive does not exist yet") @pytest.mark.parametrize("summary", (True, False)) -def test_export_config(tmp_dir, dvc, mocker, summary): - proc_spy = mocker.spy(subprocess, "Popen") - tmp_dir.gen("log.py", LIVE_SCRITP.format(log_path="logs")) - dvc.run( - cmd="python log.py", - deps=["log.py"], - name="run_logger", - live=["logs"], - live_summary=summary, - ) - assert proc_spy.call_count == 1 - _, kwargs = proc_spy.call_args +def test_export_config(tmp_dir, dvc, mocker, summary, live_stage): + run_spy = mocker.spy(stage_module.run, "_run") + live_stage(summary=summary) + + assert run_spy.call_count == 1 + _, kwargs = run_spy.call_args assert "DVCLIVE_PATH" in kwargs["env"] assert kwargs["env"]["DVCLIVE_PATH"] == "logs" @@ -60,20 +78,12 @@ def test_export_config(tmp_dir, dvc, mocker, summary): assert kwargs["env"]["DVCLIVE_SUMMARY"] == str(int(summary)) -@pytest.mark.skip(reason="dvclive does not exist yet") -def test_live_provides_metrics(tmp_dir, dvc): - tmp_dir.gen("log.py", LIVE_SCRITP.format(log_path="logs")) - dvc.run( - cmd="python log.py", - deps=["log.py"], - name="run_logger", - live=["logs"], - live_summary=True, - ) +def test_live_provides_metrics(tmp_dir, dvc, live_stage): + live_stage(summary=True) assert (tmp_dir / "logs.json").is_file() assert dvc.metrics.show() == { - "": {"logs.json": {"step": 3, "loss": -0.6, "accuracy": 0.6}} + "": {"logs.json": {"step": 1, "loss": 0.5, "accuracy": 0.5}} } assert (tmp_dir / "logs").is_dir() @@ -82,16 +92,8 @@ def test_live_provides_metrics(tmp_dir, dvc): assert "logs/loss.tsv" in plots -@pytest.mark.skip(reason="dvclive does not exist yet") -def test_live_provides_no_metrics(tmp_dir, dvc): - tmp_dir.gen("log.py", LIVE_SCRITP.format(log_path="logs")) - dvc.run( - cmd="python log.py", - deps=["log.py"], - name="run_logger", - live=["logs"], - live_summary=False, - ) +def test_live_provides_no_metrics(tmp_dir, dvc, live_stage): + live_stage(summary=False) assert not (tmp_dir / "logs.json").is_file() with pytest.raises(MetricsError): @@ -101,3 +103,15 @@ def test_live_provides_no_metrics(tmp_dir, dvc): plots = dvc.plots.show() assert "logs/accuracy.tsv" in plots assert "logs/loss.tsv" in plots + + +def test_experiments_track_summary(tmp_dir, scm, dvc, live_stage): + live_stage(summary=True) + baseline_rev = scm.get_rev() + + experiments = dvc.experiments.run(targets=["live_stage"], params=["foo=2"]) + assert len(experiments) == 1 + ((exp_rev, _),) = experiments.items() + + res = dvc.experiments.show() + assert "logs.json" in res[baseline_rev][exp_rev]["metrics"].keys()