Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dump yaml: .dvc and dvc.lock in 1.1 version, dvc.yaml in 1.2 version #4380

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from dvc.stage.loader import SingleStageLoader, StageLoader
from dvc.utils import relpath
from dvc.utils.collections import apply_diff
from dvc.utils.yaml import dump_yaml, parse_yaml, parse_yaml_for_update
from dvc.utils.yaml import (
YAMLVersion,
dump_yaml,
parse_yaml,
parse_yaml_for_update,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -176,7 +181,9 @@ def _dump_pipeline_file(self, stage):
data = {}
if self.exists():
with open(self.path) as fd:
data = parse_yaml_for_update(fd.read(), self.path)
data = parse_yaml_for_update(
fd.read(), self.path, version=YAMLVersion.V12
)
else:
logger.info("Creating '%s'", self.relpath)
open(self.path, "w+").close()
Expand All @@ -196,7 +203,7 @@ def _dump_pipeline_file(self, stage):
else:
data["stages"].update(stage_data)

dump_yaml(self.path, data)
dump_yaml(self.path, data, version=YAMLVersion.V12)
self.repo.scm.track_file(self.relpath)

@property
Expand Down
124 changes: 101 additions & 23 deletions dvc/utils/yaml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections import OrderedDict

from funcy import reraise
from ruamel.yaml import YAML
from ruamel.yaml.emitter import Emitter
from ruamel.yaml.error import YAMLError
from ruamel.yaml.events import DocumentStartEvent

from dvc.exceptions import YAMLFileCorruptedError

Expand All @@ -11,21 +14,43 @@
from yaml import SafeLoader


def load_yaml(path):
with open(path, encoding="utf-8") as fd:
return parse_yaml(fd.read(), path)
class YAMLVersion:
V11 = (1, 1)
V12 = (1, 2)


def _ruamel_load(text, path, *, version=None, typ="safe"):
yaml = YAML(typ=typ)
if version in (None, YAMLVersion.V11):
yaml.version = YAMLVersion.V11
with reraise(YAMLError, YAMLFileCorruptedError(path)):
return yaml.load(text) or {}


def parse_yaml(text, path):
try:
import yaml
def _parse_yaml_v1_1(text, path):
import yaml

with reraise(yaml.error.YAMLError, YAMLFileCorruptedError(path)):
return yaml.load(text, Loader=SafeLoader) or {}
except yaml.error.YAMLError as exc:
raise YAMLFileCorruptedError(path) from exc


def parse_yaml_for_update(text, path):
def _parse_yaml_v1_2(text, path):
return _ruamel_load(text, path, version=YAMLVersion.V12)


def parse_yaml(text, path, *, version=None):
parser = _parse_yaml_v1_1
if version == YAMLVersion.V12:
parser = _parse_yaml_v1_2
return parser(text, path)


def load_yaml(path, *, version=None):
with open(path, encoding="utf-8") as fd:
return parse_yaml(fd.read(), path, version=version)


def parse_yaml_for_update(text, path, *, version=None):
"""Parses text into Python structure.

Unlike `parse_yaml()` this returns ordered dicts, values have special
Expand All @@ -34,20 +59,73 @@ def parse_yaml_for_update(text, path):

This one is, however, several times slower than simple `parse_yaml()`.
"""
try:
yaml = YAML()
return yaml.load(text) or {}
except YAMLError as exc:
raise YAMLFileCorruptedError(path) from exc
return _ruamel_load(text, path, version=version, typ="rt")


class _YAMLEmitterNoVersionDirective(Emitter):
"""
This emitter skips printing version directive when we set yaml version
on `dump_yaml()`. Also, ruamel.yaml will still try to add a document start
marker line (assuming version directive was written), for which we
need to find a _hack_ to ensure the marker line is not written to the
stream, as our dvcfiles and hopefully, params file are single document
YAML files.

NOTE: do not use this emitter during load/parse, only when dump for 1.1
"""

MARKER_START_LINE = "---"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._first_document = False

def write_version_directive(self, version_text):
"""Do not write version directive at all."""

def expect_first_document_start(self):
# as our yaml files are expected to only have a single document,
# this is not needed, just trying to make it a bit resilient,
# but it's not well-thought out.
self._first_document = True
ret = super().expect_first_document_start()
self._first_document = False
return ret

# pylint: disable=signature-differs
def write_indicator(self, indicator, *args, **kwargs):
# NOTE: if the yaml file already have a directive,
# this will strip it
if isinstance(self.event, DocumentStartEvent):
skip_marker = (
# see comments in _expect_first_document_start()
getattr(self, "_first_document", False)
and not self.event.explicit
and not self.canonical
and not self.event.tags
)
if skip_marker and indicator == self.MARKER_START_LINE:
return
super().write_indicator(indicator, *args, **kwargs)


def _dump_yaml(data, stream, *, version=None, with_directive=False):
yaml = YAML()
if version in (None, YAMLVersion.V11):
yaml.version = YAMLVersion.V11 # Workaround for #4281
if not with_directive:
yaml.Emitter = _YAMLEmitterNoVersionDirective
elif with_directive and version == YAMLVersion.V12:
# `ruamel.yaml` dumps in 1.2 by default
yaml.version = version

yaml.default_flow_style = False
yaml.Representer.add_representer(
OrderedDict, yaml.Representer.represent_dict
)
yaml.dump(data, stream)


def dump_yaml(path, data):
def dump_yaml(path, data, *, version=None, with_directive=False):
with open(path, "w", encoding="utf-8") as fd:
yaml = YAML()
yaml.default_flow_style = False
# tell Dumper to represent OrderedDict as
# normal dict
yaml.Representer.add_representer(
OrderedDict, yaml.Representer.represent_dict
)
yaml.dump(data, fd)
_dump_yaml(data, fd, version=version, with_directive=with_directive)
82 changes: 82 additions & 0 deletions tests/unit/utils/test_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest

from dvc.exceptions import YAMLFileCorruptedError
from dvc.utils.yaml import (
YAMLVersion,
dump_yaml,
parse_yaml,
parse_yaml_for_update,
)

V12 = YAMLVersion.V12
V11 = YAMLVersion.V11
V12_DIRECTIVE = "%YAML 1.2\n---\n"
V11_DIRECTIVE = "%YAML 1.1\n---\n"


@pytest.mark.parametrize("data", [{"x": 3e24}])
@pytest.mark.parametrize("with_directive", [True, False])
@pytest.mark.parametrize(
"ver, directive, expected",
[
# dot before mantissa is not required in yaml1.2,
# whereas it's required in yaml1.1
(V12, V12_DIRECTIVE, "x: 3e+24\n"),
(V11, V11_DIRECTIVE, "x: 3.0e+24\n"),
],
)
def test_dump_yaml_with_directive(
tmp_dir, ver, directive, expected, with_directive, data
):
dump_yaml("data.yaml", data, version=ver, with_directive=with_directive)
actual = (tmp_dir / "data.yaml").read_text()
exp = expected if not with_directive else directive + expected
assert actual == exp


@pytest.mark.parametrize(
"parser, rt_parser", [(parse_yaml, False), (parse_yaml_for_update, True)]
)
def test_load_yaml(parser, rt_parser):
# ruamel.yaml.load() complains about dot before mantissa not allowed
# on 1.1 and goes on anyway to convert this to a float based on 1.2 spec
str_value = "3e24"
float_value = float(str_value)
yaml11_text = "x: 3.0e+24" # pyyaml parses as str if there's no +/- sign
# luckily, `ruamel.yaml` always dumps with sign
yaml12_text = f"x: {str_value}"
assert parser(yaml11_text, "data.yaml") == {"x": float_value}
assert parser(yaml12_text, "data.yaml") == {
"x": float_value if rt_parser else str_value
}

assert parser(yaml11_text, "data.yaml", version=V12) == {"x": float_value}
assert parser(yaml12_text, "data.yaml", version=V12) == {"x": float_value}

with pytest.raises(YAMLFileCorruptedError):
assert parser("invalid: '", "data.yaml")

with pytest.raises(YAMLFileCorruptedError):
assert parser("invalid: '", "data.yaml", version=V12)


def test_comments_are_preserved_on_update_and_dump(tmp_dir):
text = "x: 3 # this is a comment"
d = parse_yaml_for_update(text, "data.yaml")
d["w"] = 7e24

dump_yaml("data.yaml", d)
assert (tmp_dir / "data.yaml").read_text() == f"{text}\nw: 7.0e+24\n"

dump_yaml("data.yaml", d, with_directive=True)
assert (
tmp_dir / "data.yaml"
).read_text() == V11_DIRECTIVE + f"{text}\nw: 7.0e+24\n"

dump_yaml("data.yaml", d, version=V12)
assert (tmp_dir / "data.yaml").read_text() == f"{text}\nw: 7e+24\n"

dump_yaml("data.yaml", d, with_directive=True, version=V12)
assert (
tmp_dir / "data.yaml"
).read_text() == V12_DIRECTIVE + f"{text}\nw: 7e+24\n"