-
Notifications
You must be signed in to change notification settings - Fork 19
/
nn_models_sa.py
80 lines (68 loc) · 3.1 KB
/
nn_models_sa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import chainer.functions as F
from chainer.link import Chain
from fgnt.chainer_extensions.binary_cross_entropy import binary_cross_entropy
from fgnt.chainer_extensions.links.sequence_linear import SequenceLinear
from fgnt.chainer_extensions.links.sequence_lstms import SequenceBLSTM
from fgnt.chainer_extensions.mse import mean_squared_error
'''
class MaskEstimator(Chain):
def _propagate(self, Y, dropout=0.):
raise NotImplemented
def calc_masks(self, Y, dropout=0.):
N_mask, X_mask = self._propagate(Y, dropout)
return N_mask, X_mask
def train_and_cv(self, Y, IBM_N, IBM_X, dropout=0.):
N_mask_hat, X_mask_hat = self._propagate(Y, dropout)
loss_X = binary_cross_entropy(X_mask_hat, IBM_X)
loss_N = binary_cross_entropy(N_mask_hat, IBM_N)
loss = (loss_X + loss_N) / 2
return loss
'''
class MaskEstimator(Chain):
def _propagate(self, Y, dropout=0.):
raise NotImplemented
def calc_masks(self, Y, dropout=0.):
N_mask, X_mask = self._propagate(Y, dropout)
return N_mask, X_mask
def train_and_cv(self, Y, SN, SX, dropout=0.):
N_mask_hat, X_mask_hat = self._propagate(Y, dropout)
loss_X = mean_squared_error(X_mask_hat*Y, SX)
loss_N = mean_squared_error(N_mask_hat*Y, SN)
loss = (loss_X + loss_N) / 2
return loss
class BLSTMMaskEstimator(MaskEstimator):
def __init__(self):
blstm_layer = SequenceBLSTM(513, 256, normalized=True)
relu_1 = SequenceLinear(256, 513, normalized=True)
relu_2 = SequenceLinear(513, 513, normalized=True)
noise_mask_estimate = SequenceLinear(513, 513, normalized=True)
speech_mask_estimate = SequenceLinear(513, 513, normalized=True)
super(BLSTMMaskEstimator, self).__init__(
blstm_layer=blstm_layer,
relu_1=relu_1,
relu_2=relu_2,
noise_mask_estimate=noise_mask_estimate,
speech_mask_estimate=speech_mask_estimate
)
def _propagate(self, Y, dropout=0.):
blstm = self.blstm_layer(Y, dropout=dropout)
relu_1 = F.clipped_relu(self.relu_1(blstm, dropout=dropout))
relu_2 = F.clipped_relu(self.relu_2(relu_1, dropout=dropout))
N_mask = F.sigmoid(self.noise_mask_estimate(relu_2))
X_mask = F.sigmoid(self.speech_mask_estimate(relu_2))
return N_mask, X_mask
class SimpleFWMaskEstimator(MaskEstimator):
def __init__(self):
relu_1 = SequenceLinear(513, 1024, normalized=True)
noise_mask_estimate = SequenceLinear(1024, 513, normalized=True)
speech_mask_estimate = SequenceLinear(1024, 513, normalized=True)
super(SimpleFWMaskEstimator, self).__init__(
relu_1=relu_1,
noise_mask_estimate=noise_mask_estimate,
speech_mask_estimate=speech_mask_estimate
)
def _propagate(self, Y, dropout=0.):
relu_1 = F.clipped_relu(self.relu_1(Y, dropout=dropout))
N_mask = F.sigmoid(self.noise_mask_estimate(relu_1))
X_mask = F.sigmoid(self.speech_mask_estimate(relu_1))
return N_mask, X_mask