-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
global_attention.py
221 lines (176 loc) · 7.18 KB
/
global_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
""" Global attention modules (Luong / Bahdanau) """
import torch
import torch.nn as nn
from onmt.utils.misc import aeq, sequence_mask
# This class is mainly used by decoder.py for RNNs but also
# by the CNN / transformer decoder when copy attention is used
# CNN has its own attention mechanism ConvMultiStepAttention
# Transformer has its own MultiHeadedAttention
class GlobalAttention(nn.Module):
"""
Global attention takes a matrix and a query vector. It
then computes a parameterized convex combination of the matrix
based on the input query.
Constructs a unit mapping a query `q` of size `dim`
and a source matrix `H` of size `n x dim`, to an output
of size `dim`.
.. mermaid::
graph BT
A[Query]
subgraph RNN
C[H 1]
D[H 2]
E[H N]
end
F[Attn]
G[Output]
A --> F
C --> F
D --> F
E --> F
C -.-> G
D -.-> G
E -.-> G
F --> G
All models compute the output as
:math:`c = sum_{j=1}^{SeqLength} a_j H_j` where
:math:`a_j` is the softmax of a score function.
Then then apply a projection layer to [q, c].
However they
differ on how they compute the attention score.
* Luong Attention (dot, general):
* dot: :math:`score(H_j,q) = H_j^T q`
* general: :math:`score(H_j, q) = H_j^T W_a q`
* Bahdanau Attention (mlp):
* :math:`score(H_j, q) = v_a^T tanh(W_a q + U_a h_j)`
Args:
dim (int): dimensionality of query and key
coverage (bool): use coverage term
attn_type (str): type of attention to use, options [dot,general,mlp]
"""
def __init__(self, dim, coverage=False, attn_type="dot"):
super(GlobalAttention, self).__init__()
self.dim = dim
self.attn_type = attn_type
assert (self.attn_type in ["dot", "general", "mlp"]), (
"Please select a valid attention type.")
if self.attn_type == "general":
self.linear_in = nn.Linear(dim, dim, bias=False)
elif self.attn_type == "mlp":
self.linear_context = nn.Linear(dim, dim, bias=False)
self.linear_query = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, 1, bias=False)
# mlp wants it with bias
out_bias = self.attn_type == "mlp"
self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias)
self.softmax = nn.Softmax(dim=-1)
self.tanh = nn.Tanh()
if coverage:
self.linear_cover = nn.Linear(1, dim, bias=False)
def score(self, h_t, h_s):
"""
Args:
h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]`
h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]`
Returns:
:obj:`FloatTensor`:
raw attention scores (unnormalized) for each src index
`[batch x tgt_len x src_len]`
"""
# Check input sizes
src_batch, src_len, src_dim = h_s.size()
tgt_batch, tgt_len, tgt_dim = h_t.size()
aeq(src_batch, tgt_batch)
aeq(src_dim, tgt_dim)
aeq(self.dim, src_dim)
if self.attn_type in ["general", "dot"]:
if self.attn_type == "general":
h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim)
h_t_ = self.linear_in(h_t_)
h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim)
h_s_ = h_s.transpose(1, 2)
# (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
return torch.bmm(h_t, h_s_)
else:
dim = self.dim
wq = self.linear_query(h_t.view(-1, dim))
wq = wq.view(tgt_batch, tgt_len, 1, dim)
wq = wq.expand(tgt_batch, tgt_len, src_len, dim)
uh = self.linear_context(h_s.contiguous().view(-1, dim))
uh = uh.view(src_batch, 1, src_len, dim)
uh = uh.expand(src_batch, tgt_len, src_len, dim)
# (batch, t_len, s_len, d)
wquh = self.tanh(wq + uh)
return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None):
"""
Args:
input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`
memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]`
memory_lengths (`LongTensor`): the source context lengths `[batch]`
coverage (`FloatTensor`): None (not supported yet)
Returns:
(`FloatTensor`, `FloatTensor`):
* Computed vector `[tgt_len x batch x dim]`
* Attention distribtutions for each query
`[tgt_len x batch x src_len]`
"""
# one step input
if source.dim() == 2:
one_step = True
source = source.unsqueeze(1)
else:
one_step = False
batch, source_l, dim = memory_bank.size()
batch_, target_l, dim_ = source.size()
aeq(batch, batch_)
aeq(dim, dim_)
aeq(self.dim, dim)
if coverage is not None:
batch_, source_l_ = coverage.size()
aeq(batch, batch_)
aeq(source_l, source_l_)
if coverage is not None:
cover = coverage.view(-1).unsqueeze(1)
memory_bank += self.linear_cover(cover).view_as(memory_bank)
memory_bank = self.tanh(memory_bank)
# compute attention scores, as in Luong et al.
align = self.score(source, memory_bank)
if memory_lengths is not None:
mask = sequence_mask(memory_lengths, max_len=align.size(-1))
mask = mask.unsqueeze(1) # Make it broadcastable.
align.masked_fill_(1 - mask, -float('inf'))
# Softmax to normalize attention weights
align_vectors = self.softmax(align.view(batch*target_l, source_l))
align_vectors = align_vectors.view(batch, target_l, source_l)
# each context vector c_t is the weighted average
# over all the source hidden states
c = torch.bmm(align_vectors, memory_bank)
# concatenate
concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2)
attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
if self.attn_type in ["general", "dot"]:
attn_h = self.tanh(attn_h)
if one_step:
attn_h = attn_h.squeeze(1)
align_vectors = align_vectors.squeeze(1)
# Check output sizes
batch_, dim_ = attn_h.size()
aeq(batch, batch_)
aeq(dim, dim_)
batch_, source_l_ = align_vectors.size()
aeq(batch, batch_)
aeq(source_l, source_l_)
else:
attn_h = attn_h.transpose(0, 1).contiguous()
align_vectors = align_vectors.transpose(0, 1).contiguous()
# Check output sizes
target_l_, batch_, dim_ = attn_h.size()
aeq(target_l, target_l_)
aeq(batch, batch_)
aeq(dim, dim_)
target_l_, batch_, source_l_ = align_vectors.size()
aeq(target_l, target_l_)
aeq(batch, batch_)
aeq(source_l, source_l_)
return attn_h, align_vectors