Skip to content

Commit

Permalink
TmpDir: add serialize utilities (#6620)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Sep 15, 2021
1 parent 7521e91 commit fbcc3f3
Show file tree
Hide file tree
Showing 36 changed files with 179 additions and 200 deletions.
7 changes: 7 additions & 0 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
{".toml": load_toml, ".json": load_json, ".py": load_py} # noqa: F405
)

DUMPERS: DefaultDict[str, DumperFn] = defaultdict( # noqa: F405
lambda: dump_yaml # noqa: F405
)
DUMPERS.update(
{".toml": dump_toml, ".json": dump_json, ".py": dump_py} # noqa: F405
)

MODIFIERS: DefaultDict[str, ModifierFn] = defaultdict( # noqa: F405
lambda: modify_yaml # noqa: F405
)
Expand Down
10 changes: 8 additions & 2 deletions dvc/utils/serialize/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,16 @@ def _load_data(path: "AnyPath", parser: ParserFn, fs: "BaseFileSystem" = None):
return parser(fd.read(), path)


def _dump_data(path, data: Any, dumper: DumperFn, fs: "BaseFileSystem" = None):
def _dump_data(
path,
data: Any,
dumper: DumperFn,
fs: "BaseFileSystem" = None,
**dumper_args,
):
open_fn = fs.open if fs else open
with open_fn(path, "w+", encoding="utf-8") as fd: # type: ignore
dumper(data, fd)
dumper(data, fd, **dumper_args)


@contextmanager
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def parse_json(text, path, **kwargs):
return json.loads(text, **kwargs) or {}


def dump_json(path, data, fs=None):
return _dump_data(path, data, dumper=json.dump, fs=fs)
def dump_json(path, data, fs=None, **kwargs):
return _dump_data(path, data, dumper=json.dump, fs=fs, **kwargs)


@contextmanager
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/_toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def _dump(data, stream):
return toml.dump(data, stream, encoder=toml.TomlPreserveCommentEncoder())


def dump_toml(path, data, fs=None):
return _dump_data(path, data, dumper=_dump, fs=fs)
def dump_toml(path, data, fs=None, **kwargs):
return _dump_data(path, data, dumper=_dump, fs=fs, **kwargs)


@contextmanager
Expand Down
4 changes: 2 additions & 2 deletions dvc/utils/serialize/_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def _dump(data, stream):
return yaml.dump(data, stream)


def dump_yaml(path, data, fs=None):
return _dump_data(path, data, dumper=_dump, fs=fs)
def dump_yaml(path, data, fs=None, **kwargs):
return _dump_data(path, data, dumper=_dump, fs=fs, **kwargs)


def loads_yaml(s, typ="safe"):
Expand Down
18 changes: 18 additions & 0 deletions tests/dir_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@
import pathlib
import sys
from contextlib import contextmanager
from functools import partialmethod
from textwrap import dedent

import pytest
from funcy import lmap, retry

from dvc.logger import disable_other_loggers
from dvc.utils import serialize
from dvc.utils.fs import makedirs

__all__ = [
Expand Down Expand Up @@ -250,6 +252,22 @@ def read_text(self, *args, **kwargs): # pylint: disable=signature-differs
def hash_to_path_info(self, hash_):
return self / hash_[0:2] / hash_[2:]

def dump(self, *args, **kwargs):
return serialize.DUMPERS[self.suffix](self, *args, **kwargs)

def parse(self, *args, **kwargs):
return serialize.LOADERS[self.suffix](self, *args, **kwargs)

def modify(self, *args, **kwargs):
return serialize.MODIFIERS[self.suffix](self, *args, **kwargs)

load_yaml = partialmethod(serialize.load_yaml)
dump_yaml = partialmethod(serialize.dump_yaml)
load_json = partialmethod(serialize.load_json)
dump_json = partialmethod(serialize.dump_json)
load_toml = partialmethod(serialize.load_toml)
dump_toml = partialmethod(serialize.dump_toml)


def _coerce_filenames(filenames):
if isinstance(filenames, (str, bytes, pathlib.PurePath)):
Expand Down
6 changes: 3 additions & 3 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dvc.dvcfile import PIPELINE_FILE
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.utils.serialize import PythonFileCorruptedError, load_yaml
from dvc.utils.serialize import PythonFileCorruptedError
from tests.func.test_repro_multistage import COPY_SCRIPT


Expand Down Expand Up @@ -660,9 +660,9 @@ def test_modified_data_dep(tmp_dir, scm, dvc, workspace, params, target):


def test_exp_run_recursive(tmp_dir, scm, dvc, run_copy_metrics):
tmp_dir.dvc_gen("metric_t.json", "foo: 1")
tmp_dir.dvc_gen("metric_t.json", '{"foo": 1}')
run_copy_metrics(
"metric_t.json", "metric.json", metrics=["metric.json"], no_exec=True
)
assert dvc.experiments.run(".", recursive=True)
assert load_yaml(tmp_dir / "metric.json") == {"foo": 1}
assert (tmp_dir / "metric.json").parse() == {"foo": 1}
10 changes: 5 additions & 5 deletions tests/func/experiments/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dvc.repo.experiments.executor.base import BaseExecutor, ExecutorInfo
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.utils.fs import makedirs
from dvc.utils.serialize import YAMLFileCorruptedError, dump_yaml
from dvc.utils.serialize import YAMLFileCorruptedError
from tests.func.test_repro_multistage import COPY_SCRIPT


Expand Down Expand Up @@ -286,7 +286,7 @@ def test_show_filter(
"train/bar": 1,
"nested": {"foo": 1, "bar": 1},
}
dump_yaml(params_file, params_data)
(tmp_dir / params_file).dump(params_data)

dvc.run(
cmd="python copy.py params.yaml metrics.yaml",
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_show_running_workspace(tmp_dir, scm, dvc, exp_stage, capsys):
makedirs(pid_dir, True)
info = ExecutorInfo(None, None, None, BaseExecutor.DEFAULT_LOCATION)
pidfile = os.path.join(pid_dir, f"workspace{BaseExecutor.PIDFILE_EXT}")
dump_yaml(pidfile, info.to_dict())
(tmp_dir / pidfile).dump(info.to_dict())

assert dvc.experiments.show()["workspace"] == {
"baseline": {
Expand Down Expand Up @@ -417,7 +417,7 @@ def test_show_running_executor(tmp_dir, scm, dvc, exp_stage):
makedirs(pid_dir, True)
info = ExecutorInfo(None, None, None, BaseExecutor.DEFAULT_LOCATION)
pidfile = os.path.join(pid_dir, f"{exp_rev}{BaseExecutor.PIDFILE_EXT}")
dump_yaml(pidfile, info.to_dict())
(tmp_dir / pidfile).dump(info.to_dict())

results = dvc.experiments.show()
exp_data = get_in(results, [baseline_rev, exp_rev, "data"])
Expand Down Expand Up @@ -455,7 +455,7 @@ def test_show_running_checkpoint(
info = ExecutorInfo(123, "foo.git", baseline_rev, executor)
rev = "workspace" if workspace else stash_rev
pidfile = os.path.join(pid_dir, f"{rev}{BaseExecutor.PIDFILE_EXT}")
dump_yaml(pidfile, info.to_dict())
(tmp_dir / pidfile).dump(info.to_dict())

mocker.patch.object(
BaseExecutor, "fetch_exps", return_value=[str(exp_ref)]
Expand Down
13 changes: 6 additions & 7 deletions tests/func/metrics/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json

from dvc.main import main
from dvc.utils.serialize import dump_yaml


def test_metrics_diff_simple(tmp_dir, scm, dvc, run_copy_metrics):
Expand All @@ -22,7 +21,7 @@ def _gen(val):
def test_metrics_diff_yaml(tmp_dir, scm, dvc, run_copy_metrics):
def _gen(val):
metrics = {"a": {"b": {"c": val, "d": 1, "e": str(val)}}}
dump_yaml("m_temp.yaml", metrics)
(tmp_dir / "m_temp.yaml").dump(metrics)
run_copy_metrics(
"m_temp.yaml", "m.yaml", metrics=["m.yaml"], commit=str(val)
)
Expand All @@ -39,7 +38,7 @@ def _gen(val):
def test_metrics_diff_json(tmp_dir, scm, dvc, run_copy_metrics):
def _gen(val):
metrics = {"a": {"b": {"c": val, "d": 1, "e": str(val)}}}
tmp_dir.gen({"m_temp.json": json.dumps(metrics)})
(tmp_dir / "m_temp.json").dump(metrics)
run_copy_metrics(
"m_temp.json", "m.json", metrics=["m.json"], commit=str(val)
)
Expand All @@ -55,7 +54,7 @@ def _gen(val):
def test_metrics_diff_json_unchanged(tmp_dir, scm, dvc, run_copy_metrics):
def _gen(val):
metrics = {"a": {"b": {"c": val, "d": 1, "e": str(val)}}}
tmp_dir.gen({"m_temp.json": json.dumps(metrics)})
(tmp_dir / "m_temp.json").dump(metrics)
run_copy_metrics(
"m_temp.json", "m.json", metrics=["m.json"], commit=str(val)
)
Expand All @@ -69,7 +68,7 @@ def _gen(val):

def test_metrics_diff_broken_json(tmp_dir, scm, dvc, run_copy_metrics):
metrics = {"a": {"b": {"c": 1, "d": 1, "e": "3"}}}
tmp_dir.gen({"m_temp.json": json.dumps(metrics)})
(tmp_dir / "m_temp.json").dump(metrics)
run_copy_metrics(
"m_temp.json",
"m.json",
Expand All @@ -94,7 +93,7 @@ def test_metrics_diff_no_metrics(tmp_dir, scm, dvc):

def test_metrics_diff_new_metric(tmp_dir, scm, dvc, run_copy_metrics):
metrics = {"a": {"b": {"c": 1, "d": 1, "e": "3"}}}
tmp_dir.gen({"m_temp.json": json.dumps(metrics)})
(tmp_dir / "m_temp.json").dump(metrics)
run_copy_metrics("m_temp.json", "m.json", metrics_no_cache=["m.json"])

assert dvc.metrics.diff() == {
Expand All @@ -107,7 +106,7 @@ def test_metrics_diff_new_metric(tmp_dir, scm, dvc, run_copy_metrics):

def test_metrics_diff_deleted_metric(tmp_dir, scm, dvc, run_copy_metrics):
metrics = {"a": {"b": {"c": 1, "d": 1, "e": "3"}}}
tmp_dir.gen({"m_temp.json": json.dumps(metrics)})
(tmp_dir / "m_temp.json").dump(metrics)
run_copy_metrics(
"m_temp.json",
"m.json",
Expand Down
11 changes: 5 additions & 6 deletions tests/func/metrics/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

from dvc.dvcfile import PIPELINE_FILE
from dvc.exceptions import OverlappingOutputPathsError
from dvc.path_info import PathInfo
from dvc.repo import Repo
from dvc.utils.fs import remove
from dvc.utils.serialize import YAMLFileCorruptedError, dump_yaml, modify_yaml
from dvc.utils.serialize import YAMLFileCorruptedError


def test_show_simple(tmp_dir, dvc, run_copy_metrics):
Expand Down Expand Up @@ -227,10 +226,10 @@ def test_show_no_metrics_files(tmp_dir, dvc, caplog):
def test_metrics_show_overlap(
tmp_dir, dvc, run_copy_metrics, clear_before_run
):
data_dir = PathInfo("data")
(tmp_dir / data_dir).mkdir()
data_dir = tmp_dir / "data"
data_dir.mkdir()

dump_yaml(data_dir / "m1_temp.yaml", {"a": {"b": {"c": 2, "d": 1}}})
(data_dir / "m1_temp.yaml").dump({"a": {"b": {"c": 2, "d": 1}}})
run_copy_metrics(
str(data_dir / "m1_temp.yaml"),
str(data_dir / "m1.yaml"),
Expand All @@ -239,7 +238,7 @@ def test_metrics_show_overlap(
name="cp-m1",
metrics=[str(data_dir / "m1.yaml")],
)
with modify_yaml("dvc.yaml") as d:
with (tmp_dir / "dvc.yaml").modify() as d:
# trying to make an output overlaps error
d["stages"]["corrupted-stage"] = {
"cmd": "mkdir data",
Expand Down
7 changes: 3 additions & 4 deletions tests/func/params/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dvc.utils import relpath
from dvc.utils.serialize import dump_yaml


def test_diff_no_params(tmp_dir, scm, dvc):
Expand Down Expand Up @@ -140,7 +139,7 @@ def test_no_commits(tmp_dir):
def test_vars_shows_on_params_diff(tmp_dir, scm, dvc):
params_file = tmp_dir / "test_params.yaml"
param_data = {"vars": {"model1": {"epoch": 15}, "model2": {"epoch": 35}}}
dump_yaml(params_file, param_data)
(tmp_dir / params_file).dump(param_data)
d = {
"vars": ["test_params.yaml"],
"stages": {
Expand All @@ -150,7 +149,7 @@ def test_vars_shows_on_params_diff(tmp_dir, scm, dvc):
}
},
}
dump_yaml("dvc.yaml", d)
(tmp_dir / "dvc.yaml").dump(d)
assert dvc.params.diff() == {
"test_params.yaml": {
"vars.model1.epoch": {"new": 15, "old": None},
Expand All @@ -161,7 +160,7 @@ def test_vars_shows_on_params_diff(tmp_dir, scm, dvc):
scm.commit("added stages")

param_data["vars"]["model1"]["epoch"] = 20
dump_yaml(params_file, param_data)
(tmp_dir / params_file).dump(param_data)
assert dvc.params.diff() == {
"test_params.yaml": {
"vars.model1.epoch": {"new": 20, "old": 15, "diff": 5}
Expand Down
5 changes: 2 additions & 3 deletions tests/func/parsing/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dvc.parsing.context import Context
from dvc.parsing.interpolate import embrace
from dvc.utils.humanize import join
from dvc.utils.serialize import dump_yaml

from . import make_entry_definition, make_foreach_def

Expand Down Expand Up @@ -133,7 +132,7 @@ def test_interpolate_non_string(tmp_dir, dvc):


def test_partial_vars_doesnot_exist(tmp_dir, dvc):
dump_yaml("test_params.yaml", {"sub1": "sub1", "sub2": "sub2"})
(tmp_dir / "test_params.yaml").dump({"sub1": "sub1", "sub2": "sub2"})

definition = make_entry_definition(
tmp_dir,
Expand Down Expand Up @@ -281,7 +280,7 @@ def test_item_key_in_generated_stage_vars(tmp_dir, dvc, redefine, from_file):
context = Context(foo="bar")
vars_ = [redefine]
if from_file:
dump_yaml("test_params.yaml", redefine)
(tmp_dir / "test_params.yaml").dump(redefine)
vars_ = ["test_params.yaml"]

definition = make_foreach_def(
Expand Down
Loading

0 comments on commit fbcc3f3

Please sign in to comment.