-
Notifications
You must be signed in to change notification settings - Fork 8
/
model.py
103 lines (84 loc) · 4.35 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.nn as nn
from torch.autograd import Variable
class Encoder(nn.Module):
def __init__(self, qrnn_layer, n_layers, kernel_size,
hidden_size, emb_size, src_vocab_size):
super(Encoder, self).__init__()
# Initialize source embedding
self.embedding = nn.Embedding(src_vocab_size, emb_size)
layers = []
for layer_idx in xrange(n_layers):
input_size = emb_size if layer_idx == 0 else hidden_size
layers.append(qrnn_layer(input_size, hidden_size, kernel_size, False))
self.layers = nn.ModuleList(layers)
def forward(self, inputs, input_len):
# input_len: [batch_size] Variable(torch.LongTensor)
# h: [batch_size x length x emb_size]
h = self.embedding(inputs)
cell_states, hidden_states = [], []
for layer in self.layers:
c, h = layer(h) # c, h: [batch_size x length x hidden_size]
time = Variable(torch.arange(0, h.size(1)).unsqueeze(-1).expand_as(h).long())
if h.is_cuda:
time = time.cuda()
# mask to support variable seq lengths: TODO: use .masked_fill()
mask = (input_len.unsqueeze(-1).unsqueeze(-1) > time).float()
h = h * mask
# c_last, h_last: [batch_size, hidden_size]
c_last = c[range(len(inputs)), (input_len-1).data,:]
h_last = h[range(len(inputs)), (input_len-1).data,:]
cell_states.append(c_last)
hidden_states.append((h_last, h))
# return lists of cell states and hidden states of each layer
return cell_states, hidden_states
class Decoder(nn.Module):
def __init__(self, qrnn_layer, n_layers, kernel_size,
hidden_size, emb_size, tgt_vocab_size):
super(Decoder, self).__init__()
# Initialize target embedding
self.embedding = nn.Embedding(tgt_vocab_size, emb_size)
layers = []
for layer_idx in xrange(n_layers):
input_size = emb_size if layer_idx == 0 else hidden_size
use_attn = True if layer_idx == n_layers-1 else False
layers.append(qrnn_layer(input_size, hidden_size, kernel_size, use_attn))
self.layers = nn.ModuleList(layers)
def forward(self, inputs, init_states, memories):
assert len(self.layers) == len(memories)
cell_states, hidden_states = [], []
# h: [batch_size, length, emb_size]
h = self.embedding(inputs)
for layer_idx, layer in enumerate(self.layers):
state = None if init_states is None else init_states[layer_idx]
memory = memories[layer_idx]
c, h = layer(h, state, memory)
cell_states.append(c); hidden_states.append(h)
# The shape of the each state: [batch_size x length x hidden_size]
# return lists of cell states and hidden_states
return cell_states, hidden_states
class QRNNModel(nn.Module):
def __init__(self, qrnn_layer, n_layers, kernel_size, hidden_size,
emb_size, src_vocab_size, tgt_vocab_size):
super(QRNNModel, self).__init__()
self.encoder = Encoder(qrnn_layer, n_layers, kernel_size, hidden_size,
emb_size, src_vocab_size)
self.decoder = Decoder(qrnn_layer, n_layers, kernel_size, hidden_size,
emb_size, tgt_vocab_size)
self.proj_linear = nn.Linear(hidden_size, tgt_vocab_size)
def encode(self, inputs, input_len):
return self.encoder(inputs, input_len)
def decode(self, inputs, init_states, memories):
cell_states, hidden_states = self.decoder(inputs, init_states, memories)
# return:
# projected hidden_state of the last layer: logit
# first reshape it to [batch_size * length x hidden_size]
# after projection: [batch_size * length x tgt_vocab_size]
h_last = hidden_states[-1]
return cell_states, self.proj_linear(h_last.view(-1, h_last.size(-1)))
def forward(self, enc_inputs, enc_len, dec_inputs):
# Encode source inputs
init_states, memories = self.encode(enc_inputs, enc_len)
# logits: [batch_size * length x tgt_vocab_size]
_, logits = self.decode(dec_inputs, init_states, memories)
return logits