-
Notifications
You must be signed in to change notification settings - Fork 83
/
imagenet.py
224 lines (194 loc) · 11.2 KB
/
imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import argparse
import csv
import os
import random
import sys
from datetime import datetime
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim
import torch.utils.data
from torch.optim.lr_scheduler import MultiStepLR
from tqdm import trange
import flops_benchmark
from clr import CyclicLR
from data import get_loaders
from logger import CsvLogger
from model import MobileNet2
from run import train, test, save_checkpoint, find_bounds_clr
parser = argparse.ArgumentParser(description='MobileNetv2 training with PyTorch')
parser.add_argument('--dataroot', required=True, metavar='PATH',
help='Path to ImageNet train and val folders, preprocessed as described in '
'https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset')
parser.add_argument('--gpus', default=None, help='List of GPUs used for training - e.g 0,1,3')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='Number of data loading workers (default: 4)')
parser.add_argument('--type', default='float32', help='Type of tensor: float32, float16, float64. Default: float32')
# Optimization options
parser.add_argument('--epochs', type=int, default=400, help='Number of epochs to train.')
parser.add_argument('-b', '--batch-size', default=64, type=int, metavar='N', help='mini-batch size (default: 64)')
parser.add_argument('--learning_rate', '-lr', type=float, default=0.01, help='The learning rate.')
parser.add_argument('--momentum', '-m', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', '-d', type=float, default=4e-5, help='Weight decay (L2 penalty).')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma at scheduled epochs.')
parser.add_argument('--schedule', type=int, nargs='+', default=[200, 300],
help='Decrease learning rate at these epochs.')
# CLR
parser.add_argument('--clr', dest='clr', action='store_true', help='Use CLR')
parser.add_argument('--min-lr', type=float, default=1e-5, help='Minimal LR for CLR.')
parser.add_argument('--max-lr', type=float, default=1, help='Maximal LR for CLR.')
parser.add_argument('--epochs-per-step', type=int, default=20,
help='Number of epochs per step in CLR, recommended to be between 2 and 10.')
parser.add_argument('--mode', default='triangular2', help='CLR mode. One of {triangular, triangular2, exp_range}')
parser.add_argument('--find-clr', dest='find_clr', action='store_true',
help='Run search for optimal LR in range (min_lr, max_lr)')
# Checkpoints
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='Just evaluate model')
parser.add_argument('--save', '-s', type=str, default='', help='Folder to save checkpoints.')
parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', help='Directory to store results')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='Number of batches between log messages')
parser.add_argument('--seed', type=int, default=None, metavar='S', help='random seed (default: 1)')
# Architecture
parser.add_argument('--scaling', type=float, default=1, metavar='SC', help='Scaling of MobileNet (default x1).')
parser.add_argument('--input-size', type=int, default=224, metavar='I',
help='Input size of MobileNet, multiple of 32 (default 224).')
# https://github.com/keras-team/keras/blob/fe066966b5afa96f2f6b9f71ec0c71158b44068d/keras/applications/mobilenetv2.py#L30
claimed_acc_top1 = {224: {1.4: 0.75, 1.3: 0.744, 1.0: 0.718, 0.75: 0.698, 0.5: 0.654, 0.35: 0.603},
192: {1.0: 0.707, 0.75: 0.687, 0.5: 0.639, 0.35: 0.582},
160: {1.0: 0.688, 0.75: 0.664, 0.5: 0.610, 0.35: 0.557},
128: {1.0: 0.653, 0.75: 0.632, 0.5: 0.577, 0.35: 0.508},
96: {1.0: 0.603, 0.75: 0.588, 0.5: 0.512, 0.35: 0.455},
}
claimed_acc_top5 = {224: {1.4: 0.925, 1.3: 0.921, 1.0: 0.910, 0.75: 0.896, 0.5: 0.864, 0.35: 0.829},
192: {1.0: 0.901, 0.75: 0.889, 0.5: 0.854, 0.35: 0.812},
160: {1.0: 0.890, 0.75: 0.873, 0.5: 0.832, 0.35: 0.791},
128: {1.0: 0.869, 0.75: 0.855, 0.5: 0.808, 0.35: 0.750},
96: {1.0: 0.832, 0.75: 0.816, 0.5: 0.758, 0.35: 0.704},
}
def main():
args = parser.parse_args()
if args.seed is None:
args.seed = random.randint(1, 10000)
print("Random Seed: ", args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpus:
torch.cuda.manual_seed_all(args.seed)
time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
if args.evaluate:
args.results_dir = '/tmp'
if args.save is '':
args.save = time_stamp
save_path = os.path.join(args.results_dir, args.save)
if not os.path.exists(save_path):
os.makedirs(save_path)
if args.gpus is not None:
args.gpus = [int(i) for i in args.gpus.split(',')]
device = 'cuda:' + str(args.gpus[0])
cudnn.benchmark = True
else:
device = 'cpu'
if args.type == 'float64':
dtype = torch.float64
elif args.type == 'float32':
dtype = torch.float32
elif args.type == 'float16':
dtype = torch.float16
else:
raise ValueError('Wrong type!') # TODO int8
model = MobileNet2(input_size=args.input_size, scale=args.scaling)
num_parameters = sum([l.nelement() for l in model.parameters()])
print(model)
print('number of parameters: {}'.format(num_parameters))
print('FLOPs: {}'.format(
flops_benchmark.count_flops(MobileNet2,
args.batch_size // len(args.gpus) if args.gpus is not None else args.batch_size,
device, dtype, args.input_size, 3, args.scaling)))
train_loader, val_loader = get_loaders(args.dataroot, args.batch_size, args.batch_size, args.input_size,
args.workers)
# define loss function (criterion) and optimizer
criterion = torch.nn.CrossEntropyLoss()
if args.gpus is not None:
model = torch.nn.DataParallel(model, args.gpus)
model.to(device=device, dtype=dtype)
criterion.to(device=device, dtype=dtype)
optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.decay,
nesterov=True)
if args.find_clr:
find_bounds_clr(model, train_loader, optimizer, criterion, device, dtype, min_lr=args.min_lr,
max_lr=args.max_lr, step_size=args.epochs_per_step * len(train_loader), mode=args.mode,
save_path=save_path)
return
if args.clr:
scheduler = CyclicLR(optimizer, base_lr=args.min_lr, max_lr=args.max_lr,
step_size=args.epochs_per_step * len(train_loader), mode=args.mode)
else:
scheduler = MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma)
best_test = 0
# optionally resume from a checkpoint
data = None
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location=device)
args.start_epoch = checkpoint['epoch'] - 1
best_test = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
elif os.path.isdir(args.resume):
checkpoint_path = os.path.join(args.resume, 'checkpoint.pth.tar')
csv_path = os.path.join(args.resume, 'results.csv')
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location=device)
args.start_epoch = checkpoint['epoch'] - 1
best_test = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
data = []
with open(csv_path) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
data.append(row)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
if args.evaluate:
loss, top1, top5 = test(model, val_loader, criterion, device, dtype) # TODO
return
csv_logger = CsvLogger(filepath=save_path, data=data)
csv_logger.save_params(sys.argv, args)
claimed_acc1 = None
claimed_acc5 = None
if args.input_size in claimed_acc_top1:
if args.scaling in claimed_acc_top1[args.input_size]:
claimed_acc1 = claimed_acc_top1[args.input_size][args.scaling]
claimed_acc5 = claimed_acc_top5[args.input_size][args.scaling]
csv_logger.write_text(
'Claimed accuracies are: {:.2f}% top-1, {:.2f}% top-5'.format(claimed_acc1 * 100., claimed_acc5 * 100.))
train_network(args.start_epoch, args.epochs, scheduler, model, train_loader, val_loader, optimizer, criterion,
device, dtype, args.batch_size, args.log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5,
best_test)
def train_network(start_epoch, epochs, scheduler, model, train_loader, val_loader, optimizer, criterion, device, dtype,
batch_size, log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, best_test):
for epoch in trange(start_epoch, epochs + 1):
if not isinstance(scheduler, CyclicLR):
scheduler.step()
train_loss, train_accuracy1, train_accuracy5, = train(model, train_loader, epoch, optimizer, criterion, device,
dtype, batch_size, log_interval, scheduler)
test_loss, test_accuracy1, test_accuracy5 = test(model, val_loader, criterion, device, dtype)
csv_logger.write({'epoch': epoch + 1, 'val_error1': 1 - test_accuracy1, 'val_error5': 1 - test_accuracy5,
'val_loss': test_loss, 'train_error1': 1 - train_accuracy1,
'train_error5': 1 - train_accuracy5, 'train_loss': train_loss})
save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_test,
'optimizer': optimizer.state_dict()}, test_accuracy1 > best_test, filepath=save_path)
csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5)
if test_accuracy1 > best_test:
best_test = test_accuracy1
csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test * 100.))
if __name__ == '__main__':
main()