Skip to content

Commit

Permalink
tests: add tests for checkpoint and checkpoint_continue
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed Oct 6, 2020
1 parent d1a9648 commit 169f5c1
Showing 1 changed file with 113 additions and 0 deletions.
113 changes: 113 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
from textwrap import dedent

import pytest
from funcy import first

from dvc.utils.serialize import PythonFileCorruptedError
from tests.func.test_repro_multistage import COPY_SCRIPT

CHECKPOINT_SCRIPT_FORMAT = dedent(
"""\
import os
import sys
import shutil
from time import sleep
from dvc.api import make_checkpoint
checkpoint_file = {}
checkpoint_iterations = int({})
if os.path.exists(checkpoint_file):
with open(checkpoint_file) as fobj:
try:
value = int(fobj.read())
except ValueError:
value = 0
else:
with open(checkpoint_file, "w"):
pass
value = 0
shutil.copyfile({}, {})
if os.getenv("DVC_CHECKPOINT"):
for _ in range(checkpoint_iterations):
value += 1
with open(checkpoint_file, "w") as fobj:
fobj.write(str(value))
make_checkpoint()
"""
)
CHECKPOINT_SCRIPT = CHECKPOINT_SCRIPT_FORMAT.format(
"sys.argv[1]", "sys.argv[2]", "sys.argv[3]", "sys.argv[4]"
)


def test_new_simple(tmp_dir, scm, dvc, mocker):
tmp_dir.gen("copy.py", COPY_SCRIPT)
Expand Down Expand Up @@ -236,3 +275,77 @@ def test_detached_parent(tmp_dir, scm, dvc, mocker):
assert dvc.experiments.get_baseline(exp_rev) == detached_rev
assert (tmp_dir / "params.yaml").read_text().strip() == "foo: 3"
assert (tmp_dir / "metrics.yaml").read_text().strip() == "foo: 3"


def test_new_checkpoint(tmp_dir, scm, dvc, mocker):
tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT)
tmp_dir.gen("params.yaml", "foo: 1")
stage = dvc.run(
cmd="python checkpoint.py foo 5 params.yaml metrics.yaml",
metrics_no_cache=["metrics.yaml"],
params=["foo"],
outs_persist=["foo"],
always_changed=True,
name="checkpoint-file",
)
scm.add(
[
"dvc.yaml",
"dvc.lock",
"checkpoint.py",
"params.yaml",
"metrics.yaml",
]
)
scm.commit("init")

new_mock = mocker.spy(dvc.experiments, "new")
dvc.reproduce(
stage.addressing, experiment=True, checkpoint=True, params=["foo=2"]
)

new_mock.assert_called_once()
assert (tmp_dir / "foo").read_text() == "5"
assert (
tmp_dir / ".dvc" / "experiments" / "metrics.yaml"
).read_text().strip() == "foo: 2"


def test_continue_checkpoint(tmp_dir, scm, dvc, mocker):
tmp_dir.gen("checkpoint.py", CHECKPOINT_SCRIPT)
tmp_dir.gen("params.yaml", "foo: 1")
stage = dvc.run(
cmd="python checkpoint.py foo 5 params.yaml metrics.yaml",
metrics_no_cache=["metrics.yaml"],
params=["foo"],
outs_persist=["foo"],
always_changed=True,
name="checkpoint-file",
)
scm.add(
[
"dvc.yaml",
"dvc.lock",
"checkpoint.py",
"params.yaml",
"metrics.yaml",
]
)
scm.commit("init")

results = dvc.reproduce(
stage.addressing, experiment=True, checkpoint=True, params=["foo=2"]
)
exp_rev = first(results)

dvc.reproduce(
stage.addressing,
experiment=True,
checkpoint=True,
checkpoint_continue=exp_rev,
)

assert (tmp_dir / "foo").read_text() == "10"
assert (
tmp_dir / ".dvc" / "experiments" / "metrics.yaml"
).read_text().strip() == "foo: 2"

0 comments on commit 169f5c1

Please sign in to comment.