Skip to content

Commit

Permalink
exp run: Support composing and dumping Hydra config.
Browse files Browse the repository at this point in the history
The feature will be used depending on whether `config.hydra.config_dir` has been set or not.

Uses `hydra.initialize_config_dir` and `hydra.compose` (from https://hydra.cc/docs/advanced/compose_api/) to build the config and dump it to `config.hydra.output_file` (defaults to `params.yaml`). The content of the output file will be overwritten.

`config.hydra.config_dir` and `config.hydra.config_name` can be used to customize the values passed to `hydra.initialize_config_dir` and `hydra.compose`.

Can be combined with `--set-param` overrides.

Closes #8082
  • Loading branch information
daavoo committed Aug 17, 2022
1 parent 9180a71 commit 5d61212
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 21 deletions.
1 change: 1 addition & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,5 @@ class RelPath(str):
"bool": All(Lower, Choices("store_true", "boolean_optional")),
"list": All(Lower, Choices("nargs", "append")),
},
"hydra": {"output_file": str, "config_dir": str, "config_name": str},
}
24 changes: 22 additions & 2 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from funcy import cached_property

from dvc.dependency import ParamsDependency
from dvc.env import DVCLIVE_RESUME
from dvc.exceptions import DvcException
from dvc.ui import ui
Expand Down Expand Up @@ -524,10 +525,29 @@ def _update_params(self, params: Dict[str, List[str]]):
"""
logger.debug("Using experiment params '%s'", params)

from dvc.utils.hydra import apply_overrides
from dvc.utils.hydra import apply_overrides, compose_and_dump

hydra_config = self.repo.config.get("hydra", {})
hydra_output_file = hydra_config.get(
"output_file", ParamsDependency.DEFAULT_PARAMS_FILE
)
for path, overrides in params.items():
apply_overrides(path, overrides)
if (
hydra_config.get("config_dir") is not None
and path == hydra_output_file
):
config_dir = os.path.join(
self.repo.root_dir, hydra_config.get("config_dir")
)
config_name = hydra_config.get("config_name", "config")
compose_and_dump(
path,
config_dir,
config_name,
overrides,
)
else:
apply_overrides(path, overrides)

# Force params file changes to be staged in git
# Otherwise in certain situations the changes to params file may be
Expand Down
17 changes: 14 additions & 3 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Dict, Iterable, Optional

from dvc.dependency.param import ParamsDependency
from dvc.repo import locked
from dvc.ui import ui
from dvc.utils.cli_parse import to_path_overrides
Expand Down Expand Up @@ -31,21 +32,31 @@ def run(
return repo.experiments.reproduce_celery(entries, jobs=jobs)

if params:
params = to_path_overrides(params)
path_overrides = to_path_overrides(params)
else:
path_overrides = {}

hydra_config_dir = repo.config.get("hydra", {}).get("config_dir", None)
hydra_output_file = repo.config.get("hydra", {}).get(
"output_file", ParamsDependency.DEFAULT_PARAMS_FILE
)
if hydra_config_dir and hydra_output_file not in path_overrides:
# Force `_update_params` even if `--set-param` was not used
path_overrides[hydra_output_file] = []

if queue:
if not kwargs.get("checkpoint_resume", None):
kwargs["reset"] = True
queue_entry = repo.experiments.queue_one(
repo.experiments.celery_queue,
targets=targets,
params=params,
params=path_overrides,
**kwargs,
)
name = queue_entry.name or queue_entry.stash_rev[:7]
ui.write(f"Queued experiment '{name}' for future execution.")
return {}

return repo.experiments.reproduce_one(
targets=targets, params=params, tmp_dir=tmp_dir, **kwargs
targets=targets, params=path_overrides, tmp_dir=tmp_dir, **kwargs
)
29 changes: 28 additions & 1 deletion dvc/utils/hydra.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, List

from hydra import compose, initialize_config_dir
from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.errors import ConfigCompositionException, OverrideParseException
Expand All @@ -10,12 +11,38 @@
from dvc.exceptions import InvalidArgumentError

from .collections import merge_dicts, remove_missing_keys, to_omegaconf
from .serialize import MODIFIERS
from .serialize import DUMPERS, MODIFIERS

if TYPE_CHECKING:
from dvc.types import StrPath


def compose_and_dump(
output_file: "StrPath",
config_dir: str,
config_name: str,
overrides: List[str],
) -> None:
"""Compose Hydra config and dumpt it to `output_file`.
Args:
output_file: File where the composed config will be dumped.
config_dir: Folder containing the Hydra config files.
Must be absolute file system path.
config_name: Name of the config file containing defaults,
without the .yaml extension.
overrides: List of `Hydra Override`_ patterns.
.. _Hydra Override:
https://hydra.cc/docs/next/advanced/override_grammar/basic/
"""
with initialize_config_dir(config_dir, version_base=None):
cfg = compose(config_name=config_name, overrides=overrides)

dumper = DUMPERS[Path(output_file).suffix.lower()]
dumper(output_file, OmegaConf.to_object(cfg))


def apply_overrides(path: "StrPath", overrides: List[str]) -> None:
"""Update `path` params with the provided `Hydra Override`_ patterns.
Expand Down
14 changes: 0 additions & 14 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,6 @@ def test_failed_exp_workspace(
)


@pytest.mark.parametrize(
"changes, expected",
[
[["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
[["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
],
)
def test_modify_params(params_repo, dvc, changes, expected):
dvc.experiments.run(params=changes)
# pylint: disable=unspecified-encoding
with open("params.yaml", mode="r") as fobj:
assert fobj.read().strip() == expected


def test_apply(tmp_dir, scm, dvc, exp_stage):
from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.exceptions import ApplyConflictError
Expand Down
73 changes: 73 additions & 0 deletions tests/func/experiments/test_set_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

from ..utils.test_hydra import hydra_setup


@pytest.mark.parametrize(
"changes, expected",
[
[["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
[["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"],
],
)
def test_modify_params(params_repo, dvc, changes, expected):
dvc.experiments.run(params=changes)
# pylint: disable=unspecified-encoding
with open("params.yaml", mode="r") as fobj:
assert fobj.read().strip() == expected


@pytest.mark.parametrize(
"output_file,config_dir,config_name",
[
(None, None, None),
("params.yaml", None, "bar"),
("params.yaml", "conf", "bar"),
],
)
def test_hydra_compose_and_dump(
tmp_dir, params_repo, dvc, output_file, config_dir, config_name
):
hydra_setup(
tmp_dir,
config_dir=config_dir or "conf",
config_name=config_name or "config",
)

dvc.experiments.run()
assert (tmp_dir / "params.yaml").parse() == {
"foo": [{"bar": 1}, {"baz": 2}],
"goo": {"bag": 3.0},
"lorem": False,
}

with dvc.config.edit() as conf:
if output_file is not None:
conf["hydra"]["output_file"] = output_file
if config_dir is not None:
conf["hydra"]["config_dir"] = config_dir
if config_name is not None:
conf["hydra"]["config_name"] = config_name

dvc.experiments.run()

if config_dir is not None:
assert (tmp_dir / output_file or "params.yaml").parse() == {
"db": {"driver": "mysql", "user": "omry", "pass": "secret"},
}

dvc.experiments.run(params=["db=postgresql"])
assert (tmp_dir / output_file or "params.yaml").parse() == {
"db": {
"driver": "postgresql",
"user": "foo",
"pass": "bar",
"timeout": 10,
}
}
else:
assert (tmp_dir / "params.yaml").parse() == {
"foo": [{"bar": 1}, {"baz": 2}],
"goo": {"bag": 3.0},
"lorem": False,
}
52 changes: 51 additions & 1 deletion tests/func/utils/test_hydra.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from dvc.exceptions import InvalidArgumentError
from dvc.utils.hydra import apply_overrides
from dvc.utils.hydra import apply_overrides, compose_and_dump


@pytest.mark.parametrize("suffix", ["yaml", "toml", "json"])
Expand Down Expand Up @@ -120,3 +120,53 @@ def test_invalid_overrides(tmp_dir, overrides):
)
with pytest.raises(InvalidArgumentError):
apply_overrides(path=params_file.name, overrides=overrides)


def hydra_setup(tmp_dir, config_dir, config_name):
config_dir = tmp_dir / config_dir
(config_dir / "db").mkdir(parents=True)
(config_dir / f"{config_name}.yaml").dump({"defaults": [{"db": "mysql"}]})
(config_dir / "db" / "mysql.yaml").dump(
{"driver": "mysql", "user": "omry", "pass": "secret"}
)
(config_dir / "db" / "postgresql.yaml").dump(
{"driver": "postgresql", "user": "foo", "pass": "bar", "timeout": 10}
)
return str(config_dir)


@pytest.mark.parametrize("suffix", ["yaml", "toml", "json"])
@pytest.mark.parametrize(
"overrides,expected",
[
([], {"db": {"driver": "mysql", "user": "omry", "pass": "secret"}}),
(
["db=postgresql"],
{
"db": {
"driver": "postgresql",
"user": "foo",
"pass": "bar",
"timeout": 10,
}
},
),
(
["db=postgresql", "db.timeout=20"],
{
"db": {
"driver": "postgresql",
"user": "foo",
"pass": "bar",
"timeout": 20,
}
},
),
],
)
def test_compose_and_dump(tmp_dir, suffix, overrides, expected):
config_name = "config"
config_dir = hydra_setup(tmp_dir, "conf", "config")
output_file = tmp_dir / f"params.{suffix}"
compose_and_dump(output_file, config_dir, config_name, overrides)
assert output_file.parse() == expected

0 comments on commit 5d61212

Please sign in to comment.