forked from ZJULearning/pixel_link
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
226 lines (177 loc) · 7.43 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from __future__ import print_function
from pprint import pprint
import numpy as np
from tensorflow.contrib.slim.python.slim.data import parallel_reader
import tensorflow as tf
import util
from nets import pixel_link_symbol
import pixel_link
slim = tf.contrib.slim
#=====================================================================
#====================Pre-processing params START======================
# VGG mean parameters.
r_mean = 123.
g_mean = 117.
b_mean = 104.
rgb_mean = [r_mean, g_mean, b_mean]
# scale, crop, filtering and resize parameters
use_rotation = True
rotation_prob = 0.5
max_expand_scale = 1
expand_prob = 0
min_object_covered = 0.1 # Minimum object to be cropped in random crop.
bbox_crop_overlap = 0.2 # Minimum overlap to keep a bbox after cropping.
crop_aspect_ratio_range = (0.5, 2.) # Distortion ratio during cropping.
area_range = [0.1, 1]
flip = False
using_shorter_side_filtering=True
min_shorter_side = 10
max_shorter_side = np.infty
#====================Pre-processing params END========================
#=====================================================================
#=====================================================================
#====================Post-processing params START=====================
decode_method = pixel_link.DECODE_METHOD_join
min_area = 300
min_height = 10
#====================Post-processing params END=======================
#=====================================================================
#=====================================================================
#====================Training and model params START =================
dropout_ratio = 0
max_neg_pos_ratio = 3
feat_fuse_type = pixel_link_symbol.FUSE_TYPE_cascade_conv1x1_upsample_sum
# feat_fuse_type = pixel_link_symbol.FUSE_TYPE_cascade_conv1x1_128_upsamle_sum_conv1x1_2
# feat_fuse_type = pixel_link_symbol.FUSE_TYPE_cascade_conv1x1_128_upsamle_concat_conv1x1_2
pixel_neighbour_type = pixel_link.PIXEL_NEIGHBOUR_TYPE_8
#pixel_neighbour_type = pixel_link.PIXEL_NEIGHBOUR_TYPE_4
#model_type = pixel_link_symbol.MODEL_TYPE_vgg16
#feat_layers = ['conv2_2', 'conv3_3', 'conv4_3', 'conv5_3', 'fc7']
#strides = [2]
model_type = pixel_link_symbol.MODEL_TYPE_vgg16
feat_layers = ['conv3_3', 'conv4_3', 'conv5_3', 'fc7']
strides = [4]
pixel_cls_weight_method = pixel_link.PIXEL_CLS_WEIGHT_bbox_balanced
bbox_border_width = 1
pixel_cls_border_weight_lambda = 1.0
pixel_cls_loss_weight_lambda = 2.0
pixel_link_neg_loss_weight_lambda = 1.0
pixel_link_loss_weight = 1.0
#====================Training and model params END ==================
#=====================================================================
#=====================================================================
#====================do-not-change configurations START===============
num_classes = 2
ignore_label = -1
background_label = 0
text_label = 1
data_format = 'NHWC'
train_with_ignored = False
#====================do-not-change configurations END=================
#=====================================================================
global weight_decay
global train_image_shape
global image_shape
global score_map_shape
global batch_size
global batch_size_per_gpu
global gpus
global num_clones
global clone_scopes
global num_neighbours
global pixel_conf_threshold
global link_conf_threshold
def _set_weight_decay(wd):
global weight_decay
weight_decay = wd
def _set_image_shape(shape):
h, w = shape
global train_image_shape
global score_map_shape
global image_shape
assert w % 4 == 0
assert h % 4 == 0
train_image_shape = [h, w]
score_map_shape = (h / strides[0], w / strides[0])
image_shape = train_image_shape
def _set_batch_size(bz):
global batch_size
batch_size = bz
def _set_seg_th(pixel_conf_th, link_conf_th):
global pixel_conf_threshold
global link_conf_threshold
pixel_conf_threshold = pixel_conf_th
link_conf_threshold = link_conf_th
def _set_train_with_ignored(train_with_ignored_):
global train_with_ignored
train_with_ignored = train_with_ignored_
def init_config(image_shape, batch_size = 1,
weight_decay = 0.0005,
num_gpus = 1,
pixel_conf_threshold = 0.6,
link_conf_threshold = 0.9):
_set_seg_th(pixel_conf_threshold, link_conf_threshold)
_set_weight_decay(weight_decay)
_set_image_shape(image_shape)
#init batch size
global gpus
gpus = util.tf.get_available_gpus(num_gpus)
global num_clones
num_clones = len(gpus)
global clone_scopes
clone_scopes = ['clone_%d'%(idx) for idx in xrange(num_clones)]
_set_batch_size(batch_size)
global batch_size_per_gpu
batch_size_per_gpu = batch_size / num_clones
if batch_size_per_gpu < 1:
raise ValueError('Invalid batch_size [=%d], \
resulting in 0 images per gpu.'%(batch_size))
global num_neighbours
num_neighbours = pixel_link.get_neighbours_fn()[1]
def print_config(flags, dataset, save_dir = None, print_to_file = True):
def do_print(stream=None):
print(util.log.get_date_str(), file = stream)
print('\n# =========================================================================== #', file=stream)
print('# Training flags:', file=stream)
print('# =========================================================================== #', file=stream)
def print_ckpt(path):
ckpt = util.tf.get_latest_ckpt(path)
if ckpt is not None:
print('Resume Training from : %s'%(ckpt), file = stream)
return True
return False
if not print_ckpt(flags.train_dir):
print_ckpt(flags.checkpoint_path)
pprint(flags.__flags, stream=stream)
print('\n# =========================================================================== #', file=stream)
print('# pixel_link net parameters:', file=stream)
print('# =========================================================================== #', file=stream)
vars = globals()
for key in vars:
var = vars[key]
if util.dtype.is_number(var) or util.dtype.is_str(var) or util.dtype.is_list(var) or util.dtype.is_tuple(var):
pprint('%s=%s'%(key, str(var)), stream = stream)
print('\n# =========================================================================== #', file=stream)
print('# Training | Evaluation dataset files:', file=stream)
print('# =========================================================================== #', file=stream)
data_files = parallel_reader.get_data_files(dataset.data_sources)
pprint(sorted(data_files), stream=stream)
print('', file=stream)
do_print(None)
if print_to_file:
# Save to a text file as well.
if save_dir is None:
save_dir = flags.train_dir
util.io.mkdir(save_dir)
path = util.io.join_path(save_dir, 'training_config.txt')
with open(path, "a") as out:
do_print(out)
def load_config(path):
if not util.io.is_dir(path):
path = util.io.get_dir(path)
config_file = util.io.join_path(path, 'config.py')
if util.io.exists(config_file):
tf.logging.info('loading config.py from %s'%(config_file))
config = util.mod.load_mod_from_path(config_file)
else:
util.io.copy('config.py', path)