From cd04e2190f025b2ed934eddfffecb79b0b24e4fc Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Fri, 9 Aug 2024 09:43:50 -0700 Subject: [PATCH] add checkpoint key in config (#412) * add checkpoint keyin config * precommit * update checkpoint.ckpt_path, not ckpt_path --------- Co-authored-by: Benjamin Morris --- configs/eval.yaml | 5 ++++- configs/experiment/im2im/segmentation_plugin.yaml | 5 ++++- configs/train.yaml | 6 ++++-- cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py | 2 +- cyto_dl/compile.py | 2 +- cyto_dl/eval.py | 4 ++-- cyto_dl/train.py | 15 ++++++++------- cyto_dl/utils/template_utils.py | 2 +- tests/conftest.py | 2 +- tests/test_array_models.py | 2 +- tests/test_eval.py | 2 +- tests/test_train.py | 2 +- 12 files changed, 29 insertions(+), 20 deletions(-) diff --git a/configs/eval.yaml b/configs/eval.yaml index 352327d6d..b03052fb9 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -20,4 +20,7 @@ task_name: "eval" tags: ["dev"] # passing checkpoint path is necessary for evaluation -ckpt_path: ??? +checkpoint: + ckpt_path: ??? + weights_only: null + strict: True diff --git a/configs/experiment/im2im/segmentation_plugin.yaml b/configs/experiment/im2im/segmentation_plugin.yaml index f8c5a3ac5..f5fd632b5 100644 --- a/configs/experiment/im2im/segmentation_plugin.yaml +++ b/configs/experiment/im2im/segmentation_plugin.yaml @@ -17,7 +17,10 @@ defaults: tags: ["dev"] seed: 12345 -ckpt_path: null # must override for prediction +checkpoint: + ckpt_path: null # must override for prediction + weights_only: null + strict: False experiment_name: experiment_name run_name: run_name diff --git a/configs/train.yaml b/configs/train.yaml index 8b297edc7..f755de870 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -47,7 +47,9 @@ train: True test: True # simply provide checkpoint path to resume training -ckpt_path: null - +checkpoint: + ckpt_path: null + weights_only: null + strict: True # seed for random number generators in pytorch, numpy and python.random seed: null diff --git a/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py b/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py index 2261b951d..a06c4a740 100644 --- a/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py +++ b/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py @@ -108,7 +108,7 @@ def _set_training_config(self, train: bool): self._set_cfg("task_name", "train" if train else "predict") def _set_ckpt(self, ckpt: Optional[Path]) -> None: - self._set_cfg("ckpt_path", str(ckpt.resolve()) if ckpt else ckpt) + self._set_cfg("checkpoint.ckpt_path", str(ckpt.resolve()) if ckpt else ckpt) # does experiment name have any effect? def set_experiment_name(self, name: str) -> None: diff --git a/cyto_dl/compile.py b/cyto_dl/compile.py index d7ae37529..c05787b72 100644 --- a/cyto_dl/compile.py +++ b/cyto_dl/compile.py @@ -83,7 +83,7 @@ def compile(cfg: DictConfig) -> Tuple[dict, dict]: { "model_file": str(Path(pkg_root) / cfg.model_file), "handler": str(Path(pkg_root) / cfg.handler_file), - "serialized_file": cfg.ckpt_path, + "serialized_file": cfg.checkpoint.ckpt_path, "model_name": name, "version": version, "extra_files": str(cfg_path), diff --git a/cyto_dl/eval.py b/cyto_dl/eval.py index 7961a151b..9e37b4bd8 100644 --- a/cyto_dl/eval.py +++ b/cyto_dl/eval.py @@ -33,7 +33,7 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. """ - if not cfg.ckpt_path: + if not cfg.checkpoint.ckpt_path: raise ValueError("Checkpoint path must be included for testing") # resolve config to avoid unresolvable interpolations in the stored config @@ -84,7 +84,7 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]: log.info("Starting testing!") method = trainer.test if cfg.get("test", False) else trainer.predict - output = method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path) + output = method(model=model, dataloaders=data, ckpt_path=cfg.checkpoint.ckpt_path) metric_dict = trainer.callback_metrics return metric_dict, object_dict, output diff --git a/cyto_dl/train.py b/cyto_dl/train.py index 0c3e03c50..a6f64398d 100644 --- a/cyto_dl/train.py +++ b/cyto_dl/train.py @@ -96,23 +96,24 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]: if cfg.get("train"): log.info("Starting training!") - if cfg.get("weights_only"): - assert cfg.get( + load_params = cfg.get("checkpoint") + if load_params.get("weights_only"): + assert load_params.get( "ckpt_path" ), "ckpt_path must be provided to with argument weights_only=True" # load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights - state_dict = torch.load(cfg["ckpt_path"])["state_dict"] - model.load_state_dict(state_dict) - cfg["ckpt_path"] = None + state_dict = torch.load(load_params["ckpt_path"])["state_dict"] + model.load_state_dict(state_dict, strict=load_params.get("strict", True)) + load_params["ckpt_path"] = None if isinstance(data, LightningDataModule): - trainer.fit(model=model, datamodule=data, ckpt_path=cfg.get("ckpt_path")) + trainer.fit(model=model, datamodule=data, ckpt_path=load_params.get("ckpt_path")) else: trainer.fit( model=model, train_dataloaders=data.train_dataloaders, val_dataloaders=data.val_dataloaders, - ckpt_path=cfg.get("ckpt_path"), + ckpt_path=load_params.get("ckpt_path"), ) train_metrics = trainer.callback_metrics diff --git a/cyto_dl/utils/template_utils.py b/cyto_dl/utils/template_utils.py index 90ef51725..78bd0b856 100644 --- a/cyto_dl/utils/template_utils.py +++ b/cyto_dl/utils/template_utils.py @@ -181,7 +181,7 @@ def log_hyperparameters(object_dict: dict) -> None: hparams["task_name"] = cfg.get("task_name") hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["ckpt_path"] = cfg.checkpoint.get("ckpt_path") hparams["seed"] = cfg.get("seed") try: diff --git a/tests/conftest.py b/tests/conftest.py index 6c66b577a..48815e211 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,7 +53,7 @@ def cfg_eval_global(request) -> DictConfig: config_name="eval.yaml", return_hydra_config=True, overrides=[ - "ckpt_path=.", + "checkpoint.ckpt_path=.", f"experiment=im2im/{request.param}.yaml", "trainer=cpu.yaml", ], diff --git a/tests/test_array_models.py b/tests/test_array_models.py index 6d3695995..6c8071638 100644 --- a/tests/test_array_models.py +++ b/tests/test_array_models.py @@ -42,7 +42,7 @@ def test_array_train_predict(tmp_path): "logger": None, "trainer.accelerator": "cpu", "trainer.devices": 1, - "ckpt_path": ckpt_path, + "checkpoint.ckpt_path": ckpt_path, } model.load_default_experiment( diff --git a/tests/test_eval.py b/tests/test_eval.py index f5d20cfe8..1b86b5779 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -35,7 +35,7 @@ def test_train_eval(tmp_path, cfg_train, cfg_eval, spatial_dims): assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") with open_dict(cfg_eval): - cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + cfg_eval.checkpoint.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") cfg_eval.test = True cfg_eval.spatial_dims = spatial_dims diff --git a/tests/test_train.py b/tests/test_train.py index 22a88daf3..54c959f3d 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -63,7 +63,7 @@ def test_train_resume(tmp_path, cfg_train, spatial_dims): assert "epoch_000.ckpt" in files with open_dict(cfg_train): - cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + cfg_train.checkpoint.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") cfg_train.trainer.max_epochs = 2 metric_dict_2, _ = train(cfg_train)