forked from optuna/optuna-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pytorch_lightning_simple.py
173 lines (134 loc) · 5.72 KB
/
pytorch_lightning_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Optuna example that optimizes multi-layer perceptrons using PyTorch Lightning.
In this example, we optimize the validation accuracy of hand-written digit recognition using
PyTorch Lightning, and FashionMNIST. We optimize the neural network architecture. As it is too time
consuming to use the whole FashionMNIST dataset, we here use a small subset of it.
You can run this example as follows, pruning can be turned on and off with the `--pruning`
argument.
$ python pytorch_lightning_simple.py [--pruning]
"""
import argparse
import os
from typing import List
from typing import Optional
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from packaging import version
import pytorch_lightning as pl
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import datasets
from torchvision import transforms
if version.parse(pl.__version__) < version.parse("1.0.2"):
raise RuntimeError("PyTorch Lightning>=1.0.2 is required for this example.")
PERCENT_VALID_EXAMPLES = 0.1
BATCHSIZE = 128
CLASSES = 10
EPOCHS = 10
DIR = os.getcwd()
class Net(nn.Module):
def __init__(self, dropout: float, output_dims: List[int]):
super().__init__()
layers: List[nn.Module] = []
input_dim: int = 28 * 28
for output_dim in output_dims:
layers.append(nn.Linear(input_dim, output_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout))
input_dim = output_dim
layers.append(nn.Linear(input_dim, CLASSES))
self.layers: nn.Module = nn.Sequential(*layers)
def forward(self, data: torch.Tensor) -> torch.Tensor:
logits = self.layers(data)
return F.log_softmax(logits, dim=1)
class LightningNet(pl.LightningModule):
def __init__(self, dropout: float, output_dims: List[int]):
super().__init__()
self.model = Net(dropout, output_dims)
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.model(data.view(-1, 28 * 28))
def training_step(self, batch, batch_idx: int) -> torch.Tensor:
data, target = batch
output = self(data)
return F.nll_loss(output, target)
def validation_step(self, batch, batch_idx: int) -> None:
data, target = batch
output = self(data)
pred = output.argmax(dim=1, keepdim=True)
accuracy = pred.eq(target.view_as(pred)).float().mean()
self.log("val_acc", accuracy)
self.log("hp_metric", accuracy, on_step=False, on_epoch=True)
def configure_optimizers(self) -> optim.Optimizer:
return optim.Adam(self.model.parameters())
class FashionMNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str, batch_size: int):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: Optional[str] = None) -> None:
self.mnist_test = datasets.FashionMNIST(
self.data_dir, train=False, download=True, transform=transforms.ToTensor()
)
mnist_full = datasets.FashionMNIST(
self.data_dir, train=True, download=True, transform=transforms.ToTensor()
)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.mnist_train, batch_size=self.batch_size, shuffle=True, pin_memory=True
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.mnist_val, batch_size=self.batch_size, shuffle=False, pin_memory=True
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.mnist_test, batch_size=self.batch_size, shuffle=False, pin_memory=True
)
def objective(trial: optuna.trial.Trial) -> float:
# We optimize the number of layers, hidden units in each layer and dropouts.
n_layers = trial.suggest_int("n_layers", 1, 3)
dropout = trial.suggest_float("dropout", 0.2, 0.5)
output_dims = [
trial.suggest_int("n_units_l{}".format(i), 4, 128, log=True) for i in range(n_layers)
]
model = LightningNet(dropout, output_dims)
datamodule = FashionMNISTDataModule(data_dir=DIR, batch_size=BATCHSIZE)
trainer = pl.Trainer(
logger=True,
limit_val_batches=PERCENT_VALID_EXAMPLES,
checkpoint_callback=False,
max_epochs=EPOCHS,
gpus=1 if torch.cuda.is_available() else None,
callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")],
)
hyperparameters = dict(n_layers=n_layers, dropout=dropout, output_dims=output_dims)
trainer.logger.log_hyperparams(hyperparameters)
trainer.fit(model, datamodule=datamodule)
return trainer.callback_metrics["val_acc"].item()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch Lightning example.")
parser.add_argument(
"--pruning",
"-p",
action="store_true",
help="Activate the pruning feature. `MedianPruner` stops unpromising "
"trials at the early stages of training.",
)
args = parser.parse_args()
pruner: optuna.pruners.BasePruner = (
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
)
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))