This repository has been archived by the owner on Jun 15, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathmask_rcnn_train_chain.py
executable file
·116 lines (101 loc) · 5.21 KB
/
mask_rcnn_train_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import chainer
from chainer import cuda
import chainer.functions as F
from chainercv.links.model.faster_rcnn.utils.anchor_target_creator import AnchorTargetCreator
from utils.proposal_target_creator import ProposalTargetCreator
from chainer import computational_graph as c
from chainercv.links import PixelwiseSoftmaxClassifier
class MaskRCNNTrainChain(chainer.Chain):
def __init__(self, mask_rcnn, rpn_sigma=3., roi_sigma=1., gamma=1,
anchor_target_creator=AnchorTargetCreator(),
roi_size=14):
super(MaskRCNNTrainChain, self).__init__()
with self.init_scope():
self.mask_rcnn = mask_rcnn
self.rpn_sigma = rpn_sigma
self.roi_sigma = roi_sigma
self.anchor_target_creator = anchor_target_creator
self.proposal_target_creator = ProposalTargetCreator(roi_size=roi_size//2)
self.loc_normalize_mean = mask_rcnn.loc_normalize_mean
self.loc_normalize_std = mask_rcnn.loc_normalize_std
self.decayrate=0.99
self.avg_loss = None
self.gamma=gamma
def __call__(self, imgs, bboxes, labels, scale, masks, i):
if isinstance(bboxes, chainer.Variable):
bboxes = bboxes.data
if isinstance(labels, chainer.Variable):
labels = labels.data
if isinstance(scale, chainer.Variable):
scale = scale.data
if isinstance(masks, chainer.Variable):
masks = masks.data
scale = np.asscalar(cuda.to_cpu(scale))
n = bboxes.shape[0]
if n != 1:
raise ValueError('only batch size 1 is supported')
_, _, H, W = imgs.shape
img_size = (H, W)
#Extractor (VGG) : img -> features
with chainer.using_config('train', False):
features = self.mask_rcnn.extractor(imgs)
#Region Proposal Network : features -> rpn_locs, rpn_scores, rois
rpn_locs, rpn_scores, rois, roi_indices, anchor = self.mask_rcnn.rpn(
features, img_size, scale)
bbox, label, mask, rpn_score, rpn_loc, roi = \
bboxes[0], labels[0], masks[0], rpn_scores[0], rpn_locs[0], rois # batch size=1
#proposal target : roi(proposed) , bbox(GT), label(GT) -> sample_roi, gt_roi_loc, gt_roi_label
#the targets are compared with the head output.
sample_roi, gt_roi_loc, gt_roi_label, gt_roi_mask = self.proposal_target_creator(
roi, bbox, label, mask, self.loc_normalize_mean, self.loc_normalize_std)
sample_roi_index = self.xp.zeros((len(sample_roi),), dtype=np.int32)
#Head Network : features, sample_roi -> roi_cls_loc, roi_score
with chainer.using_config('train', False):
hres5 = self.mask_rcnn.head.res5head(features, sample_roi, sample_roi_index)
roi_cls_loc, roi_score = self.mask_rcnn.head.boxhead(hres5)
roi_cls_mask = self.mask_rcnn.head.maskhead(hres5)
del(hres5)
#RPN losses
gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(bbox, anchor, img_size)
rpn_loc_loss = _fast_rcnn_loc_loss(rpn_loc, gt_rpn_loc, gt_rpn_label, self.rpn_sigma)
rpn_cls_loss = F.sigmoid_cross_entropy(rpn_score, gt_rpn_label)
#Head output losses
n_sample = roi_cls_loc.shape[0]
roi_cls_loc = roi_cls_loc.reshape((n_sample, -1, 4))
roi_loc = roi_cls_loc[self.xp.arange(n_sample), gt_roi_label]
roi_mask = roi_cls_mask[self.xp.arange(n_sample), gt_roi_label]
roi_loc_loss = _fast_rcnn_loc_loss(roi_loc, gt_roi_loc, gt_roi_label, self.roi_sigma)
roi_cls_loss = F.softmax_cross_entropy(roi_score, gt_roi_label)
#mask loss: average binary cross-entropy loss
mask_loss = F.sigmoid_cross_entropy(roi_mask[0:gt_roi_mask.shape[0]], gt_roi_mask)
#total loss
loss = rpn_loc_loss + rpn_cls_loss + roi_loc_loss + roi_cls_loss + self.gamma * mask_loss
#avg loss calculation
if self.avg_loss is None:
self.avg_loss = loss.data
else:
self.avg_loss = self.avg_loss * self.decayrate + loss.data*(1-self.decayrate)
chainer.reporter.report({'rpn_loc_loss':rpn_loc_loss,
'rpn_cls_loss':rpn_cls_loss,
'roi_loc_loss':roi_loc_loss,
'roi_cls_loss':roi_cls_loss,
'roi_mask_loss':self.gamma * mask_loss,
'avg_loss':self.avg_loss,
'loss':loss}, self)
return loss
def _smooth_l1_loss(x, t, in_weight, sigma):
sigma2 = sigma ** 2
diff = in_weight * (x - t)
abs_diff = F.absolute(diff)
flag = (abs_diff.data < (1. / sigma2)).astype(np.float32)
y = (flag * (sigma2 / 2.) * F.square(diff) +
(1 - flag) * (abs_diff - 0.5 / sigma2))
return F.sum(y)
def _fast_rcnn_loc_loss(pred_loc, gt_loc, gt_label, sigma):
xp = chainer.cuda.get_array_module(pred_loc)
in_weight = xp.zeros_like(gt_loc)
in_weight[gt_label > 0] = 1
loc_loss = _smooth_l1_loss(pred_loc, gt_loc, in_weight, sigma)
loc_loss /= xp.sum(gt_label >= 0)
return loc_loss