From 375150544d80db7670b66c76214e2c420c76933a Mon Sep 17 00:00:00 2001 From: "Jessica Zhang (NY)" Date: Thu, 16 Jan 2025 16:52:42 -0500 Subject: [PATCH] Support extending list from override --- hydra/_internal/config_loader_impl.py | 48 +++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/hydra/_internal/config_loader_impl.py b/hydra/_internal/config_loader_impl.py index 6b3cb5ffc2..4a5bb10503 100644 --- a/hydra/_internal/config_loader_impl.py +++ b/hydra/_internal/config_loader_impl.py @@ -7,7 +7,15 @@ from textwrap import dedent from typing import Any, List, MutableSequence, Optional, Tuple -from omegaconf import Container, DictConfig, OmegaConf, flag_override, open_dict +from omegaconf import ( + Container, + DictConfig, + ListConfig, + Node, + OmegaConf, + flag_override, + open_dict, +) from omegaconf.errors import ( ConfigAttributeError, ConfigKeyError, @@ -475,7 +483,8 @@ def _load_single_config( ): hydra = config.pop("hydra") - merged = OmegaConf.merge(schema.config, config) + override_cfg = _convert_list_extend_overrides(schema.config, config) + merged = OmegaConf.merge(schema.config, override_cfg) # TODO assert isinstance(merged, DictConfig) if hydra is not None: @@ -548,7 +557,8 @@ def _compose_config_from_defaults_list( for default in defaults: loaded = self._load_single_config(default=default, repo=repo) try: - cfg.merge_with(loaded.config) + override_config = _convert_list_extend_overrides(cfg, loaded.config) + cfg.merge_with(override_config) # TODO except OmegaConfBaseException as e: raise ConfigCompositionException( f"In '{default.config_path}': {type(e).__name__} raised while" @@ -594,6 +604,38 @@ def compute_defaults_list( return defaults_list +def _convert_list_extend_overrides(base_cfg: Node, override_cfg: Node) -> Node: + if isinstance(base_cfg, ListConfig) and isinstance(override_cfg, DictConfig): + if "_extend_" in override_cfg and isinstance(override_cfg._extend_, ListConfig): + new_list = copy.copy(override_cfg._extend_) + new_list.extend(base_cfg) + return new_list + # This is invalid, we don't bother trying to handle it + return override_cfg + + if ( + isinstance(base_cfg, DictConfig) + and isinstance(override_cfg, DictConfig) + and not override_cfg._is_missing() + and not override_cfg._is_none() + ): + items_to_change = {} + for key in override_cfg.keys(): + value = override_cfg._get_node(key) + if not value: + continue + base_value = base_cfg._get_node(key, False) or ListConfig([]) + new_value = _convert_list_extend_overrides(base_value, value) + if not new_value is value: + items_to_change[key] = new_value + if len(items_to_change) > 0: + new_dict = copy.copy(override_cfg) + new_dict.update(items_to_change) + return new_dict + + return override_cfg + + def get_overrides_dirname( overrides: List[Override], exclude_keys: List[str], item_sep: str, kv_sep: str ) -> str: