-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathoptions.py
77 lines (71 loc) · 5.63 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os, sys
import fdgan.utils.util as util
from reid import models
from reid import datasets
class Options():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.parser.add_argument('--stage', type=int, default=1, help='training stage [1|2]')
self.parser.add_argument('-d', '--dataset', type=str, default='market1501', choices=datasets.names())
# paths
self.parser.add_argument('--dataroot', type=str, default='./datasets/', help='root path to datasets (should have subfolders market1501, dukemtmc, cuhk03, etc)')
self.parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='root path to save models')
self.parser.add_argument('--name', type=str, default='FD-GAN', help='directory to save models')
self.parser.add_argument('--netE-pretrain', type=str, default='pretrained model path for net_E in stage 2')
self.parser.add_argument('--netG-pretrain', type=str, default='pretrained model path for net_G in stage 2')
self.parser.add_argument('--netDp-pretrain', type=str, default='pretrained model path for net_Dp in stage 2')
self.parser.add_argument('--netDi-pretrain', type=str, default='pretrained model path for net_Di in stage 2')
# model structures
self.parser.add_argument('--arch', type=str, default='resnet50', choices=models.names())
self.parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization')
self.parser.add_argument('--drop', type=float, default=0.2, help='dropout for the netG')
self.parser.add_argument('--connect-layers', type=int, default=0, help='skip connections num for netG')
self.parser.add_argument('--fuse-mode', type=str, default='cat', help='method to fuse reid feature and pose feature [cat|add]')
self.parser.add_argument('--pose-feature-size', type=int, default=128, help='length of feature vector for pose')
self.parser.add_argument('--noise-feature-size', type=int, default=256, help='length of feature vector for noise')
self.parser.add_argument('--pose-aug', type=str, default='no', help='posemap augmentation [no|erase|gauss]')
# dataloader setting
self.parser.add_argument('-b', '--batch-size', type=int, default=16, help='input batch size')
self.parser.add_argument('-j', '--workers', default=4, type=int, help='num threads for loading data')
self.parser.add_argument('--width', type=int, default=128, help='input image width')
self.parser.add_argument('--height', type=int, default=256, help='input image height')
# optimizer setting
self.parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate')
self.parser.add_argument('--niter-decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate')
self.parser.add_argument('--save-step', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--eval-step', type=int, default=10, help='frequency of evaluate checkpoints at the end of epochs')
# visualization setting
self.parser.add_argument('--display-port', type=int, default=6006, help='visdom port of the web display')
self.parser.add_argument('--display-id', type=int, default=1, help='window id of the web display, set 0 for non-usage of visdom')
self.parser.add_argument('--display-winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display-freq', type=int, default=10, help='frequency of showing training results on screen')
self.parser.add_argument('--display-single-pane-ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
self.parser.add_argument('--update-html-freq', type=int, default=100, help='frequency of saving training results to html')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints]/name/web/')
self.parser.add_argument('--print-freq', type=int, default=10, help='frequency of showing training results on console')
# training schedule
self.parser.add_argument('--lambda-recon', type=float, default=1.0, help='loss weight of loss_r')
self.parser.add_argument('--lambda-veri', type=float, default=1.0, help='loss weight of loss_v')
self.parser.add_argument('--lambda-sp', type=float, default=1.0, help='loss weight of loss_sp')
self.parser.add_argument('--smooth-label', action='store_true', help='smooth label or not for GANloss')
self.opt = self.parser.parse_args()
self.show_opt()
def parse(self):
return self.opt
def show_opt(self):
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints, self.opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')