Skip to content

Commit

Permalink
cfg init removal
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonytec2 committed Jul 18, 2020
1 parent 50232e0 commit 06f57ad
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 25 deletions.
25 changes: 14 additions & 11 deletions pl_examples/hydra_examples/pl_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
"""

import hydra
from omegaconf import DictConfig
from pl_examples.hydra_examples.user_config import conf_init
import pl_examples.hydra_examples.user_config
from omegaconf import DictConfig, OmegaConf

from pl_examples.models.hydra_config_model import LightningTemplateModel
from pytorch_lightning import Callback, seed_everything, Trainer
Expand All @@ -20,21 +20,24 @@ def main(cfg: DictConfig):
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = LightningTemplateModel(cfg)
model = LightningTemplateModel(OmegaConf.masked_copy(cfg, ["data", "model", "scheduler", "opt"]))
# ------------------------
# 2 INIT TRAINER
# ------------------------

callbacks = (
[hydra.utils.instantiate(c) for c in cfg.callbacks.callbacks_list] if "callbacks_list" in cfg.callbacks else []
)
callbacks = None
if cfg.callbacks:
callbacks = (
[hydra.utils.instantiate(c) for c in cfg.callbacks.callbacks_list]
if "callbacks_list" in cfg.callbacks
else []
)

trainer = Trainer(
**cfg.trainer,
logger=conf_init(cfg, "logger"),
profiler=conf_init(cfg, "profiler"),
checkpoint_callback=conf_init(cfg, "checkpoint"),
early_stop_callback=conf_init(cfg, "early_stopping"),
logger=hydra.utils.instantiate(getattr(cfg, "logger")),
profiler=hydra.utils.instantiate(getattr(cfg, "profiler")),
checkpoint_callback=hydra.utils.instantiate(getattr(cfg, "checkpoint")),
early_stop_callback=hydra.utils.instantiate(getattr(cfg, "early_stopping")),
callbacks=callbacks,
)

Expand Down
10 changes: 1 addition & 9 deletions pl_examples/hydra_examples/user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,8 @@ class UserConfig(PLConfig):
model: Any = MISSING
scheduler: ObjectConf = MISSING
opt: ObjectConf = MISSING
callbacks: Any = MISSING
callbacks: Any = None


# Stored as config node, for top level config used for type checking.
cs.store(name="config", node=UserConfig)


def conf_init(cfg, key):
# Function to be removed after https://github.com/facebookresearch/hydra/issues/785 fixed
try:
return hydra.utils.instantiate(getattr(cfg, key))
except:
return None
2 changes: 1 addition & 1 deletion pl_examples/models/hydra_config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class LightningTemplateModel(LightningModule):
def __init__(self, cfg) -> "LightningTemplateModel":
# init superclass
super().__init__()
self.save_hyperparameters()
# self.save_hyperparameters()
self.model = cfg.model
self.data = cfg.data
self.opt = cfg.opt
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ class WandbConf:
# Bug cant merge config when set to None
@dataclass
class PLConfig(DictConfig):
logger: Optional[ObjectConf]
profiler: Optional[ObjectConf]
checkpoint: Optional[ObjectConf]
early_stopping: Optional[ObjectConf]
logger: Optional[ObjectConf] = None
profiler: Optional[ObjectConf] = None
checkpoint: Optional[ObjectConf] = None
early_stopping: Optional[ObjectConf] = None
trainer: LightningTrainerConf = MISSING

0 comments on commit 06f57ad

Please sign in to comment.