Skip to content

Commit

Permalink
add checkpoint key in config (#412)
Browse files Browse the repository at this point in the history
* add checkpoint keyin config

* precommit

* update checkpoint.ckpt_path, not ckpt_path

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Aug 9, 2024
1 parent cee8402 commit cd04e21
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 20 deletions.
5 changes: 4 additions & 1 deletion configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@ task_name: "eval"
tags: ["dev"]

# passing checkpoint path is necessary for evaluation
ckpt_path: ???
checkpoint:
ckpt_path: ???
weights_only: null
strict: True
5 changes: 4 additions & 1 deletion configs/experiment/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ defaults:
tags: ["dev"]
seed: 12345

ckpt_path: null # must override for prediction
checkpoint:
ckpt_path: null # must override for prediction
weights_only: null
strict: False

experiment_name: experiment_name
run_name: run_name
Expand Down
6 changes: 4 additions & 2 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ train: True
test: True

# simply provide checkpoint path to resume training
ckpt_path: null

checkpoint:
ckpt_path: null
weights_only: null
strict: True
# seed for random number generators in pytorch, numpy and python.random
seed: null
2 changes: 1 addition & 1 deletion cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _set_training_config(self, train: bool):
self._set_cfg("task_name", "train" if train else "predict")

def _set_ckpt(self, ckpt: Optional[Path]) -> None:
self._set_cfg("ckpt_path", str(ckpt.resolve()) if ckpt else ckpt)
self._set_cfg("checkpoint.ckpt_path", str(ckpt.resolve()) if ckpt else ckpt)

# does experiment name have any effect?
def set_experiment_name(self, name: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion cyto_dl/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compile(cfg: DictConfig) -> Tuple[dict, dict]:
{
"model_file": str(Path(pkg_root) / cfg.model_file),
"handler": str(Path(pkg_root) / cfg.handler_file),
"serialized_file": cfg.ckpt_path,
"serialized_file": cfg.checkpoint.ckpt_path,
"model_name": name,
"version": version,
"extra_files": str(cfg_path),
Expand Down
4 changes: 2 additions & 2 deletions cyto_dl/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""

if not cfg.ckpt_path:
if not cfg.checkpoint.ckpt_path:
raise ValueError("Checkpoint path must be included for testing")

# resolve config to avoid unresolvable interpolations in the stored config
Expand Down Expand Up @@ -84,7 +84,7 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]:

log.info("Starting testing!")
method = trainer.test if cfg.get("test", False) else trainer.predict
output = method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path)
output = method(model=model, dataloaders=data, ckpt_path=cfg.checkpoint.ckpt_path)
metric_dict = trainer.callback_metrics

return metric_dict, object_dict, output
Expand Down
15 changes: 8 additions & 7 deletions cyto_dl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,24 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]:
if cfg.get("train"):
log.info("Starting training!")

if cfg.get("weights_only"):
assert cfg.get(
load_params = cfg.get("checkpoint")
if load_params.get("weights_only"):
assert load_params.get(
"ckpt_path"
), "ckpt_path must be provided to with argument weights_only=True"
# load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights
state_dict = torch.load(cfg["ckpt_path"])["state_dict"]
model.load_state_dict(state_dict)
cfg["ckpt_path"] = None
state_dict = torch.load(load_params["ckpt_path"])["state_dict"]
model.load_state_dict(state_dict, strict=load_params.get("strict", True))
load_params["ckpt_path"] = None

if isinstance(data, LightningDataModule):
trainer.fit(model=model, datamodule=data, ckpt_path=cfg.get("ckpt_path"))
trainer.fit(model=model, datamodule=data, ckpt_path=load_params.get("ckpt_path"))
else:
trainer.fit(
model=model,
train_dataloaders=data.train_dataloaders,
val_dataloaders=data.val_dataloaders,
ckpt_path=cfg.get("ckpt_path"),
ckpt_path=load_params.get("ckpt_path"),
)

train_metrics = trainer.callback_metrics
Expand Down
2 changes: 1 addition & 1 deletion cyto_dl/utils/template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def log_hyperparameters(object_dict: dict) -> None:

hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["ckpt_path"] = cfg.checkpoint.get("ckpt_path")
hparams["seed"] = cfg.get("seed")

try:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def cfg_eval_global(request) -> DictConfig:
config_name="eval.yaml",
return_hydra_config=True,
overrides=[
"ckpt_path=.",
"checkpoint.ckpt_path=.",
f"experiment=im2im/{request.param}.yaml",
"trainer=cpu.yaml",
],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_array_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_array_train_predict(tmp_path):
"logger": None,
"trainer.accelerator": "cpu",
"trainer.devices": 1,
"ckpt_path": ckpt_path,
"checkpoint.ckpt_path": ckpt_path,
}

model.load_default_experiment(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_train_eval(tmp_path, cfg_train, cfg_eval, spatial_dims):
assert "last.ckpt" in os.listdir(tmp_path / "checkpoints")

with open_dict(cfg_eval):
cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
cfg_eval.checkpoint.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
cfg_eval.test = True

cfg_eval.spatial_dims = spatial_dims
Expand Down
2 changes: 1 addition & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_train_resume(tmp_path, cfg_train, spatial_dims):
assert "epoch_000.ckpt" in files

with open_dict(cfg_train):
cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
cfg_train.checkpoint.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
cfg_train.trainer.max_epochs = 2

metric_dict_2, _ = train(cfg_train)
Expand Down

0 comments on commit cd04e21

Please sign in to comment.