Skip to content

Commit

Permalink
Add back test
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 20, 2022
1 parent 813baf0 commit 30efcca
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,41 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai
assert isinstance(cli.model.submodule2, BoringModel)


def test_lightning_cli_torch_modules(tmpdir):
class TestModule(BoringModel):
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
super().__init__()
self.activation = activation
self.transform = transform

config = """model:
activation:
class_path: torch.nn.LeakyReLU
init_args:
negative_slope: 0.2
transform:
- class_path: torchvision.transforms.Resize
init_args:
size: 64
- class_path: torchvision.transforms.CenterCrop
init_args:
size: 64
"""
config_path = tmpdir / "config.yaml"
with open(config_path, "w") as f:
f.write(config)

cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModule, run=False)

assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
assert cli.model.activation.negative_slope == 0.2
assert len(cli.model.transform) == 2
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)


class BoringModelRequiredClasses(BoringModel):
def __init__(self, num_classes: int, batch_size: int = 8):
super().__init__()
Expand Down

0 comments on commit 30efcca

Please sign in to comment.