Skip to content

Commit

Permalink
Introduce a sentinel value _NO_VALUE to improve Global resolvers to s…
Browse files Browse the repository at this point in the history
…upport defaults `0` or `None` (#2976)

* Add None support to globals

Signed-off-by: Ankita Katiyar <[email protected]>

* Add warning when default value is used

Signed-off-by: Ankita Katiyar <[email protected]>

* Check keys

Signed-off-by: Ankita Katiyar <[email protected]>

* Nok's suggestions

Signed-off-by: Ankita Katiyar <[email protected]>

* Create the test to check the non-existing keys

Signed-off-by: Nok <[email protected]>

* add more tests to catch case when global key is not a dict

Signed-off-by: Nok <[email protected]>

* Fix the null test

Signed-off-by: Nok <[email protected]>

* Introduce sentinel value _NO_VALUE

Signed-off-by: Nok <[email protected]>

* rename test

Signed-off-by: Nok <[email protected]>

* Improve error mesasge and raise InterpolationResolutionError when key does not exist and no default

Signed-off-by: Nok <[email protected]>

* Fix non exist default test

Signed-off-by: Nok <[email protected]>

* Fix test

Signed-off-by: Nok <[email protected]>

* Use omegaconf to replace the custom resolving logic

Signed-off-by: Nok <[email protected]>

* uncommented some tests

Signed-off-by: Nok <[email protected]>

* Remove dead code

Signed-off-by: Ankita Katiyar <[email protected]>

* Update error message

Signed-off-by: Ankita Katiyar <[email protected]>

---------

Signed-off-by: Ankita Katiyar <[email protected]>
Signed-off-by: Nok <[email protected]>
Co-authored-by: Nok <[email protected]>
  • Loading branch information
ankatiyar and noklam authored Aug 30, 2023
1 parent 5a11941 commit dfee643
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 23 deletions.
31 changes: 14 additions & 17 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

_config_logger = logging.getLogger(__name__)

_NO_VALUE = object()


class OmegaConfigLoader(AbstractConfigLoader):
"""Recursively scan directories (config paths) contained in ``conf_source`` for
Expand Down Expand Up @@ -316,31 +318,26 @@ def _register_globals_resolver(self):
"""Register the globals resolver"""
OmegaConf.register_new_resolver(
"globals",
lambda variable, default_value=None: self._get_globals_value(
variable, default_value
),
self._get_globals_value,
replace=True,
)

def _get_globals_value(self, variable, default_value):
def _get_globals_value(self, variable, default_value=_NO_VALUE):
"""Return the globals values to the resolver"""
if variable.startswith("_"):
raise InterpolationResolutionError(
"Keys starting with '_' are not supported for globals."
)
keys = variable.split(".")
value = self["globals"]
for k in keys:
value = value.get(k)
if not value:
if default_value:
_config_logger.debug(
f"Using the default value for the global variable {variable}."
)
return default_value
msg = f"Globals key '{variable}' not found and no default value provided. "
raise InterpolationResolutionError(msg)
return value
global_omegaconf = OmegaConf.create(self["globals"])
interpolated_value = OmegaConf.select(
global_omegaconf, variable, default=default_value
)
if interpolated_value != _NO_VALUE:
return interpolated_value
else:
raise InterpolationResolutionError(
f"Globals key '{variable}' not found and no default value provided."
)

@staticmethod
def _register_new_resolvers(resolvers: dict[str, Callable]):
Expand Down
54 changes: 48 additions & 6 deletions tests/config/test_omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def test_custom_resolvers(self, tmp_path):
def test_globals(self, tmp_path):
globals_params = tmp_path / _BASE_ENV / "globals.yml"
globals_config = {
"x": 34,
"x": 0,
}
_write_yaml(globals_params, globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
Expand Down Expand Up @@ -704,7 +704,6 @@ def test_globals_resolution(self, tmp_path):
_write_yaml(globals_params, globals_config)
_write_yaml(base_catalog, catalog_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
assert OmegaConf.has_resolver("globals")
# Globals are resolved correctly in parameter
assert conf["parameters"]["my_param"] == globals_config["x"]
# The default value is used if the key does not exist
Expand Down Expand Up @@ -760,25 +759,68 @@ def test_globals_across_env(self, tmp_path):
# Base global value is accessible to local params
assert conf["parameters"]["param2"] == base_globals_config["x"]

def test_bad_globals(self, tmp_path):
def test_globals_default(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
base_globals = tmp_path / _BASE_ENV / "globals.yml"
base_param_config = {
"int": "${globals:x.NOT_EXIST, 1}",
"str": "${globals: x.NOT_EXIST, '2'}",
"dummy": "${globals: x.DUMMY.DUMMY, '2'}",
}
base_globals_config = {"x": {"DUMMY": 3}}
_write_yaml(base_params, base_param_config)
_write_yaml(base_globals, base_globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
# Default value is being used as int
assert conf["parameters"]["int"] == 1
# Default value is being used as str
assert conf["parameters"]["str"] == "2"
# Test when x.DUMMY is not a dictionary it should still work
assert conf["parameters"]["dummy"] == "2"

def test_globals_default_none(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
base_globals = tmp_path / _BASE_ENV / "globals.yml"
base_param_config = {
"param1": "${globals:x.y}",
"zero": "${globals: x.NOT_EXIST, 0}",
"null": "${globals: x.NOT_EXIST, null}",
"null2": "${globals: x.y}",
}
base_globals_config = {
"x": {
"z": 23,
"y": None,
},
}
_write_yaml(base_params, base_param_config)
_write_yaml(base_globals, base_globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
# Default value can be 0 or null
assert conf["parameters"]["zero"] == 0
assert conf["parameters"]["null"] is None
# Global value is null
assert conf["parameters"]["null2"] is None

def test_globals_missing_default(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
globals_params = tmp_path / _BASE_ENV / "globals.yml"
param_config = {
"NOT_OK": "${globals:nested.NOT_EXIST}",
}
globals_config = {
"nested": {
"y": 42,
},
}
_write_yaml(base_params, param_config)
_write_yaml(globals_params, globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")

with pytest.raises(
InterpolationResolutionError,
match=r"Globals key 'x.y' not found and no default value provided.",
match="Globals key 'nested.NOT_EXIST' not found and no default value provided.",
):
conf["parameters"]["param1"]
conf["parameters"]["NOT_OK"]

def test_bad_globals_underscore(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
Expand Down

0 comments on commit dfee643

Please sign in to comment.