Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add py support for ParamsDependency #4456

Merged
merged 2 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@

from ._common import * # noqa, pylint: disable=wildcard-import
from ._json import * # noqa, pylint: disable=wildcard-import
from ._py import * # noqa, pylint: disable=wildcard-import
from ._toml import * # noqa, pylint: disable=wildcard-import
from ._yaml import * # noqa, pylint: disable=wildcard-import

LOADERS = defaultdict(lambda: load_yaml) # noqa: F405
LOADERS.update({".toml": load_toml, ".json": load_json}) # noqa: F405
LOADERS.update(
{".toml": load_toml, ".json": load_json, ".py": load_py} # noqa: F405
)

MODIFIERS = defaultdict(lambda: modify_yaml) # noqa: F405
MODIFIERS.update({".toml": modify_toml, ".json": modify_json}) # noqa: F405
MODIFIERS.update(
{
".toml": modify_toml, # noqa: F405
".json": modify_json, # noqa: F405
".py": modify_py, # noqa: F405
}
)
181 changes: 181 additions & 0 deletions dvc/utils/serialize/_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import ast
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is not cross-python-version compatible, right? Not a blocker though.

from contextlib import contextmanager

from funcy import reraise

from ._common import ParseError, _dump_data, _load_data, _modify_data

_PARAMS_KEY = "__params_old_key_for_update__"
_PARAMS_TEXT_KEY = "__params_text_key_for_update__"


class PythonFileCorruptedError(ParseError):
def __init__(self, path, message="Python file structure is corrupted"):
super().__init__(path, message)


def load_py(path, tree=None):
return _load_data(path, parser=parse_py, tree=tree)


def parse_py(text, path):
"""Parses text from .py file into Python structure."""
with reraise(SyntaxError, PythonFileCorruptedError(path)):
tree = ast.parse(text, filename=path)

result = _ast_tree_to_dict(tree)
return result


def parse_py_for_update(text, path):
"""Parses text into dict for update params."""
with reraise(SyntaxError, PythonFileCorruptedError(path)):
tree = ast.parse(text, filename=path)

result = _ast_tree_to_dict(tree)
result.update({_PARAMS_KEY: _ast_tree_to_dict(tree, lineno=True)})
result.update({_PARAMS_TEXT_KEY: text})
return result


def _dump(data, stream):

old_params = data[_PARAMS_KEY]
new_params = {
key: value
for key, value in data.items()
if key not in [_PARAMS_KEY, _PARAMS_TEXT_KEY]
}
old_lines = data[_PARAMS_TEXT_KEY].splitlines(True)

def _update_lines(lines, old_dct, new_dct):
for key, value in new_dct.items():
if isinstance(value, dict):
lines = _update_lines(lines, old_dct[key], value)
elif value != old_dct[key]["value"]:
lineno = old_dct[key]["lineno"]
lines[lineno] = lines[lineno].replace(
f" = {old_dct[key]['value']}", f" = {value}"
)
else:
continue
return lines

new_lines = _update_lines(old_lines, old_params, new_params)
new_text = "".join(new_lines)

try:
ast.parse(new_text)
except SyntaxError:
raise PythonFileCorruptedError(
stream.name,
"Python file structure is corrupted after update params",
)

stream.write(new_text)
stream.close()


def dump_py(path, data, tree=None):
return _dump_data(path, data, dumper=_dump, tree=tree)


@contextmanager
def modify_py(path, tree=None):
with _modify_data(path, parse_py_for_update, dump_py, tree=tree) as d:
yield d


def _ast_tree_to_dict(tree, only_self_params=False, lineno=False):
"""Parses ast trees to dict.

:param tree: ast.Tree
:param only_self_params: get only self params from class __init__ function
:param lineno: add params line number (needed for update)
:return:
"""
result = {}
for _body in tree.body:
try:
if isinstance(_body, (ast.Assign, ast.AnnAssign)):
result.update(
_ast_assign_to_dict(_body, only_self_params, lineno)
)
elif isinstance(_body, ast.ClassDef):
result.update(
{_body.name: _ast_tree_to_dict(_body, lineno=lineno)}
)
elif (
isinstance(_body, ast.FunctionDef) and _body.name == "__init__"
):
result.update(
_ast_tree_to_dict(
_body, only_self_params=True, lineno=lineno
)
)
except ValueError:
continue
except AttributeError:
continue
return result


def _ast_assign_to_dict(assign, only_self_params=False, lineno=False):
result = {}

if isinstance(assign, ast.AnnAssign):
name = _get_ast_name(assign.target, only_self_params)
elif len(assign.targets) == 1:
name = _get_ast_name(assign.targets[0], only_self_params)
else:
raise AttributeError

if isinstance(assign.value, ast.Dict):
value = {}
for key, val in zip(assign.value.keys, assign.value.values):
if lineno:
value[_get_ast_value(key)] = {
"lineno": assign.lineno - 1,
"value": _get_ast_value(val),
}
else:
value[_get_ast_value(key)] = _get_ast_value(val)
elif isinstance(assign.value, ast.List):
value = [_get_ast_value(val) for val in assign.value.elts]
elif isinstance(assign.value, ast.Set):
values = [_get_ast_value(val) for val in assign.value.elts]
value = set(values)
elif isinstance(assign.value, ast.Tuple):
values = [_get_ast_value(val) for val in assign.value.elts]
value = tuple(values)
else:
value = _get_ast_value(assign.value)

if lineno and not isinstance(assign.value, ast.Dict):
result[name] = {"lineno": assign.lineno - 1, "value": value}
else:
result[name] = value

return result


def _get_ast_name(target, only_self_params=False):
if hasattr(target, "id") and not only_self_params:
result = target.id
elif hasattr(target, "attr") and target.value.id == "self":
result = target.attr
else:
raise AttributeError
return result


def _get_ast_value(value):
if isinstance(value, ast.Num):
result = value.n
elif isinstance(value, ast.Str):
result = value.s
elif isinstance(value, ast.NameConstant):
result = value.value
else:
raise ValueError
Comment on lines +173 to +180
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that ast.{Num/Str/NameConstant} has been deprecated in Python3.8, but still works on 3.8.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine as-is for now though.

return result
72 changes: 72 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest

from dvc.utils.serialize import PythonFileCorruptedError
from tests.func.test_repro_multistage import COPY_SCRIPT


Expand Down Expand Up @@ -102,6 +105,75 @@ def test_get_baseline(tmp_dir, scm, dvc):
assert dvc.experiments.get_baseline("stash@{0}") == expected


def test_update_py_params(tmp_dir, scm, dvc):
tmp_dir.gen("copy.py", COPY_SCRIPT)
tmp_dir.gen("params.py", "INT = 1\n")
stage = dvc.run(
cmd="python copy.py params.py metrics.py",
metrics_no_cache=["metrics.py"],
params=["params.py:INT"],
name="copy-file",
)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.py", "metrics.py"])
scm.commit("init")

dvc.reproduce(
stage.addressing, experiment=True, params=["params.py:INT=2"]
)
exp_a = dvc.experiments.scm.get_rev()

dvc.experiments.checkout(exp_a)
assert (tmp_dir / "params.py").read_text().strip() == "INT = 2"
assert (tmp_dir / "metrics.py").read_text().strip() == "INT = 2"

tmp_dir.gen(
"params.py",
"INT = 1\nFLOAT = 0.001\nDICT = {'a': 1}\n\n"
"class Train:\n seed = 2020\n\n"
"class Klass:\n def __init__(self):\n self.a = 111\n",
)
stage = dvc.run(
cmd="python copy.py params.py metrics.py",
metrics_no_cache=["metrics.py"],
params=["params.py:INT,FLOAT,DICT,Train,Klass"],
name="copy-file",
)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.py", "metrics.py"])
scm.commit("init")

dvc.reproduce(
stage.addressing,
experiment=True,
params=["params.py:FLOAT=0.1,Train.seed=2121,Klass.a=222"],
)
exp_a = dvc.experiments.scm.get_rev()

result = (
"INT = 1\nFLOAT = 0.1\nDICT = {'a': 1}\n\n"
"class Train:\n seed = 2121\n\n"
"class Klass:\n def __init__(self):\n self.a = 222"
)

dvc.experiments.checkout(exp_a)
assert (tmp_dir / "params.py").read_text().strip() == result
assert (tmp_dir / "metrics.py").read_text().strip() == result

tmp_dir.gen("params.py", "INT = 1\n")
stage = dvc.run(
cmd="python copy.py params.py metrics.py",
metrics_no_cache=["metrics.py"],
params=["params.py:INT"],
name="copy-file",
)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.py", "metrics.py"])
scm.commit("init")

with pytest.raises(PythonFileCorruptedError):
dvc.reproduce(
stage.addressing, experiment=True, params=["params.py:INT=2a"]
)


def test_extend_branch(tmp_dir, scm, dvc):
tmp_dir.gen("copy.py", COPY_SCRIPT)
tmp_dir.gen("params.yaml", "foo: 1")
Expand Down
15 changes: 15 additions & 0 deletions tests/func/params/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ def test_show_toml(tmp_dir, dvc):
}


def test_show_py(tmp_dir, dvc):
tmp_dir.gen(
"params.py",
"CONST = 1\nIS_DIR: bool = True\n\n\nclass Config:\n foo = 42\n",
)
dvc.run(
cmd="echo params.py",
params=["params.py:CONST,IS_DIR,Config.foo"],
single_stage=True,
)
assert dvc.params.show() == {
"": {"params.py": {"CONST": 1, "IS_DIR": True, "Config": {"foo": 42}}}
}


def test_show_multiple(tmp_dir, dvc):
tmp_dir.gen("params.yaml", "foo: bar\nbaz: qux\n")
dvc.run(
Expand Down
67 changes: 67 additions & 0 deletions tests/unit/dependency/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,73 @@ def test_read_params_toml(tmp_dir, dvc):
assert dep.read_params() == {"some.path.foo": ["val1", "val2"]}


def test_read_params_py(tmp_dir, dvc):
parameters_file = "parameters.py"
tmp_dir.gen(
parameters_file,
"INT: int = 5\n"
"FLOAT = 0.001\n"
"STR = 'abc'\n"
"BOOL: bool = True\n"
"DICT = {'a': 1}\n"
"LIST = [1, 2, 3]\n"
"SET = {4, 5, 6}\n"
"TUPLE = (10, 100)\n"
"NONE = None\n",
)
dep = ParamsDependency(
Stage(dvc),
parameters_file,
[
"INT",
"FLOAT",
"STR",
"BOOL",
"DICT",
"LIST",
"SET",
"TUPLE",
"NONE",
],
)
assert dep.read_params() == {
"INT": 5,
"FLOAT": 0.001,
"STR": "abc",
"BOOL": True,
"DICT": {"a": 1},
"LIST": [1, 2, 3],
"SET": {4, 5, 6},
"TUPLE": (10, 100),
"NONE": None,
}

tmp_dir.gen(
parameters_file, "class Train:\n foo = 'val1'\n bar = 'val2'\n",
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["Train.foo"])
assert dep.read_params() == {"Train.foo": "val1"}

dep = ParamsDependency(Stage(dvc), parameters_file, ["Train"])
assert dep.read_params() == {
"Train": {"foo": "val1", "bar": "val2"},
}

tmp_dir.gen(
parameters_file,
"x = 4\n"
"config.x = 3\n"
"class Klass:\n"
" def __init__(self):\n"
" self.a = 'val1'\n"
" container.a = 2\n"
" self.container.a = 1\n"
" a = 'val2'\n",
)
dep = ParamsDependency(Stage(dvc), parameters_file, ["x", "Klass.a"])
assert dep.read_params() == {"x": 4, "Klass.a": "val1"}


def test_get_hash_missing_config(dvc):
dep = ParamsDependency(Stage(dvc), None, ["foo"])
with pytest.raises(MissingParamsError):
Expand Down