-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
125 lines (101 loc) · 5.23 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import shutil
import datetime
import argparse
import torch
import numpy as np
def parse_train_args():
parser = argparse.ArgumentParser()
# parameters
# Model Selection
parser.add_argument('--model', type=str, default='resnet18')
parser.add_argument('--no-bias', dest='bias', action='store_false')
parser.add_argument('--ETF_fc', dest='ETF_fc', action='store_true')
parser.add_argument('--fixdim', dest='fixdim', type=int, default=0)
parser.add_argument('--SOTA', dest='SOTA', action='store_true')
# Hardware Setting
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--seed', type=int, default=6)
parser.add_argument('--use_cudnn', type=bool, default=True)
# Directory Setting
parser.add_argument('--dataset', type=str, choices=['cifar10', 'mini_imagenet'], default='cifar10')
parser.add_argument('--data_dir', type=str, default='../data')
parser.add_argument('--uid', type=str, default=None)
parser.add_argument('--force', action='store_true', help='force to override the given uid')
# Learning Options
parser.add_argument('--epochs', type=int, default=200, help='Max Epochs')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
parser.add_argument('--loss', type=str, default='MSE', help='loss function configuration')
parser.add_argument('--sample_size', type=int, default=None, help='sample size PER CLASS')
parser.add_argument('--M',type=float, default=1, help='length value for rescaled MSE')
parser.add_argument('--k', type=float, default=1, help='rescale value for rescaled MSE')
# Optimization specifications
parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
parser.add_argument('--patience', type=int, default=40, help='learning rate decay per N epochs')
parser.add_argument('--decay_type', type=str, default='step', help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.1, help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='SGD', help='optimizer to use')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
# The following two should be specified when testing adding wd on Features
parser.add_argument('--feature_decay_rate', type=float, default=1e-4, help='weight decay for last layer feature')
parser.add_argument('--history_size', type=int, default=10, help='history size for LBFGS')
args = parser.parse_args()
if args.uid is None:
unique_id = str(np.random.randint(0, 100000))
print("revise the unique id to a random number " + str(unique_id))
args.uid = unique_id
timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H-%M")
save_path = './model_weights/' + args.uid + '-' + timestamp
else:
save_path = './model_weights/' + str(args.uid)
if not os.path.exists(save_path):
os.makedirs(save_path, exist_ok=True)
else:
if not args.force:
raise ("please use another uid ")
else:
print("override this uid" + args.uid)
for m in range(1, 10):
if not os.path.exists(save_path + "/log.txt.bk" + str(m)):
shutil.copy(save_path + "/log.txt", save_path + "/log.txt.bk" + str(m))
shutil.copy(save_path + "/args.txt", save_path + "/args.txt.bk" + str(m))
break
parser.add_argument('--save_path', default=save_path, help='the output dir of weights')
parser.add_argument('--log', default=save_path + '/log.txt', help='the log file in training')
parser.add_argument('--arg', default=save_path + '/args.txt', help='the args used')
args = parser.parse_args()
with open(args.log, 'w') as f:
f.close()
with open(args.arg, 'w') as f:
print(args)
print(args, file=f)
f.close()
if args.use_cudnn:
print("cudnn is used")
torch.backends.cudnn.benchmark = True
else:
print("cudnn is not used")
torch.backends.cudnn.benchmark = False
return args
def parse_eval_args():
parser = argparse.ArgumentParser()
# parameters
# Model Selection
parser.add_argument('--model', type=str, default='resnet18')
parser.add_argument('--no-bias', dest='bias', action='store_false')
parser.add_argument('--ETF_fc', dest='ETF_fc', action='store_true')
parser.add_argument('--fixdim', dest='fixdim', type=int, default=0)
parser.add_argument('--SOTA', dest='SOTA', action='store_true')
# Hardware Setting
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--seed', type=int, default=6)
# Directory Setting
parser.add_argument('--dataset', type=str, choices=['cifar10', 'mini_imagenet'], default='cifar10')
parser.add_argument('--data_dir', type=str, default='../data')
parser.add_argument('--load_path', type=str, default=None)
# Learning Options
parser.add_argument('--epochs', type=int, default=200, help='Max Epochs')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
parser.add_argument('--sample_size', type=int, default=None, help='sample size PER CLASS')
args = parser.parse_args()
return args