-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_mvcnn_attention.py
executable file
·90 lines (70 loc) · 4.16 KB
/
train_mvcnn_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import os,shutil,json
import argparse
from tools.Trainer import ModelNetTrainer
from tools.ImgDataset import MultiviewImgDataset, SingleImgDataset
#from models.MVCNN import MVCNN, SVCNN
from models.MVCNN_attention import SVCNN, MVCNN_attention, MVCNN_self_attention, MVCNN_attention_fc
parser = argparse.ArgumentParser()
parser.add_argument("-name", "--name", type=str, help="Name of the experiment", default="MVCNN")
parser.add_argument("-bs", "--batchSize", type=int, help="Batch size for the second stage", default=8)# it will be *12 images in each batch for mvcnn
parser.add_argument("-num_models", type=int, help="number of models per class", default=1000)
parser.add_argument("-lr", type=float, help="learning rate", default=5e-5)
parser.add_argument("-weight_decay", type=float, help="weight decay", default=0.0)
parser.add_argument("-no_pretraining", dest='no_pretraining', action='store_true')
parser.add_argument("-cnn_name", "--cnn_name", type=str, help="cnn model name", default="vgg11")
parser.add_argument("-num_views", type=int, help="number of views", default=12)
parser.add_argument("-train_path", type=str, default="modelnet40_images_new_12x/*/train")
parser.add_argument("-val_path", type=str, default="modelnet40_images_new_12x/*/test")
parser.set_defaults(train=False)
#os.environ['CUDA_VISIBLE_DEVICES']='1'
def create_folder(log_dir):
# make summary folder
if not os.path.exists(log_dir):
os.mkdir(log_dir)
else:
print('WARNING: summary folder already exists!! It will be overwritten!!')
shutil.rmtree(log_dir)
os.mkdir(log_dir)
if __name__ == '__main__':
torch.random.manual_seed(10)
torch.cuda.manual_seed_all(10)
np.random.seed(10)
args = parser.parse_args()
pretraining = not args.no_pretraining
log_dir = args.name
create_folder(args.name)
config_f = open(os.path.join(log_dir, 'config.json'), 'w')
json.dump(vars(args), config_f)
config_f.close()
# STAGE 1
log_dir = args.name+'_stage_1'
create_folder(log_dir)
cnet = SVCNN(args.name, nclasses=40, pretraining=pretraining, cnn_name=args.cnn_name)
optimizer = optim.Adam(cnet.parameters(), lr=args.lr, weight_decay=args.weight_decay)
n_models_train = args.num_models*args.num_views
train_dataset = SingleImgDataset(args.train_path, scale_aug=False, rot_aug=False, num_models=n_models_train, num_views=args.num_views)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
val_dataset = SingleImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)
print('num_train_files: '+str(len(train_dataset.filepaths)))
print('num_val_files: '+str(len(val_dataset.filepaths)))
trainer = ModelNetTrainer(cnet, train_loader, val_loader, optimizer, nn.CrossEntropyLoss(), 'svcnn', log_dir, num_views=1)
trainer.train(30)
# STAGE 2
log_dir = args.name+'_stage_2'
create_folder(log_dir)
cnet_2 = MVCNN_attention(args.name, cnet, nclasses=40, cnn_name=args.cnn_name, num_views=args.num_views)
del cnet
optimizer = optim.Adam(cnet_2.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999))
train_dataset = MultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False, num_models=n_models_train, num_views=args.num_views)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=0)# shuffle needs to be false! it's done within the trainer
val_dataset = MultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, num_views=args.num_views,test_mode=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False, num_workers=0)
print('num_train_files: '+str(len(train_dataset.filepaths)))
print('num_val_files: '+str(len(val_dataset.filepaths)))
trainer = ModelNetTrainer(cnet_2, train_loader, val_loader, optimizer, nn.CrossEntropyLoss(), 'mvcnn', log_dir, num_views=args.num_views)
trainer.train(30)