Skip to content

Commit

Permalink
Support extending list from override
Browse files Browse the repository at this point in the history
  • Loading branch information
jesszzzz committed Jan 16, 2025
1 parent ca4d25c commit 3751505
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
from textwrap import dedent
from typing import Any, List, MutableSequence, Optional, Tuple

from omegaconf import Container, DictConfig, OmegaConf, flag_override, open_dict
from omegaconf import (
Container,
DictConfig,
ListConfig,
Node,
OmegaConf,
flag_override,
open_dict,
)
from omegaconf.errors import (
ConfigAttributeError,
ConfigKeyError,
Expand Down Expand Up @@ -475,7 +483,8 @@ def _load_single_config(
):
hydra = config.pop("hydra")

merged = OmegaConf.merge(schema.config, config)
override_cfg = _convert_list_extend_overrides(schema.config, config)
merged = OmegaConf.merge(schema.config, override_cfg) # TODO
assert isinstance(merged, DictConfig)

if hydra is not None:
Expand Down Expand Up @@ -548,7 +557,8 @@ def _compose_config_from_defaults_list(
for default in defaults:
loaded = self._load_single_config(default=default, repo=repo)
try:
cfg.merge_with(loaded.config)
override_config = _convert_list_extend_overrides(cfg, loaded.config)
cfg.merge_with(override_config) # TODO
except OmegaConfBaseException as e:
raise ConfigCompositionException(
f"In '{default.config_path}': {type(e).__name__} raised while"
Expand Down Expand Up @@ -594,6 +604,38 @@ def compute_defaults_list(
return defaults_list


def _convert_list_extend_overrides(base_cfg: Node, override_cfg: Node) -> Node:
if isinstance(base_cfg, ListConfig) and isinstance(override_cfg, DictConfig):
if "_extend_" in override_cfg and isinstance(override_cfg._extend_, ListConfig):
new_list = copy.copy(override_cfg._extend_)
new_list.extend(base_cfg)
return new_list
# This is invalid, we don't bother trying to handle it
return override_cfg

if (
isinstance(base_cfg, DictConfig)
and isinstance(override_cfg, DictConfig)
and not override_cfg._is_missing()
and not override_cfg._is_none()
):
items_to_change = {}
for key in override_cfg.keys():
value = override_cfg._get_node(key)
if not value:
continue
base_value = base_cfg._get_node(key, False) or ListConfig([])
new_value = _convert_list_extend_overrides(base_value, value)
if not new_value is value:
items_to_change[key] = new_value
if len(items_to_change) > 0:
new_dict = copy.copy(override_cfg)
new_dict.update(items_to_change)
return new_dict

return override_cfg


def get_overrides_dirname(
overrides: List[Override], exclude_keys: List[str], item_sep: str, kv_sep: str
) -> str:
Expand Down

0 comments on commit 3751505

Please sign in to comment.