Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dvc: introduce local stage cache #3603

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, root_dir=None):
from dvc.repo.params import Params
from dvc.scm.tree import WorkingTree
from dvc.utils.fs import makedirs
from dvc.stage.cache import StageCache

root_dir = self.find_root(root_dir)

Expand Down Expand Up @@ -104,6 +105,8 @@ def __init__(self, root_dir=None):
self.cache = Cache(self)
self.cloud = DataCloud(self)

self.stage_cache = StageCache(self.cache.local.cache_dir)

self.metrics = Metrics(self)
self.params = Params(self)

Expand Down
9 changes: 2 additions & 7 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,7 @@ def reproduce(


def _reproduce_stages(
G,
stages,
downstream=False,
ignore_build_cache=False,
single_item=False,
**kwargs
G, stages, downstream=False, single_item=False, **kwargs
):
r"""Derive the evaluation of the given node for the given graph.

Expand Down Expand Up @@ -172,7 +167,7 @@ def _reproduce_stages(
try:
ret = _reproduce_stage(stage, **kwargs)

if len(ret) != 0 and ignore_build_cache:
if len(ret) != 0 and kwargs.get("ignore_build_cache", False):
# NOTE: we are walking our pipeline from the top to the
# bottom. If one stage is changed, it will be reproduced,
# which tells us that we should force reproducing all of
Expand Down
5 changes: 4 additions & 1 deletion dvc/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def run(self, fname=None, no_exec=False, **kwargs):
raise OutputDuplicationError(exc.output, set(exc.stages) - {stage})

if not no_exec:
stage.run(no_commit=kwargs.get("no_commit", False))
stage.run(
no_commit=kwargs.get("no_commit", False),
ignore_build_cache=kwargs.get("ignore_build_cache", False),
)
dvcfile.dump(stage, update_pipeline=True)
return stage
12 changes: 10 additions & 2 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def save(self):

self.md5 = self._compute_md5()

self.repo.stage_cache.save(self)

@staticmethod
def _changed_entries(entries):
return [
Expand Down Expand Up @@ -617,7 +619,9 @@ def _run(self):
raise StageCmdFailedError(self)

@rwlocked(read=["deps"], write=["outs"])
def run(self, dry=False, no_commit=False, force=False):
def run(
self, dry=False, no_commit=False, force=False, ignore_build_cache=False
):
if (self.cmd or self.is_import) and not self.locked and not dry:
self.remove_outs(ignore_remove=False, force=False)

Expand Down Expand Up @@ -650,16 +654,20 @@ def run(self, dry=False, no_commit=False, force=False):
self.check_missing_outputs()

else:
logger.info("Running command:\n\t{}".format(self.cmd))
if not dry:
if not force and not ignore_build_cache:
self.repo.stage_cache.restore(self)

if (
not force
and not self.is_callback
and not self.always_changed
and self._already_cached()
):
logger.info("Stage is cached, skipping.")
self.checkout()
else:
logger.info("Running command:\n\t{}".format(self.cmd))
self._run()

if not dry:
Expand Down
124 changes: 124 additions & 0 deletions dvc/stage/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import yaml
import logging

from voluptuous import Schema, Required, Invalid

from dvc.utils.fs import makedirs
from dvc.utils import relpath, dict_sha256

logger = logging.getLogger(__name__)

SCHEMA = Schema(
{
Required("cmd"): str,
Required("deps"): {str: str},
Required("outs"): {str: str},
}
)


def _get_cache_hash(cache, key=False):
return dict_sha256(
{
"cmd": cache["cmd"],
"deps": cache["deps"],
"outs": list(cache["outs"].keys()) if key else cache["outs"],
}
)


def _get_stage_hash(stage):
if not stage.cmd or not stage.deps or not stage.outs:
return None

for dep in stage.deps:
if dep.scheme != "local" or not dep.def_path or not dep.get_checksum():
return None

for out in stage.outs:
if out.scheme != "local" or not out.def_path or out.persist:
return None

return _get_cache_hash(_create_cache(stage), key=True)


def _create_cache(stage):
return {
"cmd": stage.cmd,
"deps": {dep.def_path: dep.get_checksum() for dep in stage.deps},
"outs": {out.def_path: out.get_checksum() for out in stage.outs},
}


class StageCache:
def __init__(self, cache_dir):
self.cache_dir = os.path.join(cache_dir, "stages")

def _get_cache_dir(self, key):
return os.path.join(self.cache_dir, key[:2], key)

def _get_cache_path(self, key, value):
return os.path.join(self._get_cache_dir(key), value)

def _load_cache(self, key, value):
path = self._get_cache_path(key, value)

try:
with open(path, "r") as fobj:
return SCHEMA(yaml.safe_load(fobj))
except FileNotFoundError:
return None
except (yaml.error.YAMLError, Invalid):
logger.warning("corrupted cache file '%s'.", relpath(path))
os.unlink(path)
return None

def _load(self, stage):
key = _get_stage_hash(stage)
if not key:
return None

cache_dir = self._get_cache_dir(key)
if not os.path.exists(cache_dir):
return None

for value in os.listdir(cache_dir):
cache = self._load_cache(key, value)
if cache:
return cache

return None

def save(self, stage):
cache_key = _get_stage_hash(stage)
if not cache_key:
return

cache = _create_cache(stage)
cache_value = _get_cache_hash(cache)

if self._load_cache(cache_key, cache_value):
return

# sanity check
SCHEMA(cache)

path = self._get_cache_path(cache_key, cache_value)
dpath = os.path.dirname(path)
makedirs(dpath, exist_ok=True)
with open(path, "w+") as fobj:
yaml.dump(cache, fobj)

def restore(self, stage):
cache = self._load(stage)
if not cache:
return

deps = {dep.def_path: dep for dep in stage.deps}
for def_path, checksum in cache["deps"].items():
deps[def_path].checksum = checksum

outs = {out.def_path: out for out in stage.outs}
for def_path, checksum in cache["outs"].items():
outs[def_path].checksum = checksum
16 changes: 12 additions & 4 deletions dvc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def file_md5(fname):
return (None, None)


def bytes_md5(byts):
hasher = hashlib.md5()
def bytes_hash(byts, typ):
hasher = getattr(hashlib, typ)()
hasher.update(byts)
return hasher.hexdigest()

Expand All @@ -100,10 +100,18 @@ def dict_filter(d, exclude=()):
return d


def dict_md5(d, exclude=()):
def dict_hash(d, typ, exclude=()):
filtered = dict_filter(d, exclude)
byts = json.dumps(filtered, sort_keys=True).encode("utf-8")
return bytes_md5(byts)
return bytes_hash(byts, typ)


def dict_md5(d, **kwargs):
return dict_hash(d, "md5", **kwargs)


def dict_sha256(d, **kwargs):
return dict_hash(d, "sha256", **kwargs)


def _split(list_to_split, chunk_size):
Expand Down
2 changes: 2 additions & 0 deletions tests/func/test_gc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import shutil
import os

import configobj
Expand Down Expand Up @@ -341,6 +342,7 @@ def test_gc_not_collect_pipeline_tracked_files(tmp_dir, dvc, run_copy):
tmp_dir.gen("bar", "bar")

run_copy("foo", "foo2", name="copy")
shutil.rmtree(dvc.stage_cache.cache_dir)
assert _count_files(dvc.cache.local.cache_dir) == 1
dvc.gc(workspace=True, force=True)
assert _count_files(dvc.cache.local.cache_dir) == 1
Expand Down
4 changes: 3 additions & 1 deletion tests/func/test_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,9 @@ def test(self):
["repro", self._get_stage_target(self.stage), "--no-commit"]
)
self.assertEqual(ret, 0)
self.assertFalse(os.path.exists(self.dvc.cache.local.cache_dir))
self.assertEqual(
os.listdir(self.dvc.cache.local.cache_dir), ["stages"]
)


class TestReproAlreadyCached(TestRepro):
Expand Down
39 changes: 37 additions & 2 deletions tests/unit/test_stage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import signal
import subprocess
import threading
Expand Down Expand Up @@ -51,8 +52,6 @@ def test_meta_ignored():

class TestPathConversion(TestCase):
def test(self):
import os

stage = Stage(None, "path")

stage.wdir = os.path.join("..", "..")
Expand Down Expand Up @@ -103,3 +102,39 @@ def test_always_changed(dvc):
with dvc.lock:
assert stage.changed()
assert stage.status()["path"] == ["always changed"]


def test_stage_cache(tmp_dir, dvc, run_copy, mocker):
tmp_dir.gen("dep", "dep")
stage = run_copy("dep", "out")

with dvc.lock, dvc.state:
stage.remove(remove_outs=True, force=True)

assert not (tmp_dir / "out").exists()
assert not (tmp_dir / "out.dvc").exists()

cache_dir = os.path.join(
dvc.stage_cache.cache_dir,
"ec",
"ec5b6d8dea9136dbb62d93a95c777f87e6c54b0a6bee839554acb99fdf23d2b1",
)
cache_file = os.path.join(
cache_dir,
"09f9eb17fdb1ee7f8566b3c57394cee060eaf28075244bc6058612ac91fdf04a",
)

assert os.path.isdir(cache_dir)
assert os.listdir(cache_dir) == [os.path.basename(cache_file)]
assert os.path.isfile(cache_file)

run_spy = mocker.spy(stage, "_run")
checkout_spy = mocker.spy(stage, "checkout")
with dvc.lock, dvc.state:
stage.run()

assert not run_spy.called
assert checkout_spy.call_count == 1

assert (tmp_dir / "out").exists()
assert (tmp_dir / "out").read_text() == "dep"
26 changes: 26 additions & 0 deletions tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dvc.path_info import PathInfo
from dvc.utils import (
file_md5,
dict_sha256,
resolve_output,
fix_env,
relpath,
Expand Down Expand Up @@ -155,3 +156,28 @@ def test_hint_on_lockfile():
with pytest.raises(Exception) as exc:
assert parse_target("pipelines.lock:name")
assert "pipelines.yaml:name" in str(exc.value)


@pytest.mark.parametrize(
"d,sha",
[
(
{
"cmd": "echo content > out",
"deps": {"dep": "2254342becceafbd04538e0a38696791"},
"outs": {"out": "f75b8179e4bbe7e2b4a074dcef62de95"},
},
"f472eda60f09660a4750e8b3208cf90b3a3b24e5f42e0371d829710e9464d74a",
),
(
{
"cmd": "echo content > out",
"deps": {"dep": "2254342becceafbd04538e0a38696791"},
"outs": ["out"],
},
"a239b67073bd58affcdb81fff3305d1726c6e7f9c86f3d4fca0e92e8147dc7b0",
),
],
)
def test_dict_sha256(d, sha):
assert dict_sha256(d) == sha