This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 90
/
evaluate.py
356 lines (307 loc) · 10.7 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import argparse
import json
import os
import random
import signal
import sys
import time
import urllib
from torch import nn, optim
from torchvision import datasets, transforms
import torch
import resnet
def get_arguments():
parser = argparse.ArgumentParser(
description="Evaluate a pretrained model on ImageNet"
)
# Data
parser.add_argument("--data-dir", type=Path, help="path to dataset")
parser.add_argument(
"--train-percent",
default=100,
type=int,
choices=(100, 10, 1),
help="size of traing set in percent",
)
# Checkpoint
parser.add_argument("--pretrained", type=Path, help="path to pretrained model")
parser.add_argument(
"--exp-dir",
default="./checkpoint/lincls/",
type=Path,
metavar="DIR",
help="path to checkpoint directory",
)
parser.add_argument(
"--print-freq", default=100, type=int, metavar="N", help="print frequency"
)
# Model
parser.add_argument("--arch", type=str, default="resnet50")
# Optim
parser.add_argument(
"--epochs",
default=100,
type=int,
metavar="N",
help="number of total epochs to run",
)
parser.add_argument(
"--batch-size", default=256, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument(
"--lr-backbone",
default=0.0,
type=float,
metavar="LR",
help="backbone base learning rate",
)
parser.add_argument(
"--lr-head",
default=0.3,
type=float,
metavar="LR",
help="classifier base learning rate",
)
parser.add_argument(
"--weight-decay", default=1e-6, type=float, metavar="W", help="weight decay"
)
parser.add_argument(
"--weights",
default="freeze",
type=str,
choices=("finetune", "freeze"),
help="finetune or freeze resnet weights",
)
# Running
parser.add_argument(
"--workers",
default=8,
type=int,
metavar="N",
help="number of data loader workers",
)
return parser
def main():
parser = get_arguments()
args = parser.parse_args()
if args.train_percent in {1, 10}:
args.train_files = urllib.request.urlopen(
f"https://raw.githubusercontent.com/google-research/simclr/master/imagenet_subsets/{args.train_percent}percent.txt"
).readlines()
args.ngpus_per_node = torch.cuda.device_count()
if "SLURM_JOB_ID" in os.environ:
signal.signal(signal.SIGUSR1, handle_sigusr1)
signal.signal(signal.SIGTERM, handle_sigterm)
# single-node distributed training
args.rank = 0
args.dist_url = f"tcp://localhost:{random.randrange(49152, 65535)}"
args.world_size = args.ngpus_per_node
torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)
def main_worker(gpu, args):
args.rank += gpu
torch.distributed.init_process_group(
backend="nccl",
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
if args.rank == 0:
args.exp_dir.mkdir(parents=True, exist_ok=True)
stats_file = open(args.exp_dir / "stats.txt", "a", buffering=1)
print(" ".join(sys.argv))
print(" ".join(sys.argv), file=stats_file)
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True
backbone, embedding = resnet.__dict__[args.arch](zero_init_residual=True)
state_dict = torch.load(args.pretrained, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
state_dict = {
key.replace("module.backbone.", ""): value
for (key, value) in state_dict.items()
}
backbone.load_state_dict(state_dict, strict=False)
head = nn.Linear(embedding, 1000)
head.weight.data.normal_(mean=0.0, std=0.01)
head.bias.data.zero_()
model = nn.Sequential(backbone, head)
model.cuda(gpu)
if args.weights == "freeze":
backbone.requires_grad_(False)
head.requires_grad_(True)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
criterion = nn.CrossEntropyLoss().cuda(gpu)
param_groups = [dict(params=head.parameters(), lr=args.lr_head)]
if args.weights == "finetune":
param_groups.append(dict(params=backbone.parameters(), lr=args.lr_backbone))
optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
# automatically resume from checkpoint if it exists
if (args.exp_dir / "checkpoint.pth").is_file():
ckpt = torch.load(args.exp_dir / "checkpoint.pth", map_location="cpu")
start_epoch = ckpt["epoch"]
best_acc = ckpt["best_acc"]
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
else:
start_epoch = 0
best_acc = argparse.Namespace(top1=0, top5=0)
# Data loading code
traindir = args.data_dir / "train"
valdir = args.data_dir / "val"
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
)
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
),
)
if args.train_percent in {1, 10}:
train_dataset.samples = []
for fname in args.train_files:
fname = fname.decode().strip()
cls = fname.split("_")[0]
train_dataset.samples.append(
(traindir / cls / fname, train_dataset.class_to_idx[cls])
)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
kwargs = dict(
batch_size=args.batch_size // args.world_size,
num_workers=args.workers,
pin_memory=True,
)
train_loader = torch.utils.data.DataLoader(
train_dataset, sampler=train_sampler, **kwargs
)
val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs)
start_time = time.time()
for epoch in range(start_epoch, args.epochs):
# train
if args.weights == "finetune":
model.train()
elif args.weights == "freeze":
model.eval()
else:
assert False
train_sampler.set_epoch(epoch)
for step, (images, target) in enumerate(
train_loader, start=epoch * len(train_loader)
):
output = model(images.cuda(gpu, non_blocking=True))
loss = criterion(output, target.cuda(gpu, non_blocking=True))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % args.print_freq == 0:
torch.distributed.reduce(loss.div_(args.world_size), 0)
if args.rank == 0:
pg = optimizer.param_groups
lr_head = pg[0]["lr"]
lr_backbone = pg[1]["lr"] if len(pg) == 2 else 0
stats = dict(
epoch=epoch,
step=step,
lr_backbone=lr_backbone,
lr_head=lr_head,
loss=loss.item(),
time=int(time.time() - start_time),
)
print(json.dumps(stats))
print(json.dumps(stats), file=stats_file)
# evaluate
model.eval()
if args.rank == 0:
top1 = AverageMeter("Acc@1")
top5 = AverageMeter("Acc@5")
with torch.no_grad():
for images, target in val_loader:
output = model(images.cuda(gpu, non_blocking=True))
acc1, acc5 = accuracy(
output, target.cuda(gpu, non_blocking=True), topk=(1, 5)
)
top1.update(acc1[0].item(), images.size(0))
top5.update(acc5[0].item(), images.size(0))
best_acc.top1 = max(best_acc.top1, top1.avg)
best_acc.top5 = max(best_acc.top5, top5.avg)
stats = dict(
epoch=epoch,
acc1=top1.avg,
acc5=top5.avg,
best_acc1=best_acc.top1,
best_acc5=best_acc.top5,
)
print(json.dumps(stats))
print(json.dumps(stats), file=stats_file)
scheduler.step()
if args.rank == 0:
state = dict(
epoch=epoch + 1,
best_acc=best_acc,
model=model.state_dict(),
optimizer=optimizer.state_dict(),
scheduler=scheduler.state_dict(),
)
torch.save(state, args.exp_dir / "checkpoint.pth")
def handle_sigusr1(signum, frame):
os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}')
exit()
def handle_sigterm(signum, frame):
pass
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == "__main__":
main()