Skip to content

Commit

Permalink
exp run: throw error if new parameters are added by --set-param (#6521)
Browse files Browse the repository at this point in the history
* throw error if new parameters are added by dvc run exp -S

Throws an error if dvc run exp --set-param references a parameter key path
that is not in the target config file. This is done to prevent silent errors
in experimentation when new params are accidentally added instead of modifying
current ones due to typos.

This implementation is based on initial suggestions in issue #5477.

* fix mypy confusion from overwriting a var

* tests: remove modify param test that was adding params

One of the test in test_modify_params was still adding a parameter

* tests: remove param modification during exp run in test_untracked

Spurious usage of modifying a parameter when the params.yaml is not yet existing in test_untracked.

* remove benedict dependency for verifying params

Instead, the keypath util module from benedict is adapted into the dvc codebased and used
natively to verify that a keypath exists within a dict.

* exp: list all invalid param modifications in error message

* exp-run: more specific exception for new params via --set-param

Uses established MissingPramsError instead of DvcException. Modified message to be more similar to other uses of this exception.

* Revert "remove benedict dependency for verifying params"

This reverts commit 6170f16.

* exp: use internal benedict import for validating params

* exp: combine new param validation into merge_params

Moves new param validation into existing merge_params function as an optional check

* exp: minor ammendments to error message

* Revert "tests: remove param modification during exp run in test_untracked"

This reverts commit c9eb1a6.

* exp: fix test_untracked to not fail due to new params

Fixed by creating the params file with the new parameter already present.
  • Loading branch information
mattlbeck authored Oct 19, 2021
1 parent 966dc8a commit 26f435a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 9 deletions.
19 changes: 17 additions & 2 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from funcy import cached_property, first

from dvc.dependency.param import MissingParamsError
from dvc.env import DVCLIVE_RESUME
from dvc.exceptions import DvcException
from dvc.path_info import PathInfo
Expand Down Expand Up @@ -350,9 +351,19 @@ def _pack_args(self, *args, **kwargs):
)
self.scm.add(self.args_file)

def _format_new_params_msg(self, new_params, config_path):
"""Format an error message for when new parameters are identified"""
new_param_count = len(new_params)
pluralise = "s are" if new_param_count > 1 else " is"
param_list = ", ".join(new_params)
return (
f"{new_param_count} parameter{pluralise} missing "
f"from '{config_path}': {param_list}"
)

def _update_params(self, params: dict):
"""Update experiment params files with the specified values."""
from dvc.utils.collections import merge_params
from dvc.utils.collections import NewParamsFound, merge_params
from dvc.utils.serialize import MODIFIERS

logger.debug("Using experiment params '%s'", params)
Expand All @@ -362,7 +373,11 @@ def _update_params(self, params: dict):
suffix = path.suffix.lower()
modify_data = MODIFIERS[suffix]
with modify_data(path, fs=self.repo.fs) as data:
merge_params(data, params[params_fname])
try:
merge_params(data, params[params_fname], allow_new=False)
except NewParamsFound as e:
msg = self._format_new_params_msg(e.new_params, path)
raise MissingParamsError(msg)

# Force params file changes to be staged in git
# Otherwise in certain situations the changes to params file may be
Expand Down
30 changes: 28 additions & 2 deletions dvc/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@

from pygtrie import StringTrie as _StringTrie

from dvc.exceptions import DvcException


class NewParamsFound(DvcException):
"""Thrown if new params were found during merge_params"""

def __init__(self, new_params: List, *args):
self.new_params = new_params
super().__init__("New params found during merge", *args)


class PathStringTrie(_StringTrie):
"""Trie based on platform-dependent separator for pathname components."""
Expand Down Expand Up @@ -80,11 +90,27 @@ def chunk_dict(d: Dict[_KT, _VT], size: int = 1) -> List[Dict[_KT, _VT]]:
return [{key: d[key] for key in chunk} for chunk in chunks(size, d)]


def merge_params(src: Dict, to_update: Dict) -> Dict:
"""Recursively merges params with benedict's syntax support in-place."""
def merge_params(src: Dict, to_update: Dict, allow_new: bool = True) -> Dict:
"""
Recursively merges params with benedict's syntax support in-place.
Args:
src (dict): source dictionary of parameters
to_update (dict): dictionary of parameters to merge into src
allow_new (bool): if False, raises an error if new keys would be
added to src
"""
from ._benedict import benedict

data = benedict(src)

if not allow_new:
new_params = list(
set(to_update.keys()) - set(data.keypaths(indexes=True))
)
if new_params:
raise NewParamsFound(new_params)

data.merge(to_update, overwrite=True)
return src

Expand Down
29 changes: 24 additions & 5 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from funcy import first

from dvc.dvcfile import PIPELINE_FILE
from dvc.exceptions import DvcException
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.utils.serialize import PythonFileCorruptedError
from tests.func.test_repro_multistage import COPY_SCRIPT
Expand Down Expand Up @@ -118,10 +119,6 @@ def test_failed_exp(tmp_dir, scm, dvc, exp_stage, mocker, caplog):
["foo[1]=[baz, goo]"],
"{foo: [bar: 1, [baz, goo]], goo: {bag: 3}, lorem: false}",
],
[
["lorem.ipsum=3"],
"{foo: [bar: 1, baz: 2], goo: {bag: 3}, lorem: {ipsum: 3}}",
],
],
)
def test_modify_params(tmp_dir, scm, dvc, mocker, changes, expected):
Expand All @@ -148,6 +145,28 @@ def test_modify_params(tmp_dir, scm, dvc, mocker, changes, expected):
assert fobj.read().strip() == expected


@pytest.mark.parametrize(
"changes",
[["lorem.ipsum=3"], ["foo[0].bazar=3"]],
)
def test_add_params(tmp_dir, scm, dvc, changes):
tmp_dir.gen("copy.py", COPY_SCRIPT)
tmp_dir.gen(
"params.yaml", "{foo: [bar: 1, baz: 2], goo: {bag: 3}, lorem: false}"
)
stage = dvc.run(
cmd="python copy.py params.yaml metrics.yaml",
metrics_no_cache=["metrics.yaml"],
params=["foo", "goo", "lorem"],
name="copy-file",
)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"])
scm.commit("init")

with pytest.raises(DvcException):
dvc.experiments.run(stage.addressing, params=changes)


@pytest.mark.parametrize("queue", [True, False])
def test_apply(tmp_dir, scm, dvc, exp_stage, queue):
from dvc.exceptions import InvalidArgumentError
Expand Down Expand Up @@ -387,7 +406,7 @@ def test_no_scm(tmp_dir):
@pytest.mark.parametrize("workspace", [True, False])
def test_untracked(tmp_dir, scm, dvc, caplog, workspace):
tmp_dir.gen("copy.py", COPY_SCRIPT)
tmp_dir.gen("params.yaml", "foo: 1")
tmp_dir.scm_gen("params.yaml", "foo: 1", commit="track params")
stage = dvc.run(
cmd="python copy.py params.yaml metrics.yaml",
metrics_no_cache=["metrics.yaml"],
Expand Down

0 comments on commit 26f435a

Please sign in to comment.