From 30efcca4223653374940e08875c71e57b8d81e8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 20 Jul 2022 21:29:09 +0200 Subject: [PATCH] Add back test --- tests/tests_pytorch/test_cli.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 1847bca6e40046..79982db82879fb 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -517,6 +517,41 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai assert isinstance(cli.model.submodule2, BoringModel) +def test_lightning_cli_torch_modules(tmpdir): + class TestModule(BoringModel): + def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): + super().__init__() + self.activation = activation + self.transform = transform + + config = """model: + activation: + class_path: torch.nn.LeakyReLU + init_args: + negative_slope: 0.2 + transform: + - class_path: torchvision.transforms.Resize + init_args: + size: 64 + - class_path: torchvision.transforms.CenterCrop + init_args: + size: 64 + """ + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: + f.write(config) + + cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(TestModule, run=False) + + assert isinstance(cli.model.activation, torch.nn.LeakyReLU) + assert cli.model.activation.negative_slope == 0.2 + assert len(cli.model.transform) == 2 + assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform) + + class BoringModelRequiredClasses(BoringModel): def __init__(self, num_classes: int, batch_size: int = 8): super().__init__()