diff --git a/dvc/utils/serialize/__init__.py b/dvc/utils/serialize/__init__.py index 8b0b5490b4..a959f97c3f 100644 --- a/dvc/utils/serialize/__init__.py +++ b/dvc/utils/serialize/__init__.py @@ -2,11 +2,20 @@ from ._common import * # noqa, pylint: disable=wildcard-import from ._json import * # noqa, pylint: disable=wildcard-import +from ._py import * # noqa, pylint: disable=wildcard-import from ._toml import * # noqa, pylint: disable=wildcard-import from ._yaml import * # noqa, pylint: disable=wildcard-import LOADERS = defaultdict(lambda: load_yaml) # noqa: F405 -LOADERS.update({".toml": load_toml, ".json": load_json}) # noqa: F405 +LOADERS.update( + {".toml": load_toml, ".json": load_json, ".py": load_py} # noqa: F405 +) MODIFIERS = defaultdict(lambda: modify_yaml) # noqa: F405 -MODIFIERS.update({".toml": modify_toml, ".json": modify_json}) # noqa: F405 +MODIFIERS.update( + { + ".toml": modify_toml, # noqa: F405 + ".json": modify_json, # noqa: F405 + ".py": modify_py, # noqa: F405 + } +) diff --git a/dvc/utils/serialize/_py.py b/dvc/utils/serialize/_py.py new file mode 100644 index 0000000000..e51f0d4f33 --- /dev/null +++ b/dvc/utils/serialize/_py.py @@ -0,0 +1,181 @@ +import ast +from contextlib import contextmanager + +from funcy import reraise + +from ._common import ParseError, _dump_data, _load_data, _modify_data + +_PARAMS_KEY = "__params_old_key_for_update__" +_PARAMS_TEXT_KEY = "__params_text_key_for_update__" + + +class PythonFileCorruptedError(ParseError): + def __init__(self, path, message="Python file structure is corrupted"): + super().__init__(path, message) + + +def load_py(path, tree=None): + return _load_data(path, parser=parse_py, tree=tree) + + +def parse_py(text, path): + """Parses text from .py file into Python structure.""" + with reraise(SyntaxError, PythonFileCorruptedError(path)): + tree = ast.parse(text, filename=path) + + result = _ast_tree_to_dict(tree) + return result + + +def parse_py_for_update(text, path): + """Parses text into dict for update params.""" + with reraise(SyntaxError, PythonFileCorruptedError(path)): + tree = ast.parse(text, filename=path) + + result = _ast_tree_to_dict(tree) + result.update({_PARAMS_KEY: _ast_tree_to_dict(tree, lineno=True)}) + result.update({_PARAMS_TEXT_KEY: text}) + return result + + +def _dump(data, stream): + + old_params = data[_PARAMS_KEY] + new_params = { + key: value + for key, value in data.items() + if key not in [_PARAMS_KEY, _PARAMS_TEXT_KEY] + } + old_lines = data[_PARAMS_TEXT_KEY].splitlines(True) + + def _update_lines(lines, old_dct, new_dct): + for key, value in new_dct.items(): + if isinstance(value, dict): + lines = _update_lines(lines, old_dct[key], value) + elif value != old_dct[key]["value"]: + lineno = old_dct[key]["lineno"] + lines[lineno] = lines[lineno].replace( + f" = {old_dct[key]['value']}", f" = {value}" + ) + else: + continue + return lines + + new_lines = _update_lines(old_lines, old_params, new_params) + new_text = "".join(new_lines) + + try: + ast.parse(new_text) + except SyntaxError: + raise PythonFileCorruptedError( + stream.name, + "Python file structure is corrupted after update params", + ) + + stream.write(new_text) + stream.close() + + +def dump_py(path, data, tree=None): + return _dump_data(path, data, dumper=_dump, tree=tree) + + +@contextmanager +def modify_py(path, tree=None): + with _modify_data(path, parse_py_for_update, dump_py, tree=tree) as d: + yield d + + +def _ast_tree_to_dict(tree, only_self_params=False, lineno=False): + """Parses ast trees to dict. + + :param tree: ast.Tree + :param only_self_params: get only self params from class __init__ function + :param lineno: add params line number (needed for update) + :return: + """ + result = {} + for _body in tree.body: + try: + if isinstance(_body, (ast.Assign, ast.AnnAssign)): + result.update( + _ast_assign_to_dict(_body, only_self_params, lineno) + ) + elif isinstance(_body, ast.ClassDef): + result.update( + {_body.name: _ast_tree_to_dict(_body, lineno=lineno)} + ) + elif ( + isinstance(_body, ast.FunctionDef) and _body.name == "__init__" + ): + result.update( + _ast_tree_to_dict( + _body, only_self_params=True, lineno=lineno + ) + ) + except ValueError: + continue + except AttributeError: + continue + return result + + +def _ast_assign_to_dict(assign, only_self_params=False, lineno=False): + result = {} + + if isinstance(assign, ast.AnnAssign): + name = _get_ast_name(assign.target, only_self_params) + elif len(assign.targets) == 1: + name = _get_ast_name(assign.targets[0], only_self_params) + else: + raise AttributeError + + if isinstance(assign.value, ast.Dict): + value = {} + for key, val in zip(assign.value.keys, assign.value.values): + if lineno: + value[_get_ast_value(key)] = { + "lineno": assign.lineno - 1, + "value": _get_ast_value(val), + } + else: + value[_get_ast_value(key)] = _get_ast_value(val) + elif isinstance(assign.value, ast.List): + value = [_get_ast_value(val) for val in assign.value.elts] + elif isinstance(assign.value, ast.Set): + values = [_get_ast_value(val) for val in assign.value.elts] + value = set(values) + elif isinstance(assign.value, ast.Tuple): + values = [_get_ast_value(val) for val in assign.value.elts] + value = tuple(values) + else: + value = _get_ast_value(assign.value) + + if lineno and not isinstance(assign.value, ast.Dict): + result[name] = {"lineno": assign.lineno - 1, "value": value} + else: + result[name] = value + + return result + + +def _get_ast_name(target, only_self_params=False): + if hasattr(target, "id") and not only_self_params: + result = target.id + elif hasattr(target, "attr") and target.value.id == "self": + result = target.attr + else: + raise AttributeError + return result + + +def _get_ast_value(value): + if isinstance(value, ast.Num): + result = value.n + elif isinstance(value, ast.Str): + result = value.s + elif isinstance(value, ast.NameConstant): + result = value.value + else: + raise ValueError + return result diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index b11a277f5e..edff2f0aa1 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -1,3 +1,6 @@ +import pytest + +from dvc.utils.serialize import PythonFileCorruptedError from tests.func.test_repro_multistage import COPY_SCRIPT @@ -102,6 +105,75 @@ def test_get_baseline(tmp_dir, scm, dvc): assert dvc.experiments.get_baseline("stash@{0}") == expected +def test_update_py_params(tmp_dir, scm, dvc): + tmp_dir.gen("copy.py", COPY_SCRIPT) + tmp_dir.gen("params.py", "INT = 1\n") + stage = dvc.run( + cmd="python copy.py params.py metrics.py", + metrics_no_cache=["metrics.py"], + params=["params.py:INT"], + name="copy-file", + ) + scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.py", "metrics.py"]) + scm.commit("init") + + dvc.reproduce( + stage.addressing, experiment=True, params=["params.py:INT=2"] + ) + exp_a = dvc.experiments.scm.get_rev() + + dvc.experiments.checkout(exp_a) + assert (tmp_dir / "params.py").read_text().strip() == "INT = 2" + assert (tmp_dir / "metrics.py").read_text().strip() == "INT = 2" + + tmp_dir.gen( + "params.py", + "INT = 1\nFLOAT = 0.001\nDICT = {'a': 1}\n\n" + "class Train:\n seed = 2020\n\n" + "class Klass:\n def __init__(self):\n self.a = 111\n", + ) + stage = dvc.run( + cmd="python copy.py params.py metrics.py", + metrics_no_cache=["metrics.py"], + params=["params.py:INT,FLOAT,DICT,Train,Klass"], + name="copy-file", + ) + scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.py", "metrics.py"]) + scm.commit("init") + + dvc.reproduce( + stage.addressing, + experiment=True, + params=["params.py:FLOAT=0.1,Train.seed=2121,Klass.a=222"], + ) + exp_a = dvc.experiments.scm.get_rev() + + result = ( + "INT = 1\nFLOAT = 0.1\nDICT = {'a': 1}\n\n" + "class Train:\n seed = 2121\n\n" + "class Klass:\n def __init__(self):\n self.a = 222" + ) + + dvc.experiments.checkout(exp_a) + assert (tmp_dir / "params.py").read_text().strip() == result + assert (tmp_dir / "metrics.py").read_text().strip() == result + + tmp_dir.gen("params.py", "INT = 1\n") + stage = dvc.run( + cmd="python copy.py params.py metrics.py", + metrics_no_cache=["metrics.py"], + params=["params.py:INT"], + name="copy-file", + ) + scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.py", "metrics.py"]) + scm.commit("init") + + with pytest.raises(PythonFileCorruptedError): + dvc.reproduce( + stage.addressing, experiment=True, params=["params.py:INT=2a"] + ) + + def test_extend_branch(tmp_dir, scm, dvc): tmp_dir.gen("copy.py", COPY_SCRIPT) tmp_dir.gen("params.yaml", "foo: 1") diff --git a/tests/func/params/test_show.py b/tests/func/params/test_show.py index fd4621fb77..cdd39f5e31 100644 --- a/tests/func/params/test_show.py +++ b/tests/func/params/test_show.py @@ -24,6 +24,21 @@ def test_show_toml(tmp_dir, dvc): } +def test_show_py(tmp_dir, dvc): + tmp_dir.gen( + "params.py", + "CONST = 1\nIS_DIR: bool = True\n\n\nclass Config:\n foo = 42\n", + ) + dvc.run( + cmd="echo params.py", + params=["params.py:CONST,IS_DIR,Config.foo"], + single_stage=True, + ) + assert dvc.params.show() == { + "": {"params.py": {"CONST": 1, "IS_DIR": True, "Config": {"foo": 42}}} + } + + def test_show_multiple(tmp_dir, dvc): tmp_dir.gen("params.yaml", "foo: bar\nbaz: qux\n") dvc.run( diff --git a/tests/unit/dependency/test_params.py b/tests/unit/dependency/test_params.py index a1343a84b6..06168021d8 100644 --- a/tests/unit/dependency/test_params.py +++ b/tests/unit/dependency/test_params.py @@ -119,6 +119,73 @@ def test_read_params_toml(tmp_dir, dvc): assert dep.read_params() == {"some.path.foo": ["val1", "val2"]} +def test_read_params_py(tmp_dir, dvc): + parameters_file = "parameters.py" + tmp_dir.gen( + parameters_file, + "INT: int = 5\n" + "FLOAT = 0.001\n" + "STR = 'abc'\n" + "BOOL: bool = True\n" + "DICT = {'a': 1}\n" + "LIST = [1, 2, 3]\n" + "SET = {4, 5, 6}\n" + "TUPLE = (10, 100)\n" + "NONE = None\n", + ) + dep = ParamsDependency( + Stage(dvc), + parameters_file, + [ + "INT", + "FLOAT", + "STR", + "BOOL", + "DICT", + "LIST", + "SET", + "TUPLE", + "NONE", + ], + ) + assert dep.read_params() == { + "INT": 5, + "FLOAT": 0.001, + "STR": "abc", + "BOOL": True, + "DICT": {"a": 1}, + "LIST": [1, 2, 3], + "SET": {4, 5, 6}, + "TUPLE": (10, 100), + "NONE": None, + } + + tmp_dir.gen( + parameters_file, "class Train:\n foo = 'val1'\n bar = 'val2'\n", + ) + dep = ParamsDependency(Stage(dvc), parameters_file, ["Train.foo"]) + assert dep.read_params() == {"Train.foo": "val1"} + + dep = ParamsDependency(Stage(dvc), parameters_file, ["Train"]) + assert dep.read_params() == { + "Train": {"foo": "val1", "bar": "val2"}, + } + + tmp_dir.gen( + parameters_file, + "x = 4\n" + "config.x = 3\n" + "class Klass:\n" + " def __init__(self):\n" + " self.a = 'val1'\n" + " container.a = 2\n" + " self.container.a = 1\n" + " a = 'val2'\n", + ) + dep = ParamsDependency(Stage(dvc), parameters_file, ["x", "Klass.a"]) + assert dep.read_params() == {"x": 4, "Klass.a": "val1"} + + def test_get_hash_missing_config(dvc): dep = ParamsDependency(Stage(dvc), None, ["foo"]) with pytest.raises(MissingParamsError):