diff --git a/RELEASE.md b/RELEASE.md index 6a82c79b57..e2f762ff73 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,7 +11,11 @@ # Upcoming Release 0.18.9 ## Major features and improvements +* `kedro run --params` now updates interpolated parameters correctly when using `OmegaConfigLoader`. + ## Bug fixes and other changes +* `OmegaConfigLoader` will return a `dict` instead of `DictConfig`. + ## Breaking changes to the API * `kedro package` does not produce `.egg` files anymore, and now relies exclusively on `.whl` files. ## Upcoming deprecations for Kedro 0.19.0 diff --git a/kedro/config/abstract_config.py b/kedro/config/abstract_config.py index 70a463fc15..776ec6c836 100644 --- a/kedro/config/abstract_config.py +++ b/kedro/config/abstract_config.py @@ -24,7 +24,7 @@ def __init__( super().__init__() self.conf_source = conf_source self.env = env - self.runtime_params = runtime_params + self.runtime_params = runtime_params or {} class BadConfigException(Exception): diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index d04ea35b3b..285ea1f02d 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -168,7 +168,7 @@ def __getitem__(self, key) -> dict[str, Any]: else: base_path = str(Path(self._fs.ls("", detail=False)[-1]) / self.base_env) base_config = self.load_and_merge_dir_config( - base_path, patterns, read_environment_variables + base_path, patterns, key, read_environment_variables ) config = base_config @@ -179,7 +179,7 @@ def __getitem__(self, key) -> dict[str, Any]: else: env_path = str(Path(self._fs.ls("", detail=False)[-1]) / run_env) env_config = self.load_and_merge_dir_config( - env_path, patterns, read_environment_variables + env_path, patterns, key, read_environment_variables ) # Destructively merge the two env dirs. The chosen env will override base. @@ -211,6 +211,7 @@ def load_and_merge_dir_config( self, conf_path: str, patterns: Iterable[str], + key: str, read_environment_variables: bool | None = False, ) -> dict[str, Any]: """Recursively load and merge all configuration files in a directory using OmegaConf, @@ -219,6 +220,7 @@ def load_and_merge_dir_config( Args: conf_path: Path to configuration directory. patterns: List of glob patterns to match the filenames against. + key: Key of the configuration type to fetch. read_environment_variables: Whether to resolve environment variables. Raises: @@ -275,9 +277,13 @@ def load_and_merge_dir_config( if not aggregate_config: return {} - if len(aggregate_config) == 1: - return list(aggregate_config)[0] - return dict(OmegaConf.merge(*aggregate_config)) + + if key == "parameters": + # Merge with runtime parameters only for "parameters" + return OmegaConf.to_container( + OmegaConf.merge(*aggregate_config, self.runtime_params), resolve=True + ) + return OmegaConf.to_container(OmegaConf.merge(*aggregate_config), resolve=True) def _is_valid_config_path(self, path): """Check if given path is a file path and file type is yaml or json.""" diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index bd2b062da5..448dd0612d 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -201,7 +201,7 @@ def create( # pylint: disable=too-many-arguments def _get_logging_config(self) -> dict[str, Any]: logging_config = self._get_config_loader()["logging"] if isinstance(logging_config, omegaconf.DictConfig): - logging_config = OmegaConf.to_container(logging_config) + logging_config = OmegaConf.to_container(logging_config) # pragma: no cover # turn relative paths in logging config into absolute path # before initialising loggers logging_config = _convert_paths_to_absolute_posix( diff --git a/tests/config/test_omegaconf_config.py b/tests/config/test_omegaconf_config.py index a93b284853..bc80c08711 100644 --- a/tests/config/test_omegaconf_config.py +++ b/tests/config/test_omegaconf_config.py @@ -70,14 +70,30 @@ def local_config(tmp_path): @pytest.fixture def create_config_dir(tmp_path, base_config, local_config): - proj_catalog = tmp_path / _BASE_ENV / "catalog.yml" + base_catalog = tmp_path / _BASE_ENV / "catalog.yml" + base_logging = tmp_path / _BASE_ENV / "logging.yml" + base_spark = tmp_path / _BASE_ENV / "spark.yml" + base_catalog = tmp_path / _BASE_ENV / "catalog.yml" + local_catalog = tmp_path / _DEFAULT_RUN_ENV / "catalog.yml" + parameters = tmp_path / _BASE_ENV / "parameters.json" - project_parameters = {"param1": 1, "param2": 2} + base_parameters = {"param1": 1, "param2": 2, "interpolated_param": "${test_env}"} + base_global_parameters = {"test_env": "base"} + local_global_parameters = {"test_env": "local"} - _write_yaml(proj_catalog, base_config) + _write_yaml(base_catalog, base_config) _write_yaml(local_catalog, local_config) - _write_json(parameters, project_parameters) + + # Empty Config + _write_yaml(base_logging, {"version": 1}) + _write_yaml(base_spark, {"dummy": 1}) + + _write_json(parameters, base_parameters) + _write_json(tmp_path / _BASE_ENV / "parameters_global.json", base_global_parameters) + _write_json( + tmp_path / _DEFAULT_RUN_ENV / "parameters_global.json", local_global_parameters + ) @pytest.fixture @@ -531,3 +547,43 @@ def zipdir(path, ziph): conf = OmegaConfigLoader(conf_source=f"{tmp_path}/Python.zip") catalog = conf["catalog"] assert catalog["trains"]["type"] == "MemoryDataSet" + + @use_config_dir + def test_variable_interpolation_with_correct_env(self, tmp_path): + """Make sure the parameters is interpolated with the correct environment""" + conf = OmegaConfigLoader(str(tmp_path)) + params = conf["parameters"] + # Making sure it is not override by local/parameters_global.yml + assert params["interpolated_param"] == "base" + + @use_config_dir + def test_runtime_params_override_interpolated_value(self, tmp_path): + """Make sure interpolated value is updated correctly with runtime_params""" + conf = OmegaConfigLoader(str(tmp_path), runtime_params={"test_env": "dummy"}) + params = conf["parameters"] + assert params["interpolated_param"] == "dummy" + + @use_config_dir + @use_credentials_env_variable_yml + def test_runtime_params_not_propogate_non_parameters_config(self, tmp_path): + """Make sure `catalog`, `credentials`, `logging` or any config other than + `parameters` are not updated by `runtime_params`.""" + # https://github.com/kedro-org/kedro/pull/2467 + key = "test_env" + runtime_params = {key: "dummy"} + conf = OmegaConfigLoader( + str(tmp_path), + config_patterns={"spark": ["spark*", "spark*/**", "**/spark*"]}, + runtime_params=runtime_params, + ) + parameters = conf["parameters"] + catalog = conf["catalog"] + credentials = conf["credentials"] + logging = conf["logging"] + spark = conf["spark"] + + assert key in parameters + assert key not in catalog + assert key not in credentials + assert key not in logging + assert key not in spark