From 188c6bea2465a429caaed711554624e1f8db0915 Mon Sep 17 00:00:00 2001 From: Rosario Scalise Date: Wed, 12 Aug 2020 05:10:17 -0700 Subject: [PATCH] Support **DictConfig hparam serialization (#2519) * change to OmegaConf API Co-authored-by: Omry Yadan * Swapped Container for OmegaConf sentinel; Limited ds copying * Add Namespace check. * Container removed. Pass local tests. Co-authored-by: Omry Yadan --- pytorch_lightning/core/saving.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 28501fdcda06a5..dea3fa99dd9d15 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -16,11 +16,9 @@ PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) try: - from omegaconf import Container + from omegaconf import OmegaConf except ImportError: - OMEGACONF_AVAILABLE = False -else: - OMEGACONF_AVAILABLE = True + OmegaConf = None # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ( @@ -330,22 +328,26 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: if not gfile.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") - if OMEGACONF_AVAILABLE and isinstance(hparams, Container): - from omegaconf import OmegaConf - - OmegaConf.save(hparams, config_yaml, resolve=True) - return - - # saving the standard way + # convert Namespace or AD to dict if isinstance(hparams, Namespace): hparams = vars(hparams) elif isinstance(hparams, AttributeDict): hparams = dict(hparams) - assert isinstance(hparams, dict) - with cloud_open(config_yaml, "w", newline="") as fp: - yaml.dump(hparams, fp) + # saving with OmegaConf objects + if OmegaConf is not None: + if OmegaConf.is_config(hparams): + OmegaConf.save(hparams, config_yaml, resolve=True) + return + for v in hparams.values(): + if OmegaConf.is_config(v): + OmegaConf.save(OmegaConf.create(hparams), config_yaml, resolve=True) + return + # saving the standard way + assert isinstance(hparams, dict) + with open(config_yaml, 'w', newline='') as fp: + yaml.dump(hparams, fp) def convert(val: str) -> Union[int, float, bool, str]: try: