-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
48 lines (38 loc) · 1.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
from argparse import ArgumentParser
import numpy as np
import torch
from Unet import Unet
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
def main(hparams):
model = Unet(hparams)
os.makedirs(hparams.log_dir, exist_ok=True)
try:
log_dir = sorted(os.listdir(hparams.log_dir))[-1]
except IndexError:
log_dir = os.path.join(hparams.log_dir, 'version_0')
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(log_dir, 'checkpoints'),
save_best_only=False,
verbose=True,
)
stop_callback = EarlyStopping(
monitor='val_loss',
mode='auto',
patience=5,
verbose=True,
)
trainer = Trainer(
gpus=1,
checkpoint_callback=checkpoint_callback,
early_stop_callback=stop_callback,
)
trainer.fit(model)
if __name__ == '__main__':
parent_parser = ArgumentParser(add_help=False)
parent_parser.add_argument('--dataset', required=True)
parent_parser.add_argument('--log_dir', default='lightning_logs')
parser = Unet.add_model_specific_args(parent_parser)
hparams = parser.parse_args()
main(hparams)