-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
46 lines (43 loc) · 1.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.logging import TensorBoardLogger
from config.SAND_pix_opt import TrainOptions
from model.total_model.model_gpu import *
opt = TrainOptions().parse()
def train():
model = SAND_pix_Gen_Parsing(opt)
model_save_path = "{}/{}/{}/{}/{}".format(opt.checkpoints_dir, opt.name, opt.ver, opt.dataset_name, opt.log_name)
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_top_k=2,
save_weights_only=True,
monitor='fid',
mode='min',
verbose=True
)
logger_path = "{}/{}/{}".format(opt.log_dir, opt.name, opt.dataset_name)
logger = TensorBoardLogger(
save_dir=logger_path,
name=opt.log_name,
version=opt.ver
)
if opt.use_amp:
amp_level = 'O2'
else:
amp_level = 'O0'
trainer = Trainer(
precision=16,
fast_dev_run=opt.debug,
logger=logger,
max_epochs=opt.train_epoch,
min_epochs=opt.train_epoch,
checkpoint_callback=checkpoint,
gpus=opt.gpu,
check_val_every_n_epoch=opt.every_val_epoch,
log_save_interval=2000,
use_amp=opt.use_amp,
amp_level=amp_level
)
#trainer = trainer.contiguous()
trainer.fit(model)
train()