From 169f5c1f0ed6ac58a0a1303d0808e6a9c4e89216 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Tue, 6 Oct 2020 18:09:19 +0900 Subject: [PATCH] tests: add tests for checkpoint and checkpoint_continue --- tests/func/experiments/test_experiments.py | 113 +++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 88a24086a8..6688c6d642 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -1,8 +1,47 @@ +from textwrap import dedent + import pytest +from funcy import first from dvc.utils.serialize import PythonFileCorruptedError from tests.func.test_repro_multistage import COPY_SCRIPT +CHECKPOINT_SCRIPT_FORMAT = dedent( + """\ + import os + import sys + import shutil + from time import sleep + + from dvc.api import make_checkpoint + + checkpoint_file = {} + checkpoint_iterations = int({}) + if os.path.exists(checkpoint_file): + with open(checkpoint_file) as fobj: + try: + value = int(fobj.read()) + except ValueError: + value = 0 + else: + with open(checkpoint_file, "w"): + pass + value = 0 + + shutil.copyfile({}, {}) + + if os.getenv("DVC_CHECKPOINT"): + for _ in range(checkpoint_iterations): + value += 1 + with open(checkpoint_file, "w") as fobj: + fobj.write(str(value)) + make_checkpoint() +""" +) +CHECKPOINT_SCRIPT = CHECKPOINT_SCRIPT_FORMAT.format( + "sys.argv[1]", "sys.argv[2]", "sys.argv[3]", "sys.argv[4]" +) + def test_new_simple(tmp_dir, scm, dvc, mocker): tmp_dir.gen("copy.py", COPY_SCRIPT) @@ -236,3 +275,77 @@ def test_detached_parent(tmp_dir, scm, dvc, mocker): assert dvc.experiments.get_baseline(exp_rev) == detached_rev assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 3" assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 3" + + +def test_new_checkpoint(tmp_dir, scm, dvc, mocker): + tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + stage = dvc.run( + cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + outs_persist=["foo"], + always_changed=True, + name="checkpoint-file", + ) + scm.add( + [ + "dvc.yaml", + "dvc.lock", + "checkpoint.py", + "params.yaml", + "metrics.yaml", + ] + ) + scm.commit("init") + + new_mock = mocker.spy(dvc.experiments, "new") + dvc.reproduce( + stage.addressing, experiment=True, checkpoint=True, params=["foo=2"] + ) + + new_mock.assert_called_once() + assert (tmp_dir / "foo").read_text() == "5" + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: 2" + + +def test_continue_checkpoint(tmp_dir, scm, dvc, mocker): + tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT) + tmp_dir.gen("params.yaml", "foo: 1") + stage = dvc.run( + cmd="python checkpoint.py foo 5 params.yaml metrics.yaml", + metrics_no_cache=["metrics.yaml"], + params=["foo"], + outs_persist=["foo"], + always_changed=True, + name="checkpoint-file", + ) + scm.add( + [ + "dvc.yaml", + "dvc.lock", + "checkpoint.py", + "params.yaml", + "metrics.yaml", + ] + ) + scm.commit("init") + + results = dvc.reproduce( + stage.addressing, experiment=True, checkpoint=True, params=["foo=2"] + ) + exp_rev = first(results) + + dvc.reproduce( + stage.addressing, + experiment=True, + checkpoint=True, + checkpoint_continue=exp_rev, + ) + + assert (tmp_dir / "foo").read_text() == "10" + assert ( + tmp_dir / ".dvc" / "experiments" / "metrics.yaml" + ).read_text().strip() == "foo: 2"