diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 3ea66171898..1596e742d97 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -289,4 +289,5 @@ class RelPath(str): "bool": All(Lower, Choices("store_true", "boolean_optional")), "list": All(Lower, Choices("nargs", "append")), }, + "hydra": {"config_dir": str, "config_name": str}, } diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 4ba87c98fd0..b60eafd4ab9 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -19,6 +19,7 @@ from funcy import cached_property +from dvc.dependency import ParamsDependency from dvc.env import DVCLIVE_RESUME from dvc.exceptions import DvcException from dvc.ui import ui @@ -524,10 +525,27 @@ def _update_params(self, params: Dict[str, List[str]]): """ logger.debug("Using experiment params '%s'", params) - from dvc.utils.hydra import apply_overrides + from dvc.utils.hydra import apply_overrides, compose_and_dump + hydra_config = self.repo.config.get("hydra", {}) + hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE for path, overrides in params.items(): - apply_overrides(path, overrides) + if ( + hydra_config.get("config_dir") is not None + and path == hydra_output_file + ): + config_dir = os.path.join( + self.repo.root_dir, hydra_config.get("config_dir") + ) + config_name = hydra_config.get("config_name", "config") + compose_and_dump( + path, + config_dir, + config_name, + overrides, + ) + else: + apply_overrides(path, overrides) # Force params file changes to be staged in git # Otherwise in certain situations the changes to params file may be diff --git a/dvc/repo/experiments/run.py b/dvc/repo/experiments/run.py index 5853293417a..ad4609f06bd 100644 --- a/dvc/repo/experiments/run.py +++ b/dvc/repo/experiments/run.py @@ -1,6 +1,7 @@ import logging from typing import Dict, Iterable, Optional +from dvc.dependency.param import ParamsDependency from dvc.repo import locked from dvc.ui import ui from dvc.utils.cli_parse import to_path_overrides @@ -31,7 +32,15 @@ def run( return repo.experiments.reproduce_celery(entries, jobs=jobs) if params: - params = to_path_overrides(params) + path_overrides = to_path_overrides(params) + else: + path_overrides = {} + + hydra_config_dir = repo.config.get("hydra", {}).get("config_dir", None) + hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE + if hydra_config_dir and hydra_output_file not in path_overrides: + # Force `_update_params` even if `--set-param` was not used + path_overrides[hydra_output_file] = [] if queue: if not kwargs.get("checkpoint_resume", None): @@ -39,7 +48,7 @@ def run( queue_entry = repo.experiments.queue_one( repo.experiments.celery_queue, targets=targets, - params=params, + params=path_overrides, **kwargs, ) name = queue_entry.name or queue_entry.stash_rev[:7] @@ -47,5 +56,5 @@ def run( return {} return repo.experiments.reproduce_one( - targets=targets, params=params, tmp_dir=tmp_dir, **kwargs + targets=targets, params=path_overrides, tmp_dir=tmp_dir, **kwargs ) diff --git a/dvc/utils/hydra.py b/dvc/utils/hydra.py index e03be3e1c62..7019cdde8cd 100644 --- a/dvc/utils/hydra.py +++ b/dvc/utils/hydra.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import TYPE_CHECKING, List +from hydra import compose, initialize_config_dir from hydra._internal.config_loader_impl import ConfigLoaderImpl from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.errors import ConfigCompositionException, OverrideParseException @@ -10,12 +11,38 @@ from dvc.exceptions import InvalidArgumentError from .collections import merge_dicts, remove_missing_keys, to_omegaconf -from .serialize import MODIFIERS +from .serialize import DUMPERS, MODIFIERS if TYPE_CHECKING: from dvc.types import StrPath +def compose_and_dump( + output_file: "StrPath", + config_dir: str, + config_name: str, + overrides: List[str], +) -> None: + """Compose Hydra config and dumpt it to `output_file`. + + Args: + output_file: File where the composed config will be dumped. + config_dir: Folder containing the Hydra config files. + Must be absolute file system path. + config_name: Name of the config file containing defaults, + without the .yaml extension. + overrides: List of `Hydra Override`_ patterns. + + .. _Hydra Override: + https://hydra.cc/docs/next/advanced/override_grammar/basic/ + """ + with initialize_config_dir(config_dir, version_base=None): + cfg = compose(config_name=config_name, overrides=overrides) + + dumper = DUMPERS[Path(output_file).suffix.lower()] + dumper(output_file, OmegaConf.to_object(cfg)) + + def apply_overrides(path: "StrPath", overrides: List[str]) -> None: """Update `path` params with the provided `Hydra Override`_ patterns. diff --git a/tests/func/experiments/test_experiments.py b/tests/func/experiments/test_experiments.py index 0f67a1fbc21..48625bf31bf 100644 --- a/tests/func/experiments/test_experiments.py +++ b/tests/func/experiments/test_experiments.py @@ -104,20 +104,6 @@ def test_failed_exp_workspace( ) -@pytest.mark.parametrize( - "changes, expected", - [ - [["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], - [["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], - ], -) -def test_modify_params(params_repo, dvc, changes, expected): - dvc.experiments.run(params=changes) - # pylint: disable=unspecified-encoding - with open("params.yaml", mode="r") as fobj: - assert fobj.read().strip() == expected - - def test_apply(tmp_dir, scm, dvc, exp_stage): from dvc.exceptions import InvalidArgumentError from dvc.repo.experiments.exceptions import ApplyConflictError diff --git a/tests/func/experiments/test_set_params.py b/tests/func/experiments/test_set_params.py new file mode 100644 index 00000000000..2af62665364 --- /dev/null +++ b/tests/func/experiments/test_set_params.py @@ -0,0 +1,71 @@ +import pytest + +from ..utils.test_hydra import hydra_setup + + +@pytest.mark.parametrize( + "changes, expected", + [ + [["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], + [["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], + ], +) +def test_modify_params(params_repo, dvc, changes, expected): + dvc.experiments.run(params=changes) + # pylint: disable=unspecified-encoding + with open("params.yaml", mode="r") as fobj: + assert fobj.read().strip() == expected + + +@pytest.mark.parametrize( + "config_dir,config_name", + [ + (None, None), + (None, "bar"), + ("conf", "bar"), + ], +) +def test_hydra_compose_and_dump( + tmp_dir, params_repo, dvc, config_dir, config_name +): + hydra_setup( + tmp_dir, + config_dir=config_dir or "conf", + config_name=config_name or "config", + ) + + dvc.experiments.run() + assert (tmp_dir / "params.yaml").parse() == { + "foo": [{"bar": 1}, {"baz": 2}], + "goo": {"bag": 3.0}, + "lorem": False, + } + + with dvc.config.edit() as conf: + if config_dir is not None: + conf["hydra"]["config_dir"] = config_dir + if config_name is not None: + conf["hydra"]["config_name"] = config_name + + dvc.experiments.run() + + if config_dir is not None: + assert (tmp_dir / "params.yaml").parse() == { + "db": {"driver": "mysql", "user": "omry", "pass": "secret"}, + } + + dvc.experiments.run(params=["db=postgresql"]) + assert (tmp_dir / "params.yaml").parse() == { + "db": { + "driver": "postgresql", + "user": "foo", + "pass": "bar", + "timeout": 10, + } + } + else: + assert (tmp_dir / "params.yaml").parse() == { + "foo": [{"bar": 1}, {"baz": 2}], + "goo": {"bag": 3.0}, + "lorem": False, + } diff --git a/tests/func/utils/test_hydra.py b/tests/func/utils/test_hydra.py index 462405ab77a..5acb26bbf09 100644 --- a/tests/func/utils/test_hydra.py +++ b/tests/func/utils/test_hydra.py @@ -1,7 +1,7 @@ import pytest from dvc.exceptions import InvalidArgumentError -from dvc.utils.hydra import apply_overrides +from dvc.utils.hydra import apply_overrides, compose_and_dump @pytest.mark.parametrize("suffix", ["yaml", "toml", "json"]) @@ -120,3 +120,53 @@ def test_invalid_overrides(tmp_dir, overrides): ) with pytest.raises(InvalidArgumentError): apply_overrides(path=params_file.name, overrides=overrides) + + +def hydra_setup(tmp_dir, config_dir, config_name): + config_dir = tmp_dir / config_dir + (config_dir / "db").mkdir(parents=True) + (config_dir / f"{config_name}.yaml").dump({"defaults": [{"db": "mysql"}]}) + (config_dir / "db" / "mysql.yaml").dump( + {"driver": "mysql", "user": "omry", "pass": "secret"} + ) + (config_dir / "db" / "postgresql.yaml").dump( + {"driver": "postgresql", "user": "foo", "pass": "bar", "timeout": 10} + ) + return str(config_dir) + + +@pytest.mark.parametrize("suffix", ["yaml", "toml", "json"]) +@pytest.mark.parametrize( + "overrides,expected", + [ + ([], {"db": {"driver": "mysql", "user": "omry", "pass": "secret"}}), + ( + ["db=postgresql"], + { + "db": { + "driver": "postgresql", + "user": "foo", + "pass": "bar", + "timeout": 10, + } + }, + ), + ( + ["db=postgresql", "db.timeout=20"], + { + "db": { + "driver": "postgresql", + "user": "foo", + "pass": "bar", + "timeout": 20, + } + }, + ), + ], +) +def test_compose_and_dump(tmp_dir, suffix, overrides, expected): + config_name = "config" + config_dir = hydra_setup(tmp_dir, "conf", "config") + output_file = tmp_dir / f"params.{suffix}" + compose_and_dump(output_file, config_dir, config_name, overrides) + assert output_file.parse() == expected