-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_linear_evaluation.py
140 lines (122 loc) · 7.34 KB
/
train_linear_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python
# Adapted for the paper "From Patches to Objects: Exploiting Spatial
# Reasoning for Better Visual Representations"
# Based on:
# The MIT License (MIT)
# Copyright (c) 2020 Massimiliano Patacchiola
# Paper: "Self-Supervised Relational Reasoning for Representation Learning", M. Patacchiola & A. Storkey, NeurIPS 2020
# GitHub: https://github.com/mpatacchiola/self-supervised-relational-reasoning
# Manage the linear evaluation phase on different datasets/backbones/methods.
# This script should be used after the self-supervised training phase to evaluate the methods.
# Example command:
#
# python train_linear_evaluation.py --dataset="cifar10" --method="relationnet" --backbone="conv4" --seed=3 --data_size=128 --gpu=0 --epochs=100 --checkpoint="./checkpoint/relationnet/cifar10/relationnet_cifar10_conv4_seed_3_epoch_200.tar"
import os
import argparse
parser = argparse.ArgumentParser(description="Linear evaluation script")
parser.add_argument("--seed", default=-1, type=int, help="Seed for Numpy and PyTorch. Default: -1 (None)")
parser.add_argument("--epochs", default=100, type=int, help="Total number of epochs")
parser.add_argument("--dataset", default="cifar10", help="Dataset: cifar10|100, supercifar100, tiny, slim, stl10")
parser.add_argument("--backbone", default="conv4", help="Backbone: conv4, resnet|8|32|34|56")
parser.add_argument("--method", default="standard", help="Method name (used just as string in the checkpoint name)")
parser.add_argument("--data_size", default=128, type=int, help="Size of the mini-batch")
parser.add_argument("--checkpoint", default="./", help="Address of the checkpoint file")
parser.add_argument("--finetune", default="False", type=str, help="Finetune the backbone during training (default: False)")
parser.add_argument("--num_workers", default=4, type=int, help="Number of torchvision workers used to load data (default: 8)")
parser.add_argument("--id", default="", help="Additional string appended when saving the checkpoints")
parser.add_argument("--gpu", default=0, type=int, help="GPU id in case of multiple GPUs")
parser.add_argument("--patchsize", default=0, type=int, help="Size of patches to be used in the linear evaluation, default: 0 (only image)")
parser.add_argument("--patchcount", default=9, type=int, help="Number of patches to be used in the linear evaluation, default: 0 (only image)")
parser.add_argument("--additive",action='store_true', help="Enables additive mode for patches")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
import torch
import torch.optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import argparse
import numpy as np
import random
if(args.seed>=0):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print("[INFO] Setting SEED: " + str(args.seed))
else:
print("[INFO] Setting SEED: None")
if(torch.cuda.is_available() == False): print("[WARNING] CUDA is not available.")
if(args.finetune=="True" or args.finetune=="true"):
print("[INFO] Finetune set to True, the backbone will be finetuned.")
print("[INFO] Found", str(torch.cuda.device_count()), "GPU(s) available.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Device type:", str(device))
from datamanager import DataManager
manager = DataManager(args.seed)
num_classes = manager.get_num_classes(args.dataset)
train_transform = manager.get_train_transforms("lineval", args.dataset)
if(args.finetune=="True" or args.finetune=="true"):
train_transform = manager.get_train_transforms("finetune", args.dataset)
train_loader, _ = manager.get_train_loader(dataset=args.dataset,
data_type="single" if args.patchsize==0 else "patch_eval",
data_size=args.data_size,
train_transform=train_transform,
repeat_augmentations=None,
num_workers=args.num_workers,
drop_last=False,
patch_size=args.patchsize,
patch_count=args.patchcount,
additive=args.additive)
test_loader = manager.get_test_loader(args.dataset, args.data_size, patch_size=args.patchsize,num_workers=args.num_workers,patch_count=args.patchcount)
if(args.backbone=="conv4"):
from backbones.conv4 import Conv4
feature_extractor = Conv4(flatten=True)
elif(args.backbone=="resnet8"):
from backbones.resnet_small import ResNet, BasicBlock
feature_extractor = ResNet(BasicBlock, [1, 1, 1], channels=[16, 32, 64], flatten=True)
elif(args.backbone=="resnet32"):
from backbones.resnet_small import ResNet, BasicBlock
feature_extractor = ResNet(BasicBlock, [5, 5, 5], channels=[16, 32, 64], flatten=True)
elif(args.backbone=="resnet56"):
from backbones.resnet_small import ResNet, BasicBlock
feature_extractor = ResNet(BasicBlock, [9, 9, 9], channels=[16, 32, 64], flatten=True)
elif(args.backbone=="resnet34"):
from backbones.resnet_large import ResNet, BasicBlock
feature_extractor = ResNet(BasicBlock, layers=[3, 4, 6, 3],zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None)
else:
raise RuntimeError("[ERROR] the backbone " + str(args.backbone) + " is not supported.")
print("[INFO]", str(str(args.backbone)), "loaded in memory.")
print("[INFO] Feature size:", str(feature_extractor.feature_size))
def main():
print("[INFO] Loading checkpoint...")
checkpoint = torch.load(args.checkpoint,map_location=torch.device(device))
feature_extractor.load_state_dict(checkpoint["backbone"])
from methods.standard import StandardModel
model = StandardModel(feature_extractor, num_classes, patch_size=args.patchsize, patch_count=args.patchcount)
model.to(device)
if not os.path.exists("./checkpoint/"+str(args.method)+"/"+str(args.dataset)):
os.makedirs("./checkpoint/"+str(args.method)+"/"+str(args.dataset))
if(args.id!=""):
header = str(args.method)+ "_" + str(args.id) + "_" + str(args.dataset) + "_" + str(args.backbone) + "_seed_" + str(args.seed)
else:
header = str(args.method) + "_" + str(args.dataset) + "_" + str(args.backbone) + "_seed_" + str(args.seed)
loss_list = {}
for epoch in range(0, args.epochs):
if(args.finetune=="True" or args.finetune=="true"):
print(f"Current epoch {epoch}")
loss_train, accuracy_train = model.finetune(epoch, train_loader)
loss_test, accuracy_test = model.test(test_loader)
loss_list[epoch] = accuracy_test
print("Test accuracy: " + str(accuracy_test) + "%")
else:
loss_train, accuracy_train = model.linear_evaluation(epoch, train_loader)
pass
print(loss_list)
loss_test, accuracy_test = model.test(test_loader)
print("Test accuracy: " + str(accuracy_test) + "%")
if __name__== "__main__": main()