Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parsing: Support dict unpacking in cmd. #7907

Merged
merged 1 commit into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,8 @@ class RelPath(str):
"plots": str,
"live": str,
},
"parsing": {
"bool": All(Lower, Choices("store_true", "boolean_optional")),
"list": All(Lower, Choices("nargs", "append")),
},
}
2 changes: 1 addition & 1 deletion dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _resolve(
) -> DictStr:
try:
return context.resolve(
value, skip_interpolation_checks=skip_checks
value, skip_interpolation_checks=skip_checks, key=key
)
except (ParseError, KeyNotInContext) as exc:
format_and_raise(
Expand Down
13 changes: 8 additions & 5 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
normalize_key,
recurse,
str_interpolate,
validate_value,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -506,7 +507,7 @@ def set_temporarily(self, to_set: DictStr, reserve: bool = False):
self.data.pop(key, None)

def resolve(
self, src, unwrap=True, skip_interpolation_checks=False
self, src, unwrap=True, skip_interpolation_checks=False, key=None
) -> Any:
"""Recursively resolves interpolation and returns resolved data.

Expand All @@ -522,10 +523,10 @@ def resolve(
{'lst': [1, 2, 3]}
"""
func = recurse(self.resolve_str)
return func(src, unwrap, skip_interpolation_checks)
return func(src, unwrap, skip_interpolation_checks, key)

def resolve_str(
self, src: str, unwrap=True, skip_interpolation_checks=False
self, src: str, unwrap=True, skip_interpolation_checks=False, key=None
) -> str:
"""Resolves interpolated string to it's original value,
or in case of multiple interpolations, a combined string.
Expand All @@ -543,10 +544,12 @@ def resolve_str(
expr = get_expression(
matches[0], skip_checks=skip_interpolation_checks
)
return self.select(expr, unwrap=unwrap)
value = self.select(expr, unwrap=unwrap)
validate_value(value, key)
return value
# but not "${num} days"
return str_interpolate(
src, matches, self, skip_checks=skip_interpolation_checks
src, matches, self, skip_checks=skip_interpolation_checks, key=key
)


Expand Down
64 changes: 57 additions & 7 deletions dvc/parsing/interpolate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import re
import typing
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from functools import singledispatch

from funcy import memoize, rpartial

from dvc.exceptions import DvcException
from dvc.utils.flatten import flatten

if typing.TYPE_CHECKING:
from typing import List, Match
Expand Down Expand Up @@ -80,6 +81,45 @@ def _(obj: bool):
return "true" if obj else "false"


@to_str.register(dict)
def _(obj: dict):
from dvc.config import Config

config = Config().get("parsing", {})

result = ""
for k, v in flatten(obj).items():

if isinstance(v, bool):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works for argparse.BooleanOptionalAction not for store_true/false. Maybe need some better way to handle this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same issue about different options, as commented below #7907 (comment)

Which option do you think we should consider as the most appropriate default?

Note that this can be used for things beyond argparse, for example you can use the interpolation to pass flags to arbitrary executables that you call inside cmd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added config.parsing.list with store_true and boolean_optional .

Not sure about store_false because the interaction with dvc params would look kind of strange, IMO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have no idea about it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's more than enough to support for now.

if v:
result += f"--{k} "
else:
if config.get("bool", "store_true") == "boolean_optional":
result += f"--no-{k} "

elif isinstance(v, str):
result += f"--{k} '{v}' "

elif isinstance(v, Iterable):
for n, i in enumerate(v):
if isinstance(i, str):
i = f"'{i}'"
elif isinstance(i, Iterable):
raise ParseError(
f"Cannot interpolate nested iterable in '{k}'"
)

if config.get("list", "nargs") == "append":
result += f"--{k} {i} "
else:
result += f"{i} " if n > 0 else f"--{k} {i} "

else:
result += f"--{k} {v} "

return result.rstrip()


def _format_exc_msg(exc: "ParseException"):
from pyparsing import ParseException

Expand Down Expand Up @@ -148,23 +188,33 @@ def get_expression(match: "Match", skip_checks: bool = False):
return inner if skip_checks else parse_expr(inner)


def validate_value(value, key):
from .context import PRIMITIVES

not_primitive = value is not None and not isinstance(value, PRIMITIVES)
not_foreach = key is not None and "foreach" not in key
if not_primitive and not_foreach:
if isinstance(value, dict):
if key == "cmd":
return True
raise ParseError(
f"Cannot interpolate data of type '{type(value).__name__}'"
)


def str_interpolate(
template: str,
matches: "List[Match]",
context: "Context",
skip_checks: bool = False,
key=None,
):
from .context import PRIMITIVES

index, buf = 0, ""
for match in matches:
start, end = match.span(0)
expr = get_expression(match, skip_checks=skip_checks)
value = context.select(expr, unwrap=True)
if value is not None and not isinstance(value, PRIMITIVES):
raise ParseError(
f"Cannot interpolate data of type '{type(value).__name__}'"
)
validate_value(value, key)
buf += template[index:start] + to_str(value)
index = end
buf += template[index:]
Expand Down
20 changes: 18 additions & 2 deletions tests/func/parsing/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,34 @@ def test_wdir_failed_to_interpolate(tmp_dir, dvc, wdir, expected_msg):

def test_interpolate_non_string(tmp_dir, dvc):
definition = make_entry_definition(
tmp_dir, "build", {"cmd": "echo ${models}"}, Context(models={})
tmp_dir, "build", {"outs": "${models}"}, Context(models={})
)
with pytest.raises(ResolveError) as exc_info:
definition.resolve()

assert str(exc_info.value) == (
"failed to parse 'stages.build.cmd' in 'dvc.yaml':\n"
"failed to parse 'stages.build.outs' in 'dvc.yaml':\n"
"Cannot interpolate data of type 'dict'"
)
assert definition.context == {"models": {}}


def test_interpolate_nested_iterable(tmp_dir, dvc):
definition = make_entry_definition(
tmp_dir,
"build",
{"cmd": "echo ${models}"},
Context(models={"list": [1, [2, 3]]}),
)
with pytest.raises(ResolveError) as exc_info:
definition.resolve()

assert str(exc_info.value) == (
"failed to parse 'stages.build.cmd' in 'dvc.yaml':\n"
"Cannot interpolate nested iterable in 'list'"
)


def test_partial_vars_doesnot_exist(tmp_dir, dvc):
(tmp_dir / "test_params.yaml").dump({"sub1": "sub1", "sub2": "sub2"})

Expand Down
56 changes: 56 additions & 0 deletions tests/func/parsing/test_interpolated_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,59 @@ def test_vars_load_partial(tmp_dir, dvc, local, vars_):
d["vars"] = vars_
resolver = DataResolver(dvc, tmp_dir.fs_path, d)
resolver.resolve()


@pytest.mark.parametrize(
"bool_config, list_config",
[(None, None), ("store_true", "nargs"), ("boolean_optional", "append")],
)
def test_cmd_dict(tmp_dir, dvc, bool_config, list_config):
with dvc.config.edit() as conf:
if bool_config:
conf["parsing"]["bool"] = bool_config
if list_config:
conf["parsing"]["list"] = list_config

data = {
dberenbaum marked this conversation as resolved.
Show resolved Hide resolved
"dict": {
"foo": "foo",
"bar": 2,
"string": "spaced string",
"bool": True,
"bool-false": False,
"list": [1, 2, "foo"],
"nested": {"foo": "foo"},
}
}
(tmp_dir / DEFAULT_PARAMS_FILE).dump(data)
resolver = DataResolver(
dvc,
tmp_dir.fs_path,
{"stages": {"stage1": {"cmd": "python script.py ${dict}"}}},
)

if bool_config is None or bool_config == "store_true":
bool_resolved = " --bool"
else:
bool_resolved = " --bool --no-bool-false"

if list_config is None or list_config == "nargs":
list_resolved = " --list 1 2 'foo'"
else:
list_resolved = " --list 1 --list 2 --list 'foo'"

assert_stage_equal(
resolver.resolve(),
{
"stages": {
"stage1": {
"cmd": "python script.py"
" --foo 'foo' --bar 2"
" --string 'spaced string'"
f"{bool_resolved}"
f"{list_resolved}"
" --nested.foo 'foo'"
}
}
},
)