From 681aab7bc972899b27be4a71876914aea1219513 Mon Sep 17 00:00:00 2001 From: xingxy <775415344@qq.com> Date: Fri, 24 Dec 2021 22:52:00 +0800 Subject: [PATCH] Update g2pM.py According to https://pytorch.org/docs/1.7.1/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM `i_t` and `f_t` are calculated incorrectly --- g2pM/g2pM.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/g2pM/g2pM.py b/g2pM/g2pM.py index 30f71c4..3824e03 100644 --- a/g2pM/g2pM.py +++ b/g2pM/g2pM.py @@ -85,10 +85,16 @@ def fw_lstm_cell(self, inputs, init_states=None): if_hh = ifgo_hh[:, :ifgo_hh.shape[-1] * 2 // 4] go_hh = ifgo_hh[:, ifgo_hh.shape[-1] * 2 // 4:] - if_gate = self.sigmoid(if_ih + if_hh) + # if_gate = self.sigmoid(if_ih + if_hh) + # + # i, f = if_gate[:, :if_gate.shape[-1] // + # 2], if_gate[:, if_gate.shape[-1] // 2:] - i, f = if_gate[:, :if_gate.shape[-1] // - 2], if_gate[:, if_gate.shape[-1] // 2:] + i_ih, f_ih = if_ih[:, :if_ih.shape[-1]//2], if_ih[:, if_ih.shape[-1]//2:] + i_hh, f_hh = if_hh[:, :if_hh.shape[-1]//2], if_hh[:, if_hh.shape[-1]//2:] + + i = self.sigmoid(i_ih + i_hh) + f = self.sigmoid(f_ih + f_hh) g_ih, o_ih = go_ih[:, :go_ih.shape[-1] // 2], go_ih[:, go_ih.shape[-1] // 2:]