Skip to content

Commit

Permalink
Add back test
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 20, 2022
1 parent 813baf0 commit 10ac9a1
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@
import pytest
import torch
import yaml
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR

from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.cli import (
_JSONARGPARSE_SIGNATURES_AVAILABLE,
instantiate_class,
Expand All @@ -39,13 +34,19 @@
LRSchedulerTypeTuple,
SaveConfigCallback,
)
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR

from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.utils import no_warning_call

Expand Down Expand Up @@ -517,6 +518,42 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai
assert isinstance(cli.model.submodule2, BoringModel)


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="Tests a bug with torchvision, but it's not available")
def test_lightning_cli_torch_modules(tmpdir):
class TestModule(BoringModel):
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
super().__init__()
self.activation = activation
self.transform = transform

config = """model:
activation:
class_path: torch.nn.LeakyReLU
init_args:
negative_slope: 0.2
transform:
- class_path: torchvision.transforms.Resize
init_args:
size: 64
- class_path: torchvision.transforms.CenterCrop
init_args:
size: 64
"""
config_path = tmpdir / "config.yaml"
with open(config_path, "w") as f:
f.write(config)

cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModule, run=False)

assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
assert cli.model.activation.negative_slope == 0.2
assert len(cli.model.transform) == 2
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)


class BoringModelRequiredClasses(BoringModel):
def __init__(self, num_classes: int, batch_size: int = 8):
super().__init__()
Expand Down

0 comments on commit 10ac9a1

Please sign in to comment.