diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 4c5309c1755..d61bf34c175 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -19,22 +19,28 @@ from funcy import cached_property -from dvc.dependency.param import MissingParamsError from dvc.env import DVCLIVE_RESUME from dvc.exceptions import DvcException -from dvc.ui import ui - -from ..exceptions import CheckpointExistsError, ExperimentExistsError -from ..executor.base import ( +from dvc.repo.experiments.exceptions import ( + CheckpointExistsError, + ExperimentExistsError, +) +from dvc.repo.experiments.executor.base import ( EXEC_PID_DIR, EXEC_TMP_DIR, BaseExecutor, ExecutorResult, ) -from ..executor.local import WorkspaceExecutor -from ..refs import EXEC_BASELINE, EXEC_HEAD, EXEC_MERGE, ExpRefInfo -from ..stash import ExpStash, ExpStashEntry -from ..utils import exp_refs_by_rev, scm_locked +from dvc.repo.experiments.executor.local import WorkspaceExecutor +from dvc.repo.experiments.refs import ( + EXEC_BASELINE, + EXEC_HEAD, + EXEC_MERGE, + ExpRefInfo, +) +from dvc.repo.experiments.stash import ExpStash, ExpStashEntry +from dvc.repo.experiments.utils import exp_refs_by_rev, scm_locked +from dvc.ui import ui if TYPE_CHECKING: from scmrepo.git import Git @@ -283,7 +289,7 @@ def logs( def _stash_exp( self, *args, - params: Optional[dict] = None, + params: Optional[Dict[str, List[str]]] = None, resume_rev: Optional[str] = None, baseline_rev: Optional[str] = None, branch: Optional[str] = None, @@ -292,10 +298,9 @@ def _stash_exp( ) -> QueueEntry: """Stash changes from the workspace as an experiment. - Arguments: - params: Optional dictionary of parameter values to be used. - Values take priority over any parameters specified in the - user's workspace. + Args: + params: Dict mapping paths to `Hydra Override`_ patterns, + provided via `exp run --set-param`. resume_rev: Optional checkpoint resume rev. baseline_rev: Optional baseline rev for this experiment, defaults to the current SCM rev. @@ -305,6 +310,9 @@ def _stash_exp( name: Optional experiment name. If specified this will be used as the human-readable name in the experiment branch ref. Has no effect of branch is specified. + + .. _Hydra Override: + https://hydra.cc/docs/next/advanced/override_grammar/basic/ """ with self.scm.detach_head(client="dvc") as orig_head: stash_head = orig_head @@ -508,22 +516,22 @@ def _format_new_params_msg(new_params, config_path): f"from '{config_path}': {param_list}" ) - def _update_params(self, params: dict): - """Update experiment params files with the specified values.""" - from dvc.utils.collections import NewParamsFound, merge_params - from dvc.utils.serialize import MODIFIERS + def _update_params(self, params: Dict[str, List[str]]): + """Update param files with the provided `Hydra Override`_ patterns. + + Args: + params: Dict mapping paths to `Hydra Override`_ patterns, + provided via `exp run --set-param`. + .. _Hydra Override: + https://hydra.cc/docs/next/advanced/override_grammar/basic/ + """ logger.debug("Using experiment params '%s'", params) - for path in params: - suffix = self.repo.fs.path.suffix(path).lower() - modify_data = MODIFIERS[suffix] - with modify_data(path, fs=self.repo.fs) as data: - try: - merge_params(data, params[path], allow_new=False) - except NewParamsFound as e: - msg = self._format_new_params_msg(e.new_params, path) - raise MissingParamsError(msg) + from dvc.utils.hydra import apply_overrides + + for path, overrides in params.items(): + apply_overrides(path, overrides) # Force params file changes to be staged in git # Otherwise in certain situations the changes to params file may be diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index b1a601b0f70..5853293417a 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -3,7 +3,7 @@ from dvc.repo import locked from dvc.ui import ui -from dvc.utils.cli_parse import loads_param_overrides +from dvc.utils.cli_parse import to_path_overrides logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def run( return repo.experiments.reproduce_celery(entries, jobs=jobs) if params: - params = loads_param_overrides(params) + params = to_path_overrides(params) if queue: if not kwargs.get("checkpoint_resume", None): diff --git a/dvc/utils/_benedict.py b/dvc/utils/_benedict.py deleted file mode 100644 index 219c96e674a..00000000000 --- a/dvc/utils/_benedict.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Rollbacks monkeypatching of `json.encoder` by benedict. - -It monkeypatches json to use Python-based encoder instead of C-based -during import. - -We rollback that monkeypatch by keeping reference to that C-based -encoder and reinstate them after importing benedict. -See: https://github.com/iterative/dvc/issues/6423 - https://github.com/fabiocaccamo/python-benedict/issues/62 -and the source of truth: -https://github.com/fabiocaccamo/python-benedict/blob/c98c471065/benedict/dicts/__init__.py#L282-L285 -""" -from json import encoder - -try: - c_make_encoder = encoder.c_make_encoder # type: ignore[attr-defined] -except AttributeError: - c_make_encoder = None - - -from benedict import benedict # noqa: E402 - -encoder.c_make_encoder = c_make_encoder # type: ignore[attr-defined] -# Please import benedict from here lazily -__all__ = ["benedict"] diff --git a/dvc/utils/cli_parse.py b/dvc/utils/cli_parse.py index c2885484681..4972ec53f93 100644 --- a/dvc/utils/cli_parse.py +++ b/dvc/utils/cli_parse.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Dict, Iterable, List +from typing import Dict, Iterable, List def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]: @@ -17,35 +17,22 @@ def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]: return [{path: params} for path, params in ret.items()] -def loads_param_overrides( +def to_path_overrides( path_params: Iterable[str], -) -> Dict[str, Dict[str, Any]]: - """Loads the content of params from the cli as Python object.""" - from ruamel.yaml import YAMLError - +) -> Dict[str, List[str]]: + """Group overrides by path""" from dvc.dependency.param import ParamsDependency - from dvc.exceptions import InvalidArgumentError - - from .serialize import loads_yaml - - ret: Dict[str, Dict[str, Any]] = defaultdict(dict) + path_overrides = defaultdict(list) for path_param in path_params: - param_name, _, param_value = path_param.partition("=") - if not param_value: - raise InvalidArgumentError( - f"Must provide a value for parameter '{param_name}'" - ) - path, _, param_name = param_name.partition(":") - if not param_name: - param_name = path + + path_and_name = path_param.partition("=")[0] + if ":" not in path_and_name: + override = path_param path = ParamsDependency.DEFAULT_PARAMS_FILE + else: + path, _, override = path_param.partition(":") - try: - ret[path][param_name] = loads_yaml(param_value) - except (ValueError, YAMLError): - raise InvalidArgumentError( - f"Invalid parameter value for '{param_name}': '{param_value}" - ) + path_overrides[path].append(override) - return ret + return dict(path_overrides) diff --git a/dvc/utils/collections.py b/dvc/utils/collections.py index 8e2b62c66dc..8b36e92e575 100644 --- a/dvc/utils/collections.py +++ b/dvc/utils/collections.py @@ -3,16 +3,6 @@ from functools import wraps from typing import Callable, Dict, Iterable, List, TypeVar, Union -from dvc.exceptions import DvcException - - -class NewParamsFound(DvcException): - """Thrown if new params were found during merge_params""" - - def __init__(self, new_params: List, *args): - self.new_params = new_params - super().__init__("New params found during merge", *args) - def apply_diff(src, dest): """Recursively apply changes from src to dest. @@ -61,6 +51,53 @@ def is_same_type(a, b): ) +def to_omegaconf(item): + """ + Some parsers return custom classes (i.e. parse_yaml_for_update) + that can mess up with omegaconf logic. + Cast the custom classes to Python primitives. + """ + if isinstance(item, dict): + item = {k: to_omegaconf(v) for k, v in item.items()} + elif isinstance(item, list): + item = [to_omegaconf(x) for x in item] + return item + + +def remove_missing_keys(src, to_update): + keys = list(src.keys()) + for key in keys: + if key not in to_update: + del src[key] + elif isinstance(src[key], dict): + remove_missing_keys(src[key], to_update[key]) + + return src + + +def _merge_item(d, key, value): + if key in d: + item = d.get(key, None) + if isinstance(item, dict) and isinstance(value, dict): + merge_dicts(item, value) + else: + d[key] = value + else: + d[key] = value + + +def merge_dicts(src: Dict, to_update: Dict) -> Dict: + """Recursively merges dictionaries. + + Args: + src (dict): source dictionary of parameters + to_update (dict): dictionary of parameters to merge into src + """ + for key, value in to_update.items(): + _merge_item(src, key, value) + return src + + def ensure_list(item: Union[Iterable[str], str, None]) -> List[str]: if item is None: return [] @@ -79,31 +116,6 @@ def chunk_dict(d: Dict[_KT, _VT], size: int = 1) -> List[Dict[_KT, _VT]]: return [{key: d[key] for key in chunk} for chunk in chunks(size, d)] -def merge_params(src: Dict, to_update: Dict, allow_new: bool = True) -> Dict: - """ - Recursively merges params with benedict's syntax support in-place. - - Args: - src (dict): source dictionary of parameters - to_update (dict): dictionary of parameters to merge into src - allow_new (bool): if False, raises an error if new keys would be - added to src - """ - from ._benedict import benedict - - data = benedict(src) - - if not allow_new: - new_params = list( - set(to_update.keys()) - set(data.keypaths(indexes=True)) - ) - if new_params: - raise NewParamsFound(new_params) - - data.merge(to_update, overwrite=True) - return src - - class _NamespacedDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/dvc/utils/hydra.py b/dvc/utils/hydra.py new file mode 100644 index 00000000000..e03be3e1c62 --- /dev/null +++ b/dvc/utils/hydra.py @@ -0,0 +1,53 @@ +from pathlib import Path +from typing import TYPE_CHECKING, List + +from hydra._internal.config_loader_impl import ConfigLoaderImpl +from hydra.core.override_parser.overrides_parser import OverridesParser +from hydra.errors import ConfigCompositionException, OverrideParseException +from hydra.types import RunMode +from omegaconf import OmegaConf + +from dvc.exceptions import InvalidArgumentError + +from .collections import merge_dicts, remove_missing_keys, to_omegaconf +from .serialize import MODIFIERS + +if TYPE_CHECKING: + from dvc.types import StrPath + + +def apply_overrides(path: "StrPath", overrides: List[str]) -> None: + """Update `path` params with the provided `Hydra Override`_ patterns. + + Args: + overrides: List of `Hydra Override`_ patterns. + + .. _Hydra Override: + https://hydra.cc/docs/next/advanced/override_grammar/basic/ + """ + suffix = Path(path).suffix.lower() + + hydra_errors = (ConfigCompositionException, OverrideParseException) + + modify_data = MODIFIERS[suffix] + with modify_data(path) as original_data: + try: + parser = OverridesParser.create() + parsed = parser.parse_overrides(overrides=overrides) + ConfigLoaderImpl.validate_sweep_overrides_legal( + parsed, run_mode=RunMode.RUN, from_shell=True + ) + + new_data = OmegaConf.create( + to_omegaconf(original_data), + flags={"allow_objects": True}, + ) + OmegaConf.set_struct(new_data, True) + # pylint: disable=protected-access + ConfigLoaderImpl._apply_overrides_to_config(parsed, new_data) + new_data = OmegaConf.to_object(new_data) + except hydra_errors as e: + raise InvalidArgumentError("Invalid `--set-param` value") from e + + merge_dicts(original_data, new_data) + remove_missing_keys(original_data, new_data) diff --git a/setup.cfg b/setup.cfg index bf55af6ef74..d81eeb37f3c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,7 +60,6 @@ install_requires = dpath>=2.0.2,<3 shtab>=1.3.4,<2 rich>=10.13.0 - python-benedict>=0.24.2 pyparsing>=2.4.7 typing-extensions>=3.7.4 fsspec[http]>=2021.10.1 @@ -70,6 +69,7 @@ install_requires = dvc-task==0.1.0 dvclive>=0.10.0 dvc-data==0.1.8 + hydra-core>=1.1.0 [options.extras_require] all = diff --git a/tests/func/experiments/conftest.py b/tests/func/experiments/conftest.py index cb295c8160d..cffecfb10e6 100644 --- a/tests/func/experiments/conftest.py +++ b/tests/func/experiments/conftest.py @@ -133,3 +133,17 @@ def http_auth_patch(mocker): @pytest.fixture(params=[True, False]) def workspace(request, test_queue) -> bool: # noqa return request.param + + +@pytest.fixture +def params_repo(tmp_dir, scm, dvc): + (tmp_dir / "params.yaml").dump( + {"foo": [{"bar": 1}, {"baz": 2}], "goo": {"bag": 3.0}, "lorem": False} + ) + dvc.run( + cmd="echo foo", + params=["params.yaml:"], + name="foo", + ) + scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml"]) + scm.commit("init") diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 09eb725f003..0f67a1fbc21 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -7,7 +7,7 @@ from funcy import first from dvc.dvcfile import PIPELINE_FILE -from dvc.exceptions import DvcException, ReproductionError +from dvc.exceptions import ReproductionError from dvc.repo.experiments.queue.base import BaseStashQueue from dvc.repo.experiments.utils import exp_refs_by_rev from dvc.scm import resolve_rev @@ -107,69 +107,17 @@ def test_failed_exp_workspace( @pytest.mark.parametrize( "changes, expected", [ - [["foo=baz"], "{foo: baz, goo: {bag: 3}, lorem: false}"], - [["foo=baz", "goo=bar"], "{foo: baz, goo: bar, lorem: false}"], - [ - ["goo.bag=4"], - "{foo: [bar: 1, baz: 2], goo: {bag: 4}, lorem: false}", - ], - [["foo[0]=bar"], "{foo: [bar, baz: 2], goo: {bag: 3}, lorem: false}"], - [ - ["foo[1].baz=3"], - "{foo: [bar: 1, baz: 3], goo: {bag: 3}, lorem: false}", - ], - [ - ["foo[1]=[baz, goo]"], - "{foo: [bar: 1, [baz, goo]], goo: {bag: 3}, lorem: false}", - ], + [["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], + [["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], ], ) -def test_modify_params(tmp_dir, scm, dvc, mocker, changes, expected): - tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen( - "params.yaml", "{foo: [bar: 1, baz: 2], goo: {bag: 3}, lorem: false}" - ) - stage = dvc.run( - cmd="python copy.py params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo", "goo", "lorem"], - name="copy-file", - ) - scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) - scm.commit("init") - - new_mock = mocker.spy(dvc.experiments, "new") - results = dvc.experiments.run(stage.addressing, params=changes) - exp = first(results) - - new_mock.assert_called_once() - fs = scm.get_fs(exp) - with fs.open("metrics.yaml", mode="r") as fobj: +def test_modify_params(params_repo, dvc, changes, expected): + dvc.experiments.run(params=changes) + # pylint: disable=unspecified-encoding + with open("params.yaml", mode="r") as fobj: assert fobj.read().strip() == expected -@pytest.mark.parametrize( - "changes", - [["lorem.ipsum=3"], ["foo[0].bazar=3"]], -) -def test_add_params(tmp_dir, scm, dvc, changes): - tmp_dir.gen("copy.py", COPY_SCRIPT) - tmp_dir.gen( - "params.yaml", "{foo: [bar: 1, baz: 2], goo: {bag: 3}, lorem: false}" - ) - stage = dvc.run( - cmd="python copy.py params.yaml metrics.yaml", - metrics_no_cache=["metrics.yaml"], - params=["foo", "goo", "lorem"], - name="copy-file", - ) - scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"]) - scm.commit("init") - - with pytest.raises(DvcException): - dvc.experiments.run(stage.addressing, params=changes) - - def test_apply(tmp_dir, scm, dvc, exp_stage): from dvc.exceptions import InvalidArgumentError from dvc.repo.experiments.exceptions import ApplyConflictError diff --git a/tests/func/utils/test_hydra.py b/tests/func/utils/test_hydra.py new file mode 100644 index 00000000000..5fc283cc264 --- /dev/null +++ b/tests/func/utils/test_hydra.py @@ -0,0 +1,97 @@ +import pytest + +from dvc.exceptions import InvalidArgumentError +from dvc.utils.hydra import apply_overrides + + +@pytest.mark.parametrize("suffix", ["yaml", "toml", "json"]) +@pytest.mark.parametrize( + "overrides, expected", + [ + # Overriding + (["foo=baz"], {"foo": "baz", "goo": {"bag": 3.0}, "lorem": False}), + (["foo=baz", "goo=bar"], {"foo": "baz", "goo": "bar", "lorem": False}), + ( + ["foo.0=bar"], + {"foo": ["bar", {"baz": 2}], "goo": {"bag": 3.0}, "lorem": False}, + ), + ( + ["foo.1.baz=3"], + { + "foo": [{"bar": 1}, {"baz": 3}], + "goo": {"bag": 3.0}, + "lorem": False, + }, + ), + ( + ["goo.bag=4.0"], + { + "foo": [{"bar": 1}, {"baz": 2}], + "goo": {"bag": 4.0}, + "lorem": False, + }, + ), + ( + ["++goo={bag: 1, b: 2}"], + { + "foo": [{"bar": 1}, {"baz": 2}], + "goo": {"bag": 1, "b": 2}, + "lorem": False, + }, + ), + # 6129 + (["foo="], {"foo": "", "goo": {"bag": 3.0}, "lorem": False}), + # 6129 + (["foo=null"], {"foo": None, "goo": {"bag": 3.0}, "lorem": False}), + # 5868 + ( + ["foo=1992-11-20"], + {"foo": "1992-11-20", "goo": {"bag": 3.0}, "lorem": False}, + ), + # 5868 + ( + ["foo='1992-11-20'"], + {"foo": "1992-11-20", "goo": {"bag": 3.0}, "lorem": False}, + ), + # Appending + ( + ["+a=1"], + { + "foo": [{"bar": 1}, {"baz": 2}], + "goo": {"bag": 3.0}, + "lorem": False, + "a": 1, + }, + ), + # Removing + (["~foo"], {"goo": {"bag": 3.0}, "lorem": False}), + ], +) +def test_apply_overrides(tmp_dir, suffix, overrides, expected): + if suffix == "toml": + if overrides == ["foo.0=bar"]: + # TOML dumper breaks when overriding a dict with a string. + pytest.xfail() + elif overrides == ["foo=null"]: + # TOML will discardd null value. + expected = {"goo": {"bag": 3.0}, "lorem": False} + + params_file = tmp_dir / f"params.{suffix}" + params_file.dump( + {"foo": [{"bar": 1}, {"baz": 2}], "goo": {"bag": 3.0}, "lorem": False} + ) + apply_overrides(path=params_file.name, overrides=overrides) + assert params_file.parse() == expected + + +@pytest.mark.parametrize( + "overrides", + [["foobar=2"], ["lorem=3,2"], ["+lorem=3"], ["foo[0]=bar"]], +) +def test_invalid_overrides(tmp_dir, overrides): + params_file = tmp_dir / "params.yaml" + params_file.dump( + {"foo": [{"bar": 1}, {"baz": 2}], "goo": {"bag": 3.0}, "lorem": False} + ) + with pytest.raises(InvalidArgumentError): + apply_overrides(path=params_file.name, overrides=overrides) diff --git a/tests/unit/utils/test_cli_parse.py b/tests/unit/utils/test_cli_parse.py index 267047dc931..d2875a02474 100644 --- a/tests/unit/utils/test_cli_parse.py +++ b/tests/unit/utils/test_cli_parse.py @@ -1,4 +1,6 @@ -from dvc.utils.cli_parse import parse_params +import pytest + +from dvc.utils.cli_parse import parse_params, to_path_overrides def test_parse_params(): @@ -18,3 +20,23 @@ def test_parse_params(): {"file2": ["param2"]}, {"file3": []}, ] + + +@pytest.mark.parametrize( + "params,expected", + [ + (["foo=1"], {"params.yaml": ["foo=1"]}), + (["foo={bar: 1}"], {"params.yaml": ["foo={bar: 1}"]}), + (["foo.0=bar"], {"params.yaml": ["foo.0=bar"]}), + (["params.json:foo={bar: 1}"], {"params.json": ["foo={bar: 1}"]}), + ( + ["params.json:foo={bar: 1}", "baz=2", "goo=3"], + { + "params.json": ["foo={bar: 1}"], + "params.yaml": ["baz=2", "goo=3"], + }, + ), + ], +) +def test_to_path_overrides(params, expected): + assert to_path_overrides(params) == expected diff --git a/tests/unit/utils/test_collections.py b/tests/unit/utils/test_collections.py index c1eaef2655c..1a4918e7282 100644 --- a/tests/unit/utils/test_collections.py +++ b/tests/unit/utils/test_collections.py @@ -1,6 +1,5 @@ # pylint: disable=unidiomatic-typecheck import json -from json import encoder from unittest.mock import create_autospec import pytest @@ -8,7 +7,9 @@ from dvc.utils.collections import ( apply_diff, chunk_dict, - merge_params, + merge_dicts, + remove_missing_keys, + to_omegaconf, validate, ) from dvc.utils.serialize import dumps_yaml @@ -153,6 +154,24 @@ def is_serializable(d): return True +def test_to_omegaconf(): + class CustomDict(dict): + pass + + class CustomList(list): + pass + + data = { + "foo": CustomDict(bar=1, bag=CustomList([1, 2])), + "goo": CustomList([CustomDict(goobar=1)]), + } + new_data = to_omegaconf(data) + assert not isinstance(new_data["foo"], CustomDict) + assert not isinstance(new_data["foo"]["bag"], CustomList) + assert not isinstance(new_data["goo"], CustomList) + assert not isinstance(new_data["goo"][0], CustomDict) + + @pytest.mark.parametrize( "changes, expected", [ @@ -162,11 +181,11 @@ def is_serializable(d): {"foo": "baz", "goo": "bar", "lorem": False}, ], [ - {"goo.bag": 4}, + {"goo": {"bag": 4}}, {"foo": {"bar": 1, "baz": 2}, "goo": {"bag": 4}, "lorem": False}, ], [ - {"foo[0]": "bar"}, + {"foo": {"bar": 1, "baz": 2, 0: "bar"}}, { "foo": {"bar": 1, "baz": 2, 0: "bar"}, "goo": {"bag": 3}, @@ -174,23 +193,7 @@ def is_serializable(d): }, ], [ - {"foo[1].baz": 3}, - { - "foo": {"bar": 1, "baz": 2, 1: {"baz": 3}}, - "goo": {"bag": 3}, - "lorem": False, - }, - ], - [ - {"foo[1]": ["baz", "goo"]}, - { - "foo": {"bar": 1, "baz": 2, 1: ["baz", "goo"]}, - "goo": {"bag": 3}, - "lorem": False, - }, - ], - [ - {"lorem.ipsum": 3}, + {"lorem": {"ipsum": 3}}, { "foo": {"bar": 1, "baz": 2}, "goo": {"bag": 3}, @@ -200,35 +203,28 @@ def is_serializable(d): [{}, {"foo": {"bar": 1, "baz": 2}, "goo": {"bag": 3}, "lorem": False}], ], ) -def test_merge_params(changes, expected): +def test_merge_dicts(changes, expected): params = {"foo": {"bar": 1, "baz": 2}, "goo": {"bag": 3}, "lorem": False} - merged = merge_params(params, changes) + merged = merge_dicts(params, changes) assert merged == expected == params assert params is merged # references should be preserved - assert encoder.c_make_encoder assert is_serializable(params) @pytest.mark.parametrize( "changes, expected", [ - [{"foo": "baz"}, {"foo": "baz"}], - [{"foo": "baz", "goo": "bar"}, {"foo": "baz", "goo": "bar"}], - [{"foo[1]": ["baz", "goo"]}, {"foo": [None, ["baz", "goo"]]}], - [{"foo.bar": "baz"}, {"foo": {"bar": "baz"}}], + [{"foo": "baz"}, {"foo": {"baz": 2}}], + [ + {"foo": "baz", "goo": "bag"}, + {"foo": {"baz": 2}, "goo": {"bag": 3}}, + ], + [{}, {}], ], ) -def test_merge_params_on_empty_src(changes, expected): - params = {} - merged = merge_params(params, changes) - assert merged == expected == params - assert params is merged # references should be preserved - assert encoder.c_make_encoder +def test_remove_missing_keys(changes, expected): + params = {"foo": {"bar": 1, "baz": 2}, "goo": {"bag": 3}, "lorem": False} + removed = remove_missing_keys(params, changes) + assert removed == expected == params + assert params is removed # references should be preserved assert is_serializable(params) - - -def test_benedict_rollback_its_monkeypatch(): - from dvc.utils._benedict import benedict - - assert benedict({"foo": "foo"}) == {"foo": "foo"} - assert encoder.c_make_encoder