-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
exp run: Support composing and dumping Hydra config.
The feature will be used depending on whether `config.hydra.enabled` is True or False. Uses `hydra.initialize_config_dir` and `hydra.compose` (from https://hydra.cc/docs/advanced/compose_api/) to build the config and dump it to `params.yaml`. The content of the output file will be overwritten. `config.hydra.config_dir` and `config.hydra.config_name` can be used to customize the values passed to `hydra.initialize_config_dir` and `hydra.compose`. Can be combined with `--set-param` overrides. Closes #8082
- Loading branch information
Showing
7 changed files
with
188 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import pytest | ||
|
||
from ..utils.test_hydra import hydra_setup | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"changes, expected", | ||
[ | ||
[["foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], | ||
[["params.yaml:foo=baz"], "foo: baz\ngoo:\n bag: 3.0\nlorem: false"], | ||
], | ||
) | ||
def test_modify_params(params_repo, dvc, changes, expected): | ||
dvc.experiments.run(params=changes) | ||
# pylint: disable=unspecified-encoding | ||
with open("params.yaml", mode="r") as fobj: | ||
assert fobj.read().strip() == expected | ||
|
||
|
||
@pytest.mark.parametrize("hydra_enabled", [True, False]) | ||
@pytest.mark.parametrize( | ||
"config_dir,config_name", | ||
[ | ||
(None, None), | ||
(None, "bar"), | ||
("conf", "bar"), | ||
], | ||
) | ||
def test_hydra_compose_and_dump( | ||
tmp_dir, params_repo, dvc, hydra_enabled, config_dir, config_name | ||
): | ||
hydra_setup( | ||
tmp_dir, | ||
config_dir=config_dir or "conf", | ||
config_name=config_name or "config", | ||
) | ||
|
||
dvc.experiments.run() | ||
assert (tmp_dir / "params.yaml").parse() == { | ||
"foo": [{"bar": 1}, {"baz": 2}], | ||
"goo": {"bag": 3.0}, | ||
"lorem": False, | ||
} | ||
|
||
with dvc.config.edit() as conf: | ||
if hydra_enabled: | ||
conf["hydra"]["enabled"] = True | ||
if config_dir is not None: | ||
conf["hydra"]["config_dir"] = config_dir | ||
if config_name is not None: | ||
conf["hydra"]["config_name"] = config_name | ||
|
||
dvc.experiments.run() | ||
|
||
if hydra_enabled: | ||
assert (tmp_dir / "params.yaml").parse() == { | ||
"db": {"driver": "mysql", "user": "omry", "pass": "secret"}, | ||
} | ||
|
||
dvc.experiments.run(params=["db=postgresql"]) | ||
assert (tmp_dir / "params.yaml").parse() == { | ||
"db": { | ||
"driver": "postgresql", | ||
"user": "foo", | ||
"pass": "bar", | ||
"timeout": 10, | ||
} | ||
} | ||
else: | ||
assert (tmp_dir / "params.yaml").parse() == { | ||
"foo": [{"bar": 1}, {"baz": 2}], | ||
"goo": {"bag": 3.0}, | ||
"lorem": False, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters