diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index 5a48d3cf28..29f7b65389 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -20,6 +20,8 @@ _config_logger = logging.getLogger(__name__) +_NO_VALUE = object() + class OmegaConfigLoader(AbstractConfigLoader): """Recursively scan directories (config paths) contained in ``conf_source`` for @@ -316,31 +318,26 @@ def _register_globals_resolver(self): """Register the globals resolver""" OmegaConf.register_new_resolver( "globals", - lambda variable, default_value=None: self._get_globals_value( - variable, default_value - ), + self._get_globals_value, replace=True, ) - def _get_globals_value(self, variable, default_value): + def _get_globals_value(self, variable, default_value=_NO_VALUE): """Return the globals values to the resolver""" if variable.startswith("_"): raise InterpolationResolutionError( "Keys starting with '_' are not supported for globals." ) - keys = variable.split(".") - value = self["globals"] - for k in keys: - value = value.get(k) - if not value: - if default_value: - _config_logger.debug( - f"Using the default value for the global variable {variable}." - ) - return default_value - msg = f"Globals key '{variable}' not found and no default value provided. " - raise InterpolationResolutionError(msg) - return value + global_omegaconf = OmegaConf.create(self["globals"]) + interpolated_value = OmegaConf.select( + global_omegaconf, variable, default=default_value + ) + if interpolated_value != _NO_VALUE: + return interpolated_value + else: + raise InterpolationResolutionError( + f"Globals key '{variable}' not found and no default value provided." + ) @staticmethod def _register_new_resolvers(resolvers: dict[str, Callable]): diff --git a/tests/config/test_omegaconf_config.py b/tests/config/test_omegaconf_config.py index 4713d0da14..824508b5d0 100644 --- a/tests/config/test_omegaconf_config.py +++ b/tests/config/test_omegaconf_config.py @@ -676,7 +676,7 @@ def test_custom_resolvers(self, tmp_path): def test_globals(self, tmp_path): globals_params = tmp_path / _BASE_ENV / "globals.yml" globals_config = { - "x": 34, + "x": 0, } _write_yaml(globals_params, globals_config) conf = OmegaConfigLoader(tmp_path, default_run_env="") @@ -704,7 +704,6 @@ def test_globals_resolution(self, tmp_path): _write_yaml(globals_params, globals_config) _write_yaml(base_catalog, catalog_config) conf = OmegaConfigLoader(tmp_path, default_run_env="") - assert OmegaConf.has_resolver("globals") # Globals are resolved correctly in parameter assert conf["parameters"]["my_param"] == globals_config["x"] # The default value is used if the key does not exist @@ -760,25 +759,68 @@ def test_globals_across_env(self, tmp_path): # Base global value is accessible to local params assert conf["parameters"]["param2"] == base_globals_config["x"] - def test_bad_globals(self, tmp_path): + def test_globals_default(self, tmp_path): + base_params = tmp_path / _BASE_ENV / "parameters.yml" + base_globals = tmp_path / _BASE_ENV / "globals.yml" + base_param_config = { + "int": "${globals:x.NOT_EXIST, 1}", + "str": "${globals: x.NOT_EXIST, '2'}", + "dummy": "${globals: x.DUMMY.DUMMY, '2'}", + } + base_globals_config = {"x": {"DUMMY": 3}} + _write_yaml(base_params, base_param_config) + _write_yaml(base_globals, base_globals_config) + conf = OmegaConfigLoader(tmp_path, default_run_env="") + # Default value is being used as int + assert conf["parameters"]["int"] == 1 + # Default value is being used as str + assert conf["parameters"]["str"] == "2" + # Test when x.DUMMY is not a dictionary it should still work + assert conf["parameters"]["dummy"] == "2" + + def test_globals_default_none(self, tmp_path): base_params = tmp_path / _BASE_ENV / "parameters.yml" base_globals = tmp_path / _BASE_ENV / "globals.yml" base_param_config = { - "param1": "${globals:x.y}", + "zero": "${globals: x.NOT_EXIST, 0}", + "null": "${globals: x.NOT_EXIST, null}", + "null2": "${globals: x.y}", } base_globals_config = { "x": { "z": 23, + "y": None, }, } _write_yaml(base_params, base_param_config) _write_yaml(base_globals, base_globals_config) conf = OmegaConfigLoader(tmp_path, default_run_env="") + # Default value can be 0 or null + assert conf["parameters"]["zero"] == 0 + assert conf["parameters"]["null"] is None + # Global value is null + assert conf["parameters"]["null2"] is None + + def test_globals_missing_default(self, tmp_path): + base_params = tmp_path / _BASE_ENV / "parameters.yml" + globals_params = tmp_path / _BASE_ENV / "globals.yml" + param_config = { + "NOT_OK": "${globals:nested.NOT_EXIST}", + } + globals_config = { + "nested": { + "y": 42, + }, + } + _write_yaml(base_params, param_config) + _write_yaml(globals_params, globals_config) + conf = OmegaConfigLoader(tmp_path, default_run_env="") + with pytest.raises( InterpolationResolutionError, - match=r"Globals key 'x.y' not found and no default value provided.", + match="Globals key 'nested.NOT_EXIST' not found and no default value provided.", ): - conf["parameters"]["param1"] + conf["parameters"]["NOT_OK"] def test_bad_globals_underscore(self, tmp_path): base_params = tmp_path / _BASE_ENV / "parameters.yml"