Skip to content

Commit

Permalink
allow continuation of model training (#386)
Browse files Browse the repository at this point in the history
* allow continuation of model training

* ensure ckpt_path is present

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Jun 5, 2024
1 parent 4258d16 commit 4e8677d
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions cyto_dl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import hydra
import lightning
import pyrootutils
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers.logger import Logger
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -94,6 +95,16 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:

if cfg.get("train"):
log.info("Starting training!")

if cfg.get("weights_only"):
assert cfg.get(
"ckpt_path"
), "ckpt_path must be provided to with argument weights_only=True"
# load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights
state_dict = torch.load(cfg["ckpt_path"])["state_dict"]
model.load_state_dict(state_dict)
cfg["ckpt_path"] = None

if isinstance(data, LightningDataModule):
trainer.fit(model=model, datamodule=data, ckpt_path=cfg.get("ckpt_path"))
else:
Expand Down

0 comments on commit 4e8677d

Please sign in to comment.