From 4a9a962c5af50d394520491c7398c9b45dfd4d98 Mon Sep 17 00:00:00 2001 From: Anthony Bisulco Date: Fri, 17 Jul 2020 21:42:08 -0400 Subject: [PATCH] double define --- pl_examples/hydra_examples/conf/optimizer.py | 20 ++++--- pl_examples/hydra_examples/conf/scheduler.py | 44 +++++++++++--- pl_examples/models/hydra_config_model.py | 2 +- pytorch_lightning/trainer/trainer_conf.py | 60 ++++++++++++++++---- 4 files changed, 99 insertions(+), 27 deletions(-) diff --git a/pl_examples/hydra_examples/conf/optimizer.py b/pl_examples/hydra_examples/conf/optimizer.py index 9dd82a5354b20c..17bdbeaac94d4e 100644 --- a/pl_examples/hydra_examples/conf/optimizer.py +++ b/pl_examples/hydra_examples/conf/optimizer.py @@ -17,10 +17,10 @@ class AdamConf: cs.store( - group="opt", name="adam", node=ObjectConf(target="torch.optim.Adam", params=AdamConf()), + group="opt", name="adam", node=ObjectConf(target="torch.optim.Adam", cls="torch.optim.Adam", params=AdamConf()), ) cs.store( - group="opt", name="adamw", node=ObjectConf(target="torch.optim.AdamW", params=AdamConf()), + group="opt", name="adamw", node=ObjectConf(target="torch.optim.AdamW", cls="torch.optim.AdamW", params=AdamConf()), ) @@ -33,7 +33,9 @@ class AdamaxConf: cs.store( - group="opt", name="adamax", node=ObjectConf(target="torch.optim.Adamax", params=AdamaxConf()), + group="opt", + name="adamax", + node=ObjectConf(target="torch.optim.Adamax", cls="torch.optim.Adamax", params=AdamaxConf()), ) @@ -47,7 +49,7 @@ class ASGDConf: cs.store( - group="opt", name="asgd", node=ObjectConf(target="torch.optim.ASGD", params=ASGDConf()), + group="opt", name="asgd", node=ObjectConf(target="torch.optim.ASGD", cls="torch.optim.ASGD", params=ASGDConf()), ) @@ -63,7 +65,7 @@ class LBFGSConf: cs.store( - group="opt", name="lbfgs", node=ObjectConf(target="torch.optim.LBFGS", params=LBFGSConf()), + group="opt", name="lbfgs", node=ObjectConf(target="torch.optim.LBFGS", cls="torch.optim.LBFGS", params=LBFGSConf()), ) @@ -78,7 +80,9 @@ class RMSpropConf: cs.store( - group="opt", name="rmsprop", node=ObjectConf(target="torch.optim.RMSprop", params=RMSpropConf()), + group="opt", + name="rmsprop", + node=ObjectConf(target="torch.optim.RMSprop", cls="torch.optim.RMSprop", params=RMSpropConf()), ) @@ -90,7 +94,7 @@ class RpropConf: cs.store( - group="opt", name="rprop", node=ObjectConf(target="torch.optim.Rprop", params=RpropConf()), + group="opt", name="rprop", node=ObjectConf(target="torch.optim.Rprop", cls="torch.optim.Rprop", params=RpropConf()), ) @@ -104,5 +108,5 @@ class SGDConf: cs.store( - group="opt", name="sgd", node=ObjectConf(target="torch.optim.SGD", params=SGDConf()), + group="opt", name="sgd", node=ObjectConf(target="torch.optim.SGD", cls="torch.optim.SGD", params=SGDConf()), ) diff --git a/pl_examples/hydra_examples/conf/scheduler.py b/pl_examples/hydra_examples/conf/scheduler.py index 4081f4be611d7e..16feb48e98b140 100644 --- a/pl_examples/hydra_examples/conf/scheduler.py +++ b/pl_examples/hydra_examples/conf/scheduler.py @@ -17,7 +17,11 @@ class CosineConf: cs.store( group="scheduler", name="cosine", - node=ObjectConf(target="torch.optim.lr_scheduler.CosineAnnealingLR", params=CosineConf()), + node=ObjectConf( + target="torch.optim.lr_scheduler.CosineAnnealingLR", + cls="torch.optim.lr_scheduler.CosineAnnealingLR", + params=CosineConf(), + ), ) @@ -32,7 +36,11 @@ class CosineWarmConf: cs.store( group="scheduler", name="cosinewarm", - node=ObjectConf(target="torch.optim.lr_scheduler.CosineAnnealingLR", params=CosineWarmConf()), + node=ObjectConf( + target="torch.optim.lr_scheduler.CosineAnnealingLR", + cls="torch.optim.lr_scheduler.CosineAnnealingLR", + params=CosineWarmConf(), + ), ) @@ -53,7 +61,11 @@ class CyclicConf: cs.store( - group="scheduler", name="cyclic", node=ObjectConf(target="torch.optim.lr_scheduler.CyclicLR", params=CyclicConf()), + group="scheduler", + name="cyclic", + node=ObjectConf( + target="torch.optim.lr_scheduler.CyclicLR", cls="torch.optim.lr_scheduler.CyclicLR", params=CyclicConf() + ), ) @@ -66,7 +78,11 @@ class ExponentialConf: cs.store( group="scheduler", name="exponential", - node=ObjectConf(target="torch.optim.lr_scheduler.ExponentialLR", params=ExponentialConf()), + node=ObjectConf( + target="torch.optim.lr_scheduler.ExponentialLR", + cls="torch.optim.lr_scheduler.ExponentialLR", + params=ExponentialConf(), + ), ) @@ -86,7 +102,11 @@ class RedPlatConf: cs.store( group="scheduler", name="redplat", - node=ObjectConf(target="torch.optim.lr_scheduler.ReduceLROnPlateau", params=RedPlatConf()), + node=ObjectConf( + target="torch.optim.lr_scheduler.ReduceLROnPlateau", + cls="torch.optim.lr_scheduler.ReduceLROnPlateau", + params=RedPlatConf(), + ), ) @@ -100,7 +120,11 @@ class MultiStepConf: cs.store( group="scheduler", name="multistep", - node=ObjectConf(target="torch.optim.lr_scheduler.MultiStepLR", params=MultiStepConf()), + node=ObjectConf( + target="torch.optim.lr_scheduler.MultiStepLR", + cls="torch.optim.lr_scheduler.MultiStepLR", + params=MultiStepConf(), + ), ) @@ -123,7 +147,9 @@ class OneCycleConf: cs.store( group="scheduler", name="onecycle", - node=ObjectConf(target="torch.optim.lr_scheduler.OneCycleLR", params=OneCycleConf()), + node=ObjectConf( + target="torch.optim.lr_scheduler.OneCycleLR", cls="torch.optim.lr_scheduler.OneCycleLR", params=OneCycleConf() + ), ) @@ -135,5 +161,7 @@ class StepConf: cs.store( - group="scheduler", name="step", node=ObjectConf(target="torch.optim.lr_scheduler.StepLR", params=StepConf()), + group="scheduler", + name="step", + node=ObjectConf(target="torch.optim.lr_scheduler.StepLR", cls="torch.optim.lr_scheduler.StepLR", params=StepConf()), ) diff --git a/pl_examples/models/hydra_config_model.py b/pl_examples/models/hydra_config_model.py index 8e04333f7e862c..62e4d0ba35ff5b 100644 --- a/pl_examples/models/hydra_config_model.py +++ b/pl_examples/models/hydra_config_model.py @@ -18,7 +18,7 @@ class LightningTemplateModel(LightningModule): def __init__(self, cfg) -> "LightningTemplateModel": # init superclass super().__init__() - # self.save_hyperparameters() + self.save_hyperparameters() self.model = cfg.model self.data = cfg.data self.opt = cfg.opt diff --git a/pytorch_lightning/trainer/trainer_conf.py b/pytorch_lightning/trainer/trainer_conf.py index a0c152f4f0833a..1a5606b66a6d7f 100644 --- a/pytorch_lightning/trainer/trainer_conf.py +++ b/pytorch_lightning/trainer/trainer_conf.py @@ -75,7 +75,11 @@ class ModelCheckpointConf: cs.store( group="checkpoint", name="modelckpt", - node=ObjectConf(target="pytorch_lightning.callbacks.ModelCheckpoint", params=ModelCheckpointConf()), + node=ObjectConf( + target="pytorch_lightning.callbacks.ModelCheckpoint", + cls="pytorch_lightning.callbacks.ModelCheckpoint", + params=ModelCheckpointConf(), + ), ) @@ -92,7 +96,11 @@ class EarlyStoppingConf: cs.store( group="early_stopping", name="earlystop", - node=ObjectConf(target="pytorch_lightning.callbacks.EarlyStopping", params=EarlyStoppingConf()), + node=ObjectConf( + target="pytorch_lightning.callbacks.EarlyStopping", + cls="pytorch_lightning.callbacks.EarlyStopping", + params=EarlyStoppingConf(), + ), ) @@ -110,13 +118,21 @@ class AdvancedProfilerConf: cs.store( group="profiler", name="simple", - node=ObjectConf(target="pytorch_lightning.profiler.SimpleProfiler", params=SimpleProfilerConf()), + node=ObjectConf( + target="pytorch_lightning.profiler.SimpleProfiler", + cls="pytorch_lightning.profiler.SimpleProfiler", + params=SimpleProfilerConf(), + ), ) cs.store( group="profiler", name="advanced", - node=ObjectConf(target="pytorch_lightning.profiler.AdvancedProfiler", params=AdvancedProfilerConf()), + node=ObjectConf( + target="pytorch_lightning.profiler.AdvancedProfiler", + cls="pytorch_lightning.profiler.AdvancedProfiler", + params=AdvancedProfilerConf(), + ), ) @@ -134,7 +150,11 @@ class CometLoggerConf: cs.store( group="logger", name="comet", - node=ObjectConf(target="pytorch_lightning.loggers.comet.CometLogger", params=CometLoggerConf()), + node=ObjectConf( + target="pytorch_lightning.loggers.comet.CometLogger", + cls="pytorch_lightning.loggers.comet.CometLogger", + params=CometLoggerConf(), + ), ) @@ -149,7 +169,11 @@ class MLFlowLoggerConf: cs.store( group="logger", name="mlflow", - node=ObjectConf(target="pytorch_lightning.loggers.mlflow.MLFlowLogger", params=MLFlowLoggerConf()), + node=ObjectConf( + target="pytorch_lightning.loggers.mlflow.MLFlowLogger", + cls="pytorch_lightning.loggers.mlflow.MLFlowLogger", + params=MLFlowLoggerConf(), + ), ) @@ -169,7 +193,11 @@ class NeptuneLoggerConf: cs.store( group="logger", name="neptune", - node=ObjectConf(target="pytorch_lightning.loggers.neptune.NeptuneLogger", params=NeptuneLoggerConf()), + node=ObjectConf( + target="pytorch_lightning.loggers.neptune.NeptuneLogger", + cls="pytorch_lightning.loggers.neptune.NeptuneLogger", + params=NeptuneLoggerConf(), + ), ) @@ -183,7 +211,11 @@ class TensorboardLoggerConf: cs.store( group="logger", name="tensorboard", - node=ObjectConf(target="pytorch_lightning.loggers.tensorboard.TensorBoardLogger", params=TensorboardLoggerConf()), + node=ObjectConf( + target="pytorch_lightning.loggers.tensorboard.TensorBoardLogger", + cls="pytorch_lightning.loggers.tensorboard.TensorBoardLogger", + params=TensorboardLoggerConf(), + ), ) @@ -200,7 +232,11 @@ class TestTubeLoggerConf: cs.store( group="logger", name="testtube", - node=ObjectConf(target="pytorch_lightning.loggers.test_tube.TestTubeLogger", params=TestTubeLoggerConf()), + node=ObjectConf( + target="pytorch_lightning.loggers.test_tube.TestTubeLogger", + cls="pytorch_lightning.loggers.test_tube.TestTubeLogger", + params=TestTubeLoggerConf(), + ), ) @@ -223,7 +259,11 @@ class WandbConf: cs.store( group="logger", name="wandb", - node=ObjectConf(target="pytorch_lightning.loggers.wandb.WandbLogger", params=WandbConf()), + node=ObjectConf( + target="pytorch_lightning.loggers.wandb.WandbLogger", + cls="pytorch_lightning.loggers.wandb.WandbLogger", + params=WandbConf(), + ), ) # Bug cant merge config when set to None