Skip to content

Commit

Permalink
exp run: Use hydra for parsing --set-param.
Browse files Browse the repository at this point in the history
- Overriding a config value : foo.bar=value
- Appending a config value : +foo.bar=value
- Appending or overriding a config value : ++foo.bar=value
- Removing a config value : ~foo.bar, ~foo.bar=value

See https://hydra.cc/docs/advanced/override_grammar/basic/#modifying-the-config-object

---

Breaking changes:

To modify a list, `foo[0]=bar` must now be passed as `foo.0=bar`.

Modifying a nested list inside a dictionary is not supported by omegaconf.

---

Closes #4883
Closes #5868
Closes #6129
  • Loading branch information
daavoo committed Aug 5, 2022
1 parent 45ac64d commit d09c3fd
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 210 deletions.
39 changes: 20 additions & 19 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from funcy import cached_property

from dvc.dependency.param import MissingParamsError
from dvc.env import DVCLIVE_RESUME
from dvc.exceptions import DvcException
from dvc.ui import ui
Expand Down Expand Up @@ -283,7 +282,7 @@ def logs(
def _stash_exp(
self,
*args,
params: Optional[dict] = None,
params: Optional[Dict[str, List[str]]] = None,
resume_rev: Optional[str] = None,
baseline_rev: Optional[str] = None,
branch: Optional[str] = None,
Expand All @@ -292,10 +291,9 @@ def _stash_exp(
) -> QueueEntry:
"""Stash changes from the workspace as an experiment.
Arguments:
params: Optional dictionary of parameter values to be used.
Values take priority over any parameters specified in the
user's workspace.
Args:
params: Dict mapping paths to `Hydra Override`_ patterns,
provided via `exp run --set-param`.
resume_rev: Optional checkpoint resume rev.
baseline_rev: Optional baseline rev for this experiment, defaults
to the current SCM rev.
Expand All @@ -305,6 +303,9 @@ def _stash_exp(
name: Optional experiment name. If specified this will be used as
the human-readable name in the experiment branch ref. Has no
effect of branch is specified.
.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
"""
with self.scm.detach_head(client="dvc") as orig_head:
stash_head = orig_head
Expand Down Expand Up @@ -508,22 +509,22 @@ def _format_new_params_msg(new_params, config_path):
f"from '{config_path}': {param_list}"
)

def _update_params(self, params: dict):
"""Update experiment params files with the specified values."""
from dvc.utils.collections import NewParamsFound, merge_params
from dvc.utils.serialize import MODIFIERS
def _update_params(self, params: Dict[str, List[str]]):
"""Update param files with the provided `Hydra Override`_ patterns.
Args:
params: Dict mapping paths to `Hydra Override`_ patterns,
provided via `exp run --set-param`.
.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
"""
logger.debug("Using experiment params '%s'", params)

for path in params:
suffix = self.repo.fs.path.suffix(path).lower()
modify_data = MODIFIERS[suffix]
with modify_data(path, fs=self.repo.fs) as data:
try:
merge_params(data, params[path], allow_new=False)
except NewParamsFound as e:
msg = self._format_new_params_msg(e.new_params, path)
raise MissingParamsError(msg)
from dvc.utils.hydra import apply_overrides

for path, overrides in params.items():
apply_overrides(path, overrides)

# Force params file changes to be staged in git
# Otherwise in certain situations the changes to params file may be
Expand Down
4 changes: 2 additions & 2 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dvc.repo import locked
from dvc.ui import ui
from dvc.utils.cli_parse import loads_param_overrides
from dvc.utils.cli_parse import to_path_overrides

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -31,7 +31,7 @@ def run(
return repo.experiments.reproduce_celery(entries, jobs=jobs)

if params:
params = loads_param_overrides(params)
params = to_path_overrides(params)

if queue:
if not kwargs.get("checkpoint_resume", None):
Expand Down
26 changes: 0 additions & 26 deletions dvc/utils/_benedict.py

This file was deleted.

39 changes: 13 additions & 26 deletions dvc/utils/cli_parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Dict, Iterable, List
from typing import Dict, Iterable, List


def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]:
Expand All @@ -17,35 +17,22 @@ def parse_params(path_params: Iterable[str]) -> List[Dict[str, List[str]]]:
return [{path: params} for path, params in ret.items()]


def loads_param_overrides(
def to_path_overrides(
path_params: Iterable[str],
) -> Dict[str, Dict[str, Any]]:
"""Loads the content of params from the cli as Python object."""
from ruamel.yaml import YAMLError

) -> Dict[str, List[str]]:
"""Group overrides by path"""
from dvc.dependency.param import ParamsDependency
from dvc.exceptions import InvalidArgumentError

from .serialize import loads_yaml

ret: Dict[str, Dict[str, Any]] = defaultdict(dict)

path_overrides = defaultdict(list)
for path_param in path_params:
param_name, _, param_value = path_param.partition("=")
if not param_value:
raise InvalidArgumentError(
f"Must provide a value for parameter '{param_name}'"
)
path, _, param_name = param_name.partition(":")
if not param_name:
param_name = path

path_and_name = path_param.partition("=")[0]
if ":" not in path_and_name:
override = path_param
path = ParamsDependency.DEFAULT_PARAMS_FILE
else:
path, _, override = path_param.partition(":")

try:
ret[path][param_name] = loads_yaml(param_value)
except (ValueError, YAMLError):
raise InvalidArgumentError(
f"Invalid parameter value for '{param_name}': '{param_value}"
)
path_overrides[path].append(override)

return ret
return dict(path_overrides)
82 changes: 47 additions & 35 deletions dvc/utils/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from functools import wraps
from typing import Callable, Dict, Iterable, List, TypeVar, Union

from dvc.exceptions import DvcException


class NewParamsFound(DvcException):
"""Thrown if new params were found during merge_params"""

def __init__(self, new_params: List, *args):
self.new_params = new_params
super().__init__("New params found during merge", *args)


def apply_diff(src, dest):
"""Recursively apply changes from src to dest.
Expand Down Expand Up @@ -61,6 +51,53 @@ def is_same_type(a, b):
)


def to_omegaconf(item):
"""
Some parsers return custom classes (i.e. parse_yaml_for_update)
that can mess up with omegaconf logic.
Cast the custom classes to Python primitives.
"""
if isinstance(item, dict):
item = {k: to_omegaconf(v) for k, v in item.items()}
elif isinstance(item, list):
item = [to_omegaconf(x) for x in item]
return item


def remove_missing_keys(src, to_update):
keys = list(src.keys())
for key in keys:
if key not in to_update:
del src[key]
elif isinstance(src[key], dict):
remove_missing_keys(src[key], to_update[key])

return src


def _merge_item(d, key, value):
if key in d:
item = d.get(key, None)
if isinstance(item, dict) and isinstance(value, dict):
merge_dicts(item, value)
else:
d[key] = value
else:
d[key] = value


def merge_dicts(src: Dict, to_update: Dict) -> Dict:
"""Recursively merges dictionaries.
Args:
src (dict): source dictionary of parameters
to_update (dict): dictionary of parameters to merge into src
"""
for key, value in to_update.items():
_merge_item(src, key, value)
return src


def ensure_list(item: Union[Iterable[str], str, None]) -> List[str]:
if item is None:
return []
Expand All @@ -79,31 +116,6 @@ def chunk_dict(d: Dict[_KT, _VT], size: int = 1) -> List[Dict[_KT, _VT]]:
return [{key: d[key] for key in chunk} for chunk in chunks(size, d)]


def merge_params(src: Dict, to_update: Dict, allow_new: bool = True) -> Dict:
"""
Recursively merges params with benedict's syntax support in-place.
Args:
src (dict): source dictionary of parameters
to_update (dict): dictionary of parameters to merge into src
allow_new (bool): if False, raises an error if new keys would be
added to src
"""
from ._benedict import benedict

data = benedict(src)

if not allow_new:
new_params = list(
set(to_update.keys()) - set(data.keypaths(indexes=True))
)
if new_params:
raise NewParamsFound(new_params)

data.merge(to_update, overwrite=True)
return src


class _NamespacedDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
53 changes: 53 additions & 0 deletions dvc/utils/hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pathlib import Path
from typing import TYPE_CHECKING, List

from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.errors import ConfigCompositionException, OverrideParseException
from hydra.types import RunMode
from omegaconf import OmegaConf

from dvc.exceptions import InvalidArgumentError

from .collections import merge_dicts, remove_missing_keys, to_omegaconf
from .serialize import MODIFIERS

if TYPE_CHECKING:
from dvc.types import StrPath


def apply_overrides(path: "StrPath", overrides: List[str]) -> None:
"""Update `path` params with the provided `Hydra Override`_ patterns.
Args:
overrides: List of `Hydra Override`_ patterns.
.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
"""
suffix = Path(path).suffix.lower()

hydra_errors = (ConfigCompositionException, OverrideParseException)

modify_data = MODIFIERS[suffix]
with modify_data(path) as original_data:
try:
parser = OverridesParser.create()
parsed = parser.parse_overrides(overrides=overrides)
ConfigLoaderImpl.validate_sweep_overrides_legal(
parsed, run_mode=RunMode.RUN, from_shell=True
)

new_data = OmegaConf.create(
to_omegaconf(original_data),
flags={"allow_objects": True},
)
OmegaConf.set_struct(new_data, True)
# pylint: disable=protected-access
ConfigLoaderImpl._apply_overrides_to_config(parsed, new_data)
new_data = OmegaConf.to_object(new_data)
except hydra_errors as e:
raise InvalidArgumentError("Invalid `--set-param` value") from e

merge_dicts(original_data, new_data)
remove_missing_keys(original_data, new_data)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ install_requires =
dpath>=2.0.2,<3
shtab>=1.3.4,<2
rich>=10.13.0
python-benedict>=0.24.2
pyparsing>=2.4.7
typing-extensions>=3.7.4
fsspec[http]>=2021.10.1
Expand All @@ -70,6 +69,7 @@ install_requires =
dvc-task==0.1.2
dvclive>=0.10.0
dvc-data==0.1.13
hydra-core>=1.1.0

[options.extras_require]
all =
Expand Down
14 changes: 14 additions & 0 deletions tests/func/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,17 @@ def http_auth_patch(mocker):
@pytest.fixture(params=[True, False])
def workspace(request, test_queue) -> bool: # noqa
return request.param


@pytest.fixture
def params_repo(tmp_dir, scm, dvc):
(tmp_dir / "params.yaml").dump(
{"foo": [{"bar": 1}, {"baz": 2}], "goo": {"bag": 3.0}, "lorem": False}
)
dvc.run(
cmd="echo foo",
params=["params.yaml:"],
name="foo",
)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml"])
scm.commit("init")
Loading

0 comments on commit d09c3fd

Please sign in to comment.