Skip to content

Commit

Permalink
Support **DictConfig hparam serialization (Lightning-AI#2519)
Browse files Browse the repository at this point in the history
* change to OmegaConf API

Co-authored-by: Omry Yadan <[email protected]>

* Swapped Container for OmegaConf sentinel; Limited ds copying

* Add Namespace check.

* Container removed. Pass local tests.

Co-authored-by: Omry Yadan <[email protected]>
  • Loading branch information
2 people authored and atee committed Aug 17, 2020
1 parent 4a2d62f commit 188c6be
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
try:
from omegaconf import Container
from omegaconf import OmegaConf
except ImportError:
OMEGACONF_AVAILABLE = False
else:
OMEGACONF_AVAILABLE = True
OmegaConf = None

# the older shall be on the top
CHECKPOINT_PAST_HPARAMS_KEYS = (
Expand Down Expand Up @@ -330,22 +328,26 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
if not gfile.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
from omegaconf import OmegaConf

OmegaConf.save(hparams, config_yaml, resolve=True)
return

# saving the standard way
# convert Namespace or AD to dict
if isinstance(hparams, Namespace):
hparams = vars(hparams)
elif isinstance(hparams, AttributeDict):
hparams = dict(hparams)
assert isinstance(hparams, dict)

with cloud_open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)
# saving with OmegaConf objects
if OmegaConf is not None:
if OmegaConf.is_config(hparams):
OmegaConf.save(hparams, config_yaml, resolve=True)
return
for v in hparams.values():
if OmegaConf.is_config(v):
OmegaConf.save(OmegaConf.create(hparams), config_yaml, resolve=True)
return

# saving the standard way
assert isinstance(hparams, dict)
with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)

def convert(val: str) -> Union[int, float, bool, str]:
try:
Expand Down

0 comments on commit 188c6be

Please sign in to comment.