Skip to content

Commit

Permalink
params: Use ast.literal_eval for python params.
Browse files Browse the repository at this point in the history
Add unit tests for `parse_py`

Closes  #6953
  • Loading branch information
daavoo authored and efiop committed Nov 15, 2021
1 parent fdba79c commit 7df6090
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
26 changes: 7 additions & 19 deletions dvc/utils/serialize/_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,22 @@ def _ast_assign_to_dict(assign, only_self_params=False, lineno=False):
value = {}
for key, val in zip(assign.value.keys, assign.value.values):
if lineno:
value[_get_ast_value(key)] = {
value[ast.literal_eval(key)] = {
"lineno": assign.lineno - 1,
"value": _get_ast_value(val),
"value": ast.literal_eval(val),
}
else:
value[_get_ast_value(key)] = _get_ast_value(val)
value[ast.literal_eval(key)] = ast.literal_eval(val)
elif isinstance(assign.value, ast.List):
value = [_get_ast_value(val) for val in assign.value.elts]
value = [ast.literal_eval(val) for val in assign.value.elts]
elif isinstance(assign.value, ast.Set):
values = [_get_ast_value(val) for val in assign.value.elts]
values = [ast.literal_eval(val) for val in assign.value.elts]
value = set(values)
elif isinstance(assign.value, ast.Tuple):
values = [_get_ast_value(val) for val in assign.value.elts]
values = [ast.literal_eval(val) for val in assign.value.elts]
value = tuple(values)
else:
value = _get_ast_value(assign.value)
value = ast.literal_eval(assign.value)

if lineno and not isinstance(assign.value, ast.Dict):
result[name] = {"lineno": assign.lineno - 1, "value": value}
Expand All @@ -167,15 +167,3 @@ def _get_ast_name(target, only_self_params=False):
else:
raise AttributeError
return result


def _get_ast_value(value):
if isinstance(value, ast.Num):
result = value.n
elif isinstance(value, ast.Str):
result = value.s
elif isinstance(value, ast.NameConstant):
result = value.value
else:
raise ValueError
return result
45 changes: 45 additions & 0 deletions tests/unit/utils/serialize/test_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from dvc.utils.serialize import parse_py


@pytest.mark.parametrize(
"text,result",
[
("BOOL = True", {"BOOL": True}),
("INT = 5", {"INT": 5}),
("FLOAT = 0.001", {"FLOAT": 0.001}),
("STR = 'abc'", {"STR": "abc"}),
("DICT = {'a': 1, 'b': 2}", {"DICT": {"a": 1, "b": 2}}),
("LIST = [1, 2, 3]", {"LIST": [1, 2, 3]}),
("SET = {1, 2, 3}", {"SET": {1, 2, 3}}),
("TUPLE = (10, 100)", {"TUPLE": (10, 100)}),
("NONE = None", {"NONE": None}),
("UNARY_OP = -1", {"UNARY_OP": -1}),
(
"""class TrainConfig:
EPOCHS = 70
def __init__(self):
self.layers = 5
self.layers = 9 # TrainConfig.layers param will be 9
bar = 3 # Will NOT be found since it's locally scoped
""",
{"TrainConfig": {"EPOCHS": 70, "layers": 9}},
),
],
)
def test_parse_valid_types(text, result):
assert parse_py(text, "foo") == result


@pytest.mark.parametrize(
"text",
[
"CONSTRUCTOR = dict(a=1, b=2)",
"SUM = 1 + 2",
],
)
def test_parse_invalid_types(text):
assert parse_py(text, "foo") == {}

0 comments on commit 7df6090

Please sign in to comment.