From 96b75df41a6c2eaa9a438d1ec2ec42ca1adf4f91 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Thu, 11 Jul 2024 12:38:18 +0200 Subject: [PATCH] Fix LightningCLI saving hyperparameters breaking change (#20068) --- src/lightning/pytorch/CHANGELOG.md | 2 +- src/lightning/pytorch/cli.py | 24 +++++++++++++---- tests/tests_pytorch/test_cli.py | 42 +++++++++++++++++++++++++----- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8e026f485fb65..4f8c32bcf8434 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Avoid LightningCLI saving hyperparameters with `class_path` and `init_args` since this would be a breaking change ([#20068](https://github.com/Lightning-AI/pytorch-lightning/pull/20068)) - diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 09f025b988089..26af335f7be93 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -534,7 +534,7 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No self.config = parser.parse_args(args) def _add_instantiators(self) -> None: - self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False)) + self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False)) if "subcommand" in self.config: self.config_dump = self.config_dump[self.config.subcommand] @@ -791,8 +791,18 @@ def __init__(self, cli: LightningCLI, key: str) -> None: self.key = key def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: + hparams = self.cli.config_dump.get(self.key, {}) + if "class_path" in hparams: + # To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the + # parameters are stored directly, and the class_path in a special key `_class_path` to clarify its internal + # use. + hparams = { + "_class_path": hparams["class_path"], + **hparams.get("init_args", {}), + **hparams.get("dict_kwargs", {}), + } with _given_hyperparameters_context( - hparams=self.cli.config_dump.get(self.key, {}), + hparams=hparams, instantiator="lightning.pytorch.cli.instantiate_module", ): return class_type(*args, **kwargs) @@ -800,10 +810,14 @@ def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> M def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: parser = ArgumentParser(exit_on_error=False) - if "class_path" in config: - parser.add_subclass_arguments(class_type, "module") + if "_class_path" in config: + parser.add_subclass_arguments(class_type, "module", fail_untyped=False) + config = { + "class_path": config["_class_path"], + "dict_kwargs": {k: v for k, v in config.items() if k != "_class_path"}, + } else: - parser.add_class_arguments(class_type, "module") + parser.add_class_arguments(class_type, "module", fail_untyped=False) cfg = parser.parse_object({"module": config}) init = parser.instantiate_classes(cfg) return init.module diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 35041d42dce1e..56b58d4d157a1 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -900,12 +900,10 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c expected = { "_instantiator": "lightning.pytorch.cli.instantiate_module", - "class_path": f"{__name__}.TestModelSaveHparams", - "init_args": { - "optimizer": "torch.optim.Adam", - "scheduler": "torch.optim.lr_scheduler.ConstantLR", - "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, - }, + "_class_path": f"{__name__}.TestModelSaveHparams", + "optimizer": "torch.optim.Adam", + "scheduler": "torch.optim.lr_scheduler.ConstantLR", + "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, } checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) @@ -922,6 +920,38 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR) +class TestModelSaveHparamsUntyped(BoringModel): + def __init__(self, learning_rate, step_size=None, **kwargs): + super().__init__() + self.save_hyperparameters() + self.learning_rate = learning_rate + self.step_size = step_size + self.kwargs = kwargs + + +def test_lightning_cli_save_hyperparameters_untyped_module(cleandir): + config = { + "model": { + "class_path": f"{__name__}.TestModelSaveHparamsUntyped", + "init_args": {"learning_rate": 1e-2}, + "dict_kwargs": {"x": 1}, + } + } + with mock.patch("sys.argv", ["any.py", f"--config={json.dumps(config)}", "--trainer.max_epochs=1"]): + cli = LightningCLI(BoringModel, run=False, auto_configure_optimizers=False, subclass_mode_model=True) + cli.trainer.fit(cli.model) + assert isinstance(cli.model, TestModelSaveHparamsUntyped) + assert cli.model.hparams["learning_rate"] == 1e-2 + assert cli.model.hparams["step_size"] is None + assert cli.model.hparams["x"] == 1 + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) + model = TestModelSaveHparamsUntyped.load_from_checkpoint(checkpoint_path) + assert model.learning_rate == 1e-2 + assert model.step_size is None + assert model.kwargs == {"x": 1} + + @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI):