diff --git a/train_models/mtcnn_model.py b/train_models/mtcnn_model.py index 420b014..749aa0d 100755 --- a/train_models/mtcnn_model.py +++ b/train_models/mtcnn_model.py @@ -4,83 +4,90 @@ from tensorflow.contrib.tensorboard.plugins import projector import numpy as np num_keep_radio = 0.7 + + #define prelu def prelu(inputs): - alphas = tf.get_variable("alphas", shape=inputs.get_shape()[-1], dtype=tf.float32, initializer=tf.constant_initializer(0.25)) + alphas = tf.get_variable("alphas", + shape=inputs.get_shape()[-1], + dtype=tf.float32, + initializer=tf.constant_initializer(0.25)) pos = tf.nn.relu(inputs) - neg = alphas * (inputs-abs(inputs))*0.5 + neg = alphas * (inputs - abs(inputs)) * 0.5 return pos + neg -def dense_to_one_hot(labels_dense,num_classes): + +def dense_to_one_hot(labels_dense, num_classes): num_labels = labels_dense.shape[0] - index_offset = np.arange(num_labels)*num_classes + index_offset = np.arange(num_labels) * num_classes #num_sample*num_classes - labels_one_hot = np.zeros((num_labels,num_classes)) + labels_one_hot = np.zeros((num_labels, num_classes)) labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 return labels_one_hot + + #cls_prob:batch*2 #label:batch - def cls_ohem(cls_prob, label): - zeros = tf.zeros_like(label) - #label=-1 --> label=0net_factory - - #pos -> 1, neg -> 0, others -> 0 - label_filter_invalid = tf.where(tf.less(label,0), zeros, label) - num_cls_prob = tf.size(cls_prob) - cls_prob_reshape = tf.reshape(cls_prob,[num_cls_prob,-1]) - label_int = tf.cast(label_filter_invalid,tf.int32) - # get the number of rows of class_prob - num_row = tf.to_int32(cls_prob.get_shape()[0]) - #row = [0,2,4.....] - row = tf.range(num_row)*2 - indices_ = row + label_int - label_prob = tf.squeeze(tf.gather(cls_prob_reshape, indices_)) - loss = -tf.log(label_prob+1e-10) - zeros = tf.zeros_like(label_prob, dtype=tf.float32) - ones = tf.ones_like(label_prob,dtype=tf.float32) + # pos -> 1, neg -> 0, others -> 0 + label_filter_invalid = tf.nn.relu(label) + loss = tf.losses.sparse_softmax_cross_entropy( + tf.cast(label_filter_invalid, tf.int32), + cls_prob, + reduction=tf.losses.Reduction.NONE) # set pos and neg to be 1, rest to be 0 - valid_inds = tf.where(label < zeros,zeros,ones) + zeros = tf.zeros_like(label, dtype=tf.float32) + ones = tf.ones_like(label, dtype=tf.float32) + valid_inds = tf.where(label < zeros, zeros, ones) # get the number of POS and NEG examples - num_valid = tf.reduce_sum(valid_inds) + valid_num = tf.reduce_sum(valid_inds) - keep_num = tf.cast(num_valid*num_keep_radio,dtype=tf.int32) - #FILTER OUT PART AND LANDMARK DATA + keep_num = tf.cast(valid_num * num_keep_radio, dtype=tf.int32) + # FILTER OUT PART AND LANDMARK DATA loss = loss * valid_inds - loss,_ = tf.nn.top_k(loss, k=keep_num) + loss, _ = tf.nn.top_k(loss, k=keep_num) return tf.reduce_mean(loss) -def bbox_ohem_smooth_L1_loss(bbox_pred,bbox_target,label): +def bbox_ohem_smooth_L1_loss(bbox_pred, bbox_target, label): sigma = tf.constant(1.0) - threshold = 1.0/(sigma**2) + threshold = 1.0 / (sigma**2) zeros_index = tf.zeros_like(label, dtype=tf.float32) - valid_inds = tf.where(label!=zeros_index,tf.ones_like(label,dtype=tf.float32),zeros_index) - abs_error = tf.abs(bbox_pred-bbox_target) - loss_smaller = 0.5*((abs_error*sigma)**2) - loss_larger = abs_error-0.5/(sigma**2) - smooth_loss = tf.reduce_sum(tf.where(abs_error