Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using custom Environments and Strategies #608

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions configs/environment/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- _self_
1 change: 1 addition & 0 deletions configs/environment/lightning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: lightning.fabric.plugins.environments.LightningEnvironment
3 changes: 3 additions & 0 deletions configs/environment/slurm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: lightning.fabric.plugins.environments.SLURMEnvironment
auto_requeue: true
requeue_signal: null
2 changes: 2 additions & 0 deletions configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ defaults:
- data: mnist # choose datamodule with `test_dataloader()` for evaluation
- model: mnist
- logger: null
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default

task_name: "eval"

Expand Down
4 changes: 4 additions & 0 deletions configs/strategy/ddp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: false
gradient_as_bucket_view: false
find_unused_parameters: true
5 changes: 5 additions & 0 deletions configs/strategy/deepspeed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: lightning.pytorch.strategies.DeepSpeedStrategy
stage: 2
offload_optimizer: false
allgather_bucket_size: 200_000_000
reduce_bucket_size: 200_000_000
2 changes: 2 additions & 0 deletions configs/strategy/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- _self_
12 changes: 12 additions & 0 deletions configs/strategy/fsdp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_target_: lightning.pytorch.strategies.FSDPStrategy
sharding_strategy: ${resolve_variable:torch.distributed.fsdp.ShardingStrategy.FULL_SHARD}
cpu_offload: null
activation_checkpointing: null
mixed_precision:
_target_: torch.distributed.fsdp.MixedPrecision
param_dtype: null
reduce_dtype: null
buffer_dtype: null
keep_low_precision_grads: false
cast_forward_inputs: false
cast_root_forward_inputs: true
4 changes: 4 additions & 0 deletions configs/strategy/optimized_ddp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: lightning.pytorch.strategies.DDPStrategy
static_graph: true
gradient_as_bucket_view: true
find_unused_parameters: false
2 changes: 2 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ defaults:
- model: mnist
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- strategy: default
- trainer: default
- paths: default
- extras: default
- hydra: default
- environment: default

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
Expand Down
33 changes: 33 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import importlib
from typing import Any

from omegaconf import OmegaConf


def resolve_omegaconf_variable(variable_path: str) -> Any:
"""Resolve an OmegaConf variable path to its value."""
# split the string into parts using the dot separator
parts = variable_path.rsplit(".", 1)

# get the module name from the first part of the path
module_name = parts[0]

# dynamically import the module using the module name
try:
module = importlib.import_module(module_name)
# use the imported module to get the requested attribute value
attribute = getattr(module, parts[1])
except Exception:
module = importlib.import_module(".".join(module_name.split(".")[:-1]))
inner_module = ".".join(module_name.split(".")[-1:])
# use the imported module to get the requested attribute value
attribute = getattr(getattr(module, inner_module), parts[1])

return attribute


def register_custom_omegaconf_resolvers():
"""Register custom OmegaConf resolvers."""
OmegaConf.register_new_resolver(
"resolve_variable", lambda variable_path: resolve_omegaconf_variable(variable_path)
)
45 changes: 44 additions & 1 deletion src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import hydra
import rootutils
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies.strategy import Strategy
from omegaconf import DictConfig

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
Expand All @@ -24,6 +26,7 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
from src.utils import (
RankedLogger,
extras,
Expand Down Expand Up @@ -56,8 +59,47 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

plugins = None
if "_target_" in cfg.environment:
log.info(f"Instantiating environment <{cfg.environment._target_}>")
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)

strategy = getattr(cfg.trainer, "strategy", None)
if "_target_" in cfg.strategy:
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None:
strategy.mixed_precision.param_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
else None
)
strategy.mixed_precision.reduce_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
else None
)
strategy.mixed_precision.buffer_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
else None
)

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
trainer: Trainer = (
hydra.utils.instantiate(
cfg.trainer,
logger=logger,
plugins=plugins,
strategy=strategy,
)
if strategy is not None
else hydra.utils.instantiate(
cfg.trainer,
logger=logger,
plugins=plugins,
)
)

object_dict = {
"cfg": cfg,
Expand Down Expand Up @@ -96,4 +138,5 @@ def main(cfg: DictConfig) -> None:


if __name__ == "__main__":
register_custom_omegaconf_resolvers()
main()
7 changes: 6 additions & 1 deletion src/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
:param net: The model to train.
:param optimizer: The optimizer to use for training.
:param scheduler: The learning rate scheduler to use for training.
:param compile: Whether to compile the model before training.
"""
super().__init__()

Expand Down Expand Up @@ -198,7 +199,11 @@ def configure_optimizers(self) -> Dict[str, Any]:

:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
"""
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
try:
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
except TypeError:
# NOTE: strategies such as DeepSpeed require `params` to instead be specified as `model_params`
optimizer = self.hparams.optimizer(model_params=self.trainer.model.parameters())
if self.hparams.scheduler is not None:
scheduler = self.hparams.scheduler(optimizer=optimizer)
return {
Expand Down
58 changes: 55 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import hydra
import lightning as L
import rootutils
import torch
import os
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies.strategy import Strategy
from omegaconf import DictConfig

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
Expand All @@ -26,6 +28,7 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
from src.utils import (
RankedLogger,
extras,
Expand Down Expand Up @@ -66,8 +69,49 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

plugins = None
if "_target_" in cfg.environment:
log.info(f"Instantiating environment <{cfg.environment._target_}>")
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)

strategy = getattr(cfg.trainer, "strategy", None)
if "_target_" in cfg.strategy:
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None:
strategy.mixed_precision.param_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
else None
)
strategy.mixed_precision.reduce_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
else None
)
strategy.mixed_precision.buffer_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
else None
)

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
trainer: Trainer = (
hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
plugins=plugins,
strategy=strategy,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you can see here, now one can specify an optional strategy for Lightning to use e.g., via OmegaConf YAML config files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
if strategy is not None
else hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
plugins=plugins,
)
)

object_dict = {
"cfg": cfg,
Expand All @@ -84,7 +128,14 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
ckpt_path = None
if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")):
ckpt_path = cfg.get("ckpt_path")
elif cfg.get("ckpt_path"):
log.warning(
"`ckpt_path` was given, but the path does not exist. Training with new model weights."
)
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

train_metrics = trainer.callback_metrics

Expand Down Expand Up @@ -129,4 +180,5 @@ def main(cfg: DictConfig) -> Optional[float]:


if __name__ == "__main__":
register_custom_omegaconf_resolvers()
main()
Loading