diff --git a/dvc/utils/serialize/__init__.py b/dvc/utils/serialize/__init__.py index e42cb8cae3..bd8ca43cf5 100644 --- a/dvc/utils/serialize/__init__.py +++ b/dvc/utils/serialize/__init__.py @@ -2,6 +2,7 @@ from typing import DefaultDict from ._common import * # noqa: F403 +from ._ini import * # noqa: F403 from ._json import * # noqa: F403 from ._py import * # noqa: F403 from ._toml import * # noqa: F403 @@ -10,13 +11,27 @@ LOADERS: DefaultDict[str, LoaderFn] = defaultdict( # noqa: F405 lambda: load_yaml # noqa: F405 ) -LOADERS.update({".toml": load_toml, ".json": load_json, ".py": load_py}) # noqa: F405 +LOADERS.update( + { + ".toml": load_toml, # noqa: F405 + ".json": load_json, # noqa: F405 + ".py": load_py, # noqa: F405 + ".cfg": load_ini, # noqa: F405 + ".ini": load_ini, # noqa: F405 + } +) PARSERS: DefaultDict[str, ParserFn] = defaultdict( # noqa: F405 lambda: parse_yaml # noqa: F405 ) PARSERS.update( - {".toml": parse_toml, ".json": parse_json, ".py": parse_py} # noqa: F405 + { + ".toml": parse_toml, # noqa: F405 + ".json": parse_json, # noqa: F405 + ".py": parse_py, # noqa: F405 + ".cfg": parse_ini, # noqa: F405 + ".ini": parse_ini, # noqa: F405 + } ) @@ -29,7 +44,15 @@ def load_path(fs_path, fs, **kwargs): DUMPERS: DefaultDict[str, DumperFn] = defaultdict( # noqa: F405 lambda: dump_yaml # noqa: F405 ) -DUMPERS.update({".toml": dump_toml, ".json": dump_json, ".py": dump_py}) # noqa: F405 +DUMPERS.update( + { + ".toml": dump_toml, # noqa: F405 + ".json": dump_json, # noqa: F405 + ".py": dump_py, # noqa: F405 + ".cfg": dump_ini, # noqa: F405 + ".ini": dump_ini, # noqa: F405 + } +) MODIFIERS: DefaultDict[str, ModifierFn] = defaultdict( # noqa: F405 lambda: modify_yaml # noqa: F405 @@ -39,5 +62,7 @@ def load_path(fs_path, fs, **kwargs): ".toml": modify_toml, # noqa: F405 ".json": modify_json, # noqa: F405 ".py": modify_py, # noqa: F405 + ".cfg": modify_ini, # noqa: F405 + ".ini": modify_ini, # noqa: F405 } ) diff --git a/dvc/utils/serialize/_ini.py b/dvc/utils/serialize/_ini.py new file mode 100644 index 0000000000..c2dcb5f3fa --- /dev/null +++ b/dvc/utils/serialize/_ini.py @@ -0,0 +1,123 @@ +import json +import re +from ast import literal_eval +from contextlib import contextmanager +from typing import Any, Dict + +from funcy import reraise + +from ._common import ParseError, _dump_data, _load_data, _modify_data + + +class ConfigFileCorruptedError(ParseError): + def __init__(self, path): + super().__init__(path, "Config file structure is corrupted") + + +def split_path(path: str): + offset = 0 + result = [] + for match in re.finditer(r"(?:'([^']*)'|\"([^\"]*)\"|([^.]*))(?:[.]|$)", path): + assert match.start() == offset, f"Malformed path: {path!r} in config" + offset = match.end() + result.append(next(g for g in match.groups() if g is not None)) + if offset == len(path): + break + return result + + +def join_path(path): + # This is required to handle sections like `[foo."bar.baz".qux]` + return ".".join(repr(x) if "." in x else x for x in path) + + +def config_literal_eval(s: str): + try: + return literal_eval(s) + except (ValueError, SyntaxError): + try: + return json.loads(s) + except ValueError: + return s + + +def config_literal_dump(v: Any): + if isinstance(v, str) and config_literal_eval(str(v)) == v: + return str(v) + return json.dumps(v) + + +def flatten_sections(root: Dict[str, Any]) -> Dict[str, Any]: + res: Dict = {} + + def rec(d, path): + res.setdefault(join_path(path), {}) + section = {} + for k, v in d.items(): + if isinstance(v, dict): + rec(v, (*path, k)) + else: + section[k] = v + res[join_path(path)].update(section) + + rec(root, ()) + res.pop("", None) + return dict(res) + + +def load_ini(path, fs=None, **kwargs): + return _load_data(path, parser=parse_ini, fs=fs) + + +def parse_ini(text, path, **kwargs): + import configparser + + with reraise(configparser.Error, ConfigFileCorruptedError(path)): + parser = configparser.ConfigParser(interpolation=None) + parser.optionxform = str # type: ignore[assignment,method-assign] + parser.read_string(text) + config: Dict = {} + for section in parser.sections(): + parts = split_path(section) + current = config + for part in parts: + if part not in current: + current[part] = current = {} + else: + current = current[part] + current.update( + {k: config_literal_eval(v) for k, v in parser.items(section)} + ) + + return config + + +def _dump(data, stream): + import configparser + + prepared = flatten_sections(data) + + parser = configparser.ConfigParser(interpolation=None) + + parser.optionxform = str # type: ignore[assignment,method-assign] + for section_name, section in prepared.items(): + content = {k: config_literal_dump(v) for k, v in section.items()} + if content: + parser.add_section(section_name) + parser[section_name].update(content) + + return parser.write(stream) + + +def dump_ini(path, data, fs=None, **kwargs): + return _dump_data(path, data, dumper=_dump, fs=fs, **kwargs) + + +@contextmanager +def modify_ini(path, fs=None): + """ + NOTE: As configparser does not parse comments, those will be striped + from the modified config file + """ + with _modify_data(path, parse_ini, _dump, fs=fs) as d: + yield d diff --git a/tests/func/api/test_show.py b/tests/func/api/test_show.py index 6b0c0e091b..709827fc83 100644 --- a/tests/func/api/test_show.py +++ b/tests/func/api/test_show.py @@ -30,6 +30,7 @@ def params_repo(tmp_dir, scm, dvc): tmp_dir.gen("params.yaml", "foo: 1") tmp_dir.gen("params.json", '{"bar": 2, "foobar": 3}') + tmp_dir.gen("params.cfg", "[section.sub]\nfoo = 1") tmp_dir.gen("other_params.json", '{"foo": {"bar": 4}}') dvc.run( @@ -40,7 +41,7 @@ def params_repo(tmp_dir, scm, dvc): dvc.run( name="stage-1", cmd="echo stage-1", - params=["foo", "params.json:bar"], + params=["foo", "params.json:bar", "params.cfg:"], ) dvc.run( @@ -59,6 +60,7 @@ def params_repo(tmp_dir, scm, dvc): [ "params.yaml", "params.json", + "params.cfg", "other_params.json", "dvc.yaml", "dvc.lock", @@ -149,6 +151,7 @@ def test_params_show_no_args(params_repo): "bar": 2, "foobar": 3, "other_params.json:foo": {"bar": 4}, + "section": {"sub": {"foo": 1}}, } @@ -159,7 +162,11 @@ def test_params_show_targets(params_repo): "bar": 2, "foobar": 3, } - assert api.params_show("params.yaml", stages="stage-1") == {"bar": 2, "foo": 5} + assert api.params_show("params.yaml", stages="stage-1") == { + "bar": 2, + "foo": 5, + "section": {"sub": {"foo": 1}}, + } def test_params_show_deps(params_repo): @@ -169,6 +176,7 @@ def test_params_show_deps(params_repo): "bar": 2, "foobar": 3, "other_params.json:foo": {"bar": 4}, + "section": {"sub": {"foo": 1}}, } @@ -211,6 +219,7 @@ def test_params_show_revs(params_repo): "bar": 2, "foobar": 3, "other_params.json:foo": {"bar": 4}, + "section": {"sub": {"foo": 1}}, } diff --git a/tests/func/utils/test_hydra.py b/tests/func/utils/test_hydra.py index 6a448e2a93..403ba9655c 100644 --- a/tests/func/utils/test_hydra.py +++ b/tests/func/utils/test_hydra.py @@ -112,6 +112,38 @@ def test_apply_overrides(tmp_dir, suffix, overrides, expected): assert params_file.parse() == expected +def test_apply_ini_overrides(tmp_dir): + from dvc.utils.hydra import apply_overrides + + params_file = tmp_dir / "params.ini" + params_file.dump( + { + "section": { + "foo": [{"bar": 1}, {"baz": 2}], + "goo": {"bag": 3.0}, + "lorem": False, + }, + } + ) + + # Test with valid overrides + valid_overrides = ["section.foo=baz", "section.goo=bar", "+section.items=[1,2,3]"] + apply_overrides(path=params_file.name, overrides=valid_overrides) + assert params_file.parse() == { + "section": { + "foo": "baz", + "goo": "bar", + "lorem": False, + "items": [1, 2, 3], + } + } + + # Test with invalid overrides + invalid_overrides = ["foobar=2", "lorem=3,2", "+lorem=3", "foo[0]=bar"] + with pytest.raises(InvalidArgumentError): + apply_overrides(path=params_file.name, overrides=invalid_overrides) + + @pytest.mark.parametrize( "overrides", [["foobar=2"], ["lorem=3,2"], ["+lorem=3"], ["foo[0]=bar"]], @@ -140,7 +172,7 @@ def hydra_setup(tmp_dir, config_dir, config_name): return str(config_dir) -@pytest.mark.parametrize("suffix", ["yaml", "toml", "json"]) +@pytest.mark.parametrize("suffix", ["yaml", "toml", "json", "ini"]) @pytest.mark.parametrize( "overrides,expected", [ diff --git a/tests/unit/utils/serialize/test_ini.py b/tests/unit/utils/serialize/test_ini.py new file mode 100644 index 0000000000..f522a9cb54 --- /dev/null +++ b/tests/unit/utils/serialize/test_ini.py @@ -0,0 +1,51 @@ +import pytest + +from dvc.utils.serialize import ConfigFileCorruptedError + + +def test_update(tmp_dir): + from dvc.utils.serialize._ini import modify_ini + + contents_fmt = """\ +#A Title +[section.foo] +bar = 42 +baz = [1, 2] +ref = ${section.foo.bar} + +[section.foo."bar.baz".qux] +value = 42 +""" + tmp_dir.gen("params.ini", contents_fmt) + + with modify_ini("params.ini") as d: + d["section"]["foo"]["bar"] //= 2 + assert ( + (tmp_dir / "params.ini").read_text() + == """\ +[section.foo] +bar = 21 +baz = [1, 2] +ref = ${section.foo.bar} + +[section.foo.'bar.baz'.qux] +value = 42 + +""" + ) + + +def test_parse_error(): + from dvc.utils.serialize._ini import parse_ini + + contents = "# A Title [foo]\nbar = 42# meaning of life\nbaz = [1, 2]\n" + + with pytest.raises(ConfigFileCorruptedError): + parse_ini(contents, ".") + + +def test_split_path(): + from dvc.utils.serialize._ini import split_path + + assert split_path("foo.bar") == ["foo", "bar"] + assert split_path('foo."bar.test"') == ["foo", "bar.test"]