Skip to content

Commit

Permalink
Fix LightningCLI saving hyperparameters breaking change (#20068)
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored Jul 11, 2024
1 parent 50af052 commit 96b75df
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Avoid LightningCLI saving hyperparameters with `class_path` and `init_args` since this would be a breaking change ([#20068](https://github.com/Lightning-AI/pytorch-lightning/pull/20068))

-

Expand Down
24 changes: 19 additions & 5 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False))
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

Expand Down Expand Up @@ -791,19 +791,33 @@ def __init__(self, cli: LightningCLI, key: str) -> None:
self.key = key

def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
hparams = self.cli.config_dump.get(self.key, {})
if "class_path" in hparams:
# To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the
# parameters are stored directly, and the class_path in a special key `_class_path` to clarify its internal
# use.
hparams = {
"_class_path": hparams["class_path"],
**hparams.get("init_args", {}),
**hparams.get("dict_kwargs", {}),
}
with _given_hyperparameters_context(
hparams=self.cli.config_dump.get(self.key, {}),
hparams=hparams,
instantiator="lightning.pytorch.cli.instantiate_module",
):
return class_type(*args, **kwargs)


def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType:
parser = ArgumentParser(exit_on_error=False)
if "class_path" in config:
parser.add_subclass_arguments(class_type, "module")
if "_class_path" in config:
parser.add_subclass_arguments(class_type, "module", fail_untyped=False)
config = {
"class_path": config["_class_path"],
"dict_kwargs": {k: v for k, v in config.items() if k != "_class_path"},
}
else:
parser.add_class_arguments(class_type, "module")
parser.add_class_arguments(class_type, "module", fail_untyped=False)
cfg = parser.parse_object({"module": config})
init = parser.instantiate_classes(cfg)
return init.module
42 changes: 36 additions & 6 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,12 +900,10 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c

expected = {
"_instantiator": "lightning.pytorch.cli.instantiate_module",
"class_path": f"{__name__}.TestModelSaveHparams",
"init_args": {
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
},
"_class_path": f"{__name__}.TestModelSaveHparams",
"optimizer": "torch.optim.Adam",
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
}

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
Expand All @@ -922,6 +920,38 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
assert isinstance(lr_scheduler, torch.optim.lr_scheduler.ConstantLR)


class TestModelSaveHparamsUntyped(BoringModel):
def __init__(self, learning_rate, step_size=None, **kwargs):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.step_size = step_size
self.kwargs = kwargs


def test_lightning_cli_save_hyperparameters_untyped_module(cleandir):
config = {
"model": {
"class_path": f"{__name__}.TestModelSaveHparamsUntyped",
"init_args": {"learning_rate": 1e-2},
"dict_kwargs": {"x": 1},
}
}
with mock.patch("sys.argv", ["any.py", f"--config={json.dumps(config)}", "--trainer.max_epochs=1"]):
cli = LightningCLI(BoringModel, run=False, auto_configure_optimizers=False, subclass_mode_model=True)
cli.trainer.fit(cli.model)
assert isinstance(cli.model, TestModelSaveHparamsUntyped)
assert cli.model.hparams["learning_rate"] == 1e-2
assert cli.model.hparams["step_size"] is None
assert cli.model.hparams["x"] == 1

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
model = TestModelSaveHparamsUntyped.load_from_checkpoint(checkpoint_path)
assert model.learning_rate == 1e-2
assert model.step_size is None
assert model.kwargs == {"x": 1}


@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
Expand Down

0 comments on commit 96b75df

Please sign in to comment.