-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtrain.py
executable file
·72 lines (60 loc) · 2.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import time, os, sys, glob, argparse
import importlib
import numpy as np
import torch
import MinkowskiEngine as ME
from data_loader import PCDataset, make_data_loader
from pcc_model import PCCModel
from trainer import Trainer
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--dataset", default='/home/ubuntu/HardDisk2/color_training_datasets/training_dataset/')
parser.add_argument("--dataset_num", type=int, default=2e4)
parser.add_argument("--alpha", type=float, default=1., help="weights for distoration.")
parser.add_argument("--beta", type=float, default=1., help="weights for bit rate.")
parser.add_argument("--init_ckpt", default='')
parser.add_argument("--lr", type=float, default=8e-4)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--epoch", type=int, default=50)
parser.add_argument("--check_time", type=float, default=10, help='frequency for recording state (min).')
parser.add_argument("--prefix", type=str, default='tp', help="prefix of checkpoints/logger, etc.")
args = parser.parse_args()
return args
class TrainingConfig():
def __init__(self, logdir, ckptdir, init_ckpt, alpha, beta, lr, check_time):
self.logdir = logdir
if not os.path.exists(self.logdir): os.makedirs(self.logdir)
self.ckptdir = ckptdir
if not os.path.exists(self.ckptdir): os.makedirs(self.ckptdir)
self.init_ckpt = init_ckpt
self.alpha = alpha
self.beta = beta
self.lr = lr
self.check_time=check_time
if __name__ == '__main__':
# log
args = parse_args()
training_config = TrainingConfig(
logdir=os.path.join('./logs', args.prefix),
ckptdir=os.path.join('./ckpts', args.prefix),
init_ckpt=args.init_ckpt,
alpha=args.alpha,
beta=args.beta,
lr=args.lr,
check_time=args.check_time)
# model
model = PCCModel()
# trainer
trainer = Trainer(config=training_config, model=model)
# dataset
filedirs = sorted(glob.glob(args.dataset+'*.h5'))[:int(args.dataset_num)]
train_dataset = PCDataset(filedirs[round(len(filedirs)/10):])
train_dataloader = make_data_loader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, repeat=False)
test_dataset = PCDataset(filedirs[:round(len(filedirs)/10)])
test_dataloader = make_data_loader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, repeat=False)
# training
for epoch in range(0, args.epoch):
if epoch>0: trainer.config.lr = max(trainer.config.lr/2, 1e-5)# update lr
trainer.train(train_dataloader)
trainer.test(test_dataloader, 'Test')