Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

serialize: py: try to track segments of the source #5778

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 90 additions & 53 deletions dvc/utils/serialize/_py.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import ast
import dataclasses
import logging
import sys
from contextlib import contextmanager
from functools import partial
from typing import Any, Optional

from funcy import reraise

Expand All @@ -8,6 +13,8 @@
_PARAMS_KEY = "__params_old_key_for_update__"
_PARAMS_TEXT_KEY = "__params_text_key_for_update__"

logger = logging.getLogger(__name__)


class PythonFileCorruptedError(ParseError):
def __init__(self, path, message="Python file structure is corrupted"):
Expand All @@ -23,7 +30,7 @@ def parse_py(text, path):
with reraise(SyntaxError, PythonFileCorruptedError(path)):
tree = ast.parse(text, filename=path)

result = _ast_tree_to_dict(tree)
result = _ast_tree_to_dict(tree, text)
return result


Expand All @@ -32,8 +39,8 @@ def parse_py_for_update(text, path):
with reraise(SyntaxError, PythonFileCorruptedError(path)):
tree = ast.parse(text, filename=path)

result = _ast_tree_to_dict(tree)
result.update({_PARAMS_KEY: _ast_tree_to_dict(tree, lineno=True)})
result = _ast_tree_to_dict(tree, text)
result.update({_PARAMS_KEY: _ast_tree_to_dict(tree, text, lineno=True)})
result.update({_PARAMS_TEXT_KEY: text})
return result

Expand All @@ -49,16 +56,37 @@ def _dump(data, stream):
old_lines = data[_PARAMS_TEXT_KEY].splitlines(True)

def _update_lines(lines, old_dct, new_dct):
if not isinstance(old_dct, dict):
return lines

for key, value in new_dct.items():
if isinstance(value, dict):
lines = _update_lines(lines, old_dct[key], value)
elif value != old_dct[key]["value"]:
lineno = old_dct[key]["lineno"]
lines[lineno] = lines[lineno].replace(
f" = {old_dct[key]['value']}", f" = {value}"
)
old_value = old_dct.get(key)
if isinstance(old_value, dict) and isinstance(value, dict):
lines = _update_lines(lines, old_value, value)
continue

if isinstance(old_value, Node):
if isinstance(value, dict):
logger.trace("Old %s is %s, new value is of type %s", key, old_value, type(value))
continue
else:
continue

if old_value.value is not None and value == old_value.value:
# we should try to reduce amount of updates
# so if things didn't change at all or are equivalent
# we don't need to dump at all.
continue
elif old_value.lineno is not None and (old_value.segment or old_value.value):
old_segment = " = {}".format(old_value.segment or old_value.value)
new_segment = " = {}".format(value)
lineno = old_value.lineno
logger.trace("updating lineno:", lineno)
line = lines[lineno].replace(old_segment, new_segment)
logger.trace("before: ", lines[lineno])
lines[lineno] = line
logger.trace("after: ", lines[lineno])

return lines

new_lines = _update_lines(old_lines, old_params, new_params)
Expand Down Expand Up @@ -86,7 +114,7 @@ def modify_py(path, fs=None):
yield d


def _ast_tree_to_dict(tree, only_self_params=False, lineno=False):
def _ast_tree_to_dict(tree, source, only_self_params=False, lineno=False):
"""Parses ast trees to dict.

:param tree: ast.Tree
Expand All @@ -99,18 +127,24 @@ def _ast_tree_to_dict(tree, only_self_params=False, lineno=False):
try:
if isinstance(_body, (ast.Assign, ast.AnnAssign)):
result.update(
_ast_assign_to_dict(_body, only_self_params, lineno)
_ast_assign_to_dict(
_body, source, only_self_params, lineno
)
)
elif isinstance(_body, ast.ClassDef):
result.update(
{_body.name: _ast_tree_to_dict(_body, lineno=lineno)}
{
_body.name: _ast_tree_to_dict(
_body, source, lineno=lineno
)
}
)
elif (
isinstance(_body, ast.FunctionDef) and _body.name == "__init__"
):
result.update(
_ast_tree_to_dict(
_body, only_self_params=True, lineno=lineno
_body, source, only_self_params=True, lineno=lineno
)
)
except ValueError:
Expand All @@ -120,43 +154,14 @@ def _ast_tree_to_dict(tree, only_self_params=False, lineno=False):
return result


def _ast_assign_to_dict(assign, only_self_params=False, lineno=False):
result = {}

def _ast_assign_to_dict(assign, source, only_self_params=False, lineno=False):
if isinstance(assign, ast.AnnAssign):
name = _get_ast_name(assign.target, only_self_params)
elif len(assign.targets) == 1:
name = _get_ast_name(assign.targets[0], only_self_params)
else:
raise AttributeError

if isinstance(assign.value, ast.Dict):
value = {}
for key, val in zip(assign.value.keys, assign.value.values):
if lineno:
value[_get_ast_value(key)] = {
"lineno": assign.lineno - 1,
"value": _get_ast_value(val),
}
else:
value[_get_ast_value(key)] = _get_ast_value(val)
elif isinstance(assign.value, ast.List):
value = [_get_ast_value(val) for val in assign.value.elts]
elif isinstance(assign.value, ast.Set):
values = [_get_ast_value(val) for val in assign.value.elts]
value = set(values)
elif isinstance(assign.value, ast.Tuple):
values = [_get_ast_value(val) for val in assign.value.elts]
value = tuple(values)
else:
value = _get_ast_value(assign.value)

if lineno and not isinstance(assign.value, ast.Dict):
result[name] = {"lineno": assign.lineno - 1, "value": value}
else:
result[name] = value

return result
return {name: _get_ast_value(assign.value, source, value_only=not lineno)}


def _get_ast_name(target, only_self_params=False):
Expand All @@ -169,13 +174,45 @@ def _get_ast_name(target, only_self_params=False):
return result


def _get_ast_value(value):
if isinstance(value, ast.Num):
result = value.n
elif isinstance(value, ast.Str):
result = value.s
elif isinstance(value, ast.NameConstant):
result = value.value
def get_source_segment(source, node):
if sys.version_info > (3, 8):
return ast.get_source_segment(source, node)

try:
import astunparse

return astunparse.unparse(node).rstrip()
except:
return None


@dataclasses.dataclass
class Node:
value: Any
lineno: Optional[int]
segment: Optional[str]


def _get_ast_value(node, source=None, value_only: bool = False):
from ast import literal_eval

convert = partial(_get_ast_value, source=source, value_only=value_only)
if isinstance(node, ast.Tuple):
result = tuple(map(convert, node.elts))
elif isinstance(node, ast.Set):
result = set(map(convert, node.elts))
elif isinstance(node, ast.Dict):
result = dict(
(_get_ast_value(k, value_only=True), convert(v))
for k, v in zip(node.keys, node.values)
)
else:
raise ValueError
result = literal_eval(node)
if value_only or not source:
return result

lno = node.lineno - 1
segment = get_source_segment(source, node)
return Node(result, lno, segment)

return result
Empty file.
22 changes: 22 additions & 0 deletions tests/unit/utils/serialize/fixtures/classy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
class Prepare:
split: float = 0.20
seed = 20170428
shuffle_dataset = True
bag = {"mango", "apple", "orange"}


class Featurize:
max_features = 3000
ngrams = 2


class Train:
seed = 123
min_split: str
optimizer: str = "Adam"

def __init__(self):
self.seed = 20170428
self.n_est = 100
self.min_split = 64
self.data = {"key1": "value1", "key2": "value2"}
4 changes: 4 additions & 0 deletions tests/unit/utils/serialize/fixtures/composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
lst = [1, 2, 3, 4, 1]
bag = {"mango", "apple", "orange"}
tuh_pal = ("t", "u", "p", "l", "e")
data = {"key1": "value1", "key2": "value2"}
13 changes: 13 additions & 0 deletions tests/unit/utils/serialize/fixtures/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name = "DVC"
features = 43

optimizer = "Adam"
activation = b"relu"
units = 16

seed: int = 42
noise: float = 0.0001
dropout = 0.5
buckets = set()
default = -1
shuffle = False
50 changes: 50 additions & 0 deletions tests/unit/utils/serialize/test_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import textwrap

import pytest

from dvc.utils.serialize import modify_py


@pytest.mark.parametrize(
"val",
[
"1000.0",
"1_000_000.000_000",
"1e4",
"1e+4",
"1e04",
"1.0e4",
"1.0e+4",
"1.0e+04",
"1e-4",
"1.0e-4",
"1.0e-04",
"-1e4",
"-1e04",
"-1e+4",
"-1.0e4",
"-1.0e+4",
"-1.0e+04",
"-1e-4",
"-1.0e-4",
"-1.0e-04",
],
)
def test_modify_override_floats(tmp_dir, val):
skshetry marked this conversation as resolved.
Show resolved Hide resolved
source_fmt = textwrap.dedent(
"""\
threshold: float = {}
epochs = 10
"""
)
param_file = tmp_dir / "params.py"
param_file.write_text(source_fmt.format(val))

with modify_py(param_file) as d:
d["threshold"] = 1e3
assert source_fmt.format("1000.0") == param_file.read_text()

parsed = float(val)
with modify_py(param_file) as d:
d["threshold"] = parsed
assert source_fmt.format(str(parsed)) == param_file.read_text()
Empty file.
40 changes: 40 additions & 0 deletions tests/unit/utils/serialize/test_python_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path

from dvc.utils.serialize import load_py

fixtures = Path(__file__).parent / "fixtures"


def mod_to_dict(mod):
return {k: getattr(mod, k) for k in dir(mod) if not k.startswith("_")}


def test_simple():
from .fixtures import simple

assert load_py(fixtures / "simple.py") == mod_to_dict(simple)


def test_composite_data():
from .fixtures import composite

assert load_py(fixtures / "composite.py") == mod_to_dict(composite)


def test_classy():
assert load_py(fixtures / "classy.py") == {
"Featurize": {"max_features": 3000, "ngrams": 2},
"Prepare": {
"bag": {"apple", "mango", "orange"},
"seed": 20170428,
"shuffle_dataset": True,
"split": 0.2,
},
"Train": {
"min_split": 64,
"n_est": 100,
"optimizer": "Adam",
"seed": 20170428,
"data": {"key1": "value1", "key2": "value2"},
},
}