Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE] Add raw attention scores to the AttentionCell #951 #964

Merged
merged 5 commits into from
Oct 25, 2019
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 50 additions & 18 deletions src/gluonnlp/model/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _masked_softmax(F, att_score, mask, dtype):
Shape (batch_size, query_length, memory_length)
Returns
-------
att_weights : Symborl or NDArray
att_weights : Symbol or NDArray
Shape (batch_size, query_length, memory_length)
att_score : Symbol or NDArray
Shape (batch_size, query_length, memory_length)
"""
if mask is not None:
Expand All @@ -58,7 +60,7 @@ def _masked_softmax(F, att_score, mask, dtype):
att_weights = F.softmax(att_score, axis=-1) * mask
else:
att_weights = F.softmax(att_score, axis=-1)
return att_weights
return att_weights, att_score


# TODO(sxjscience) In the future, we should support setting mask/att_weights as sparse tensors
Expand All @@ -70,10 +72,19 @@ class AttentionCell(HybridBlock):
cell = AttentionCell()
out = cell(query, key, value, mask)

Parameters
----------
prefix : str or None, default None
See document of `Block`.
params : str or None, default None
See document of `Block`.
return_attention_scores: bool, default False
Return also the raw attention scores if True
"""
def __init__(self, prefix=None, params=None):
def __init__(self, prefix=None, params=None, return_attention_scores=False):
self._dtype = np.float32
super(AttentionCell, self).__init__(prefix=prefix, params=params)
self._return_attention_scores = return_attention_scores

def cast(self, dtype):
self._dtype = dtype
Expand Down Expand Up @@ -148,6 +159,8 @@ def __call__(self, query, key, value=None, mask=None): # pylint: disable=argume
Shape (batch_size, query_length, context_vec_dim)
att_weights : Symbol or NDArray
Attention weights. Shape (batch_size, query_length, memory_length)
[att_scores]: Symbol or NDArray
emilmont marked this conversation as resolved.
Show resolved Hide resolved
Attention scores. Shape (batch_size, query_length, memory_length)
"""
return super(AttentionCell, self).__call__(query, key, value, mask)

Expand All @@ -160,9 +173,12 @@ def forward(self, query, key, value=None, mask=None): # pylint: disable=argumen
return super(AttentionCell, self).forward(query, key, value, mask)

def hybrid_forward(self, F, query, key, value, mask=None): # pylint: disable=arguments-differ
att_weights = self._compute_weight(F, query, key, mask)
att_weights, att_scores = self._compute_weight(F, query, key, mask)
context_vec = self._read_by_weight(F, att_weights, value)
return context_vec, att_weights
if self._return_attention_scores:
return context_vec, att_weights, att_scores
else:
return context_vec, att_weights


class MultiHeadAttentionCell(AttentionCell):
Expand Down Expand Up @@ -198,10 +214,14 @@ class MultiHeadAttentionCell(AttentionCell):
See document of `Block`.
params : str or None, default None
See document of `Block`.
return_attention_scores: bool, default False
Return also the raw attention scores if True
"""
def __init__(self, base_cell, query_units, key_units, value_units, num_heads, use_bias=True,
weight_initializer=None, bias_initializer='zeros', prefix=None, params=None):
super(MultiHeadAttentionCell, self).__init__(prefix=prefix, params=params)
weight_initializer=None, bias_initializer='zeros', prefix=None, params=None,
return_attention_scores=False):
super(MultiHeadAttentionCell, self).__init__(
prefix=prefix, params=params, return_attention_scores=return_attention_scores)
self._base_cell = base_cell
self._num_heads = num_heads
self._use_bias = use_bias
Expand Down Expand Up @@ -244,6 +264,8 @@ def __call__(self, query, key, value=None, mask=None):
att_weights : Symbol or NDArray
Attention weights of multiple heads.
Shape (batch_size, num_heads, query_length, memory_length)
[att_scores]: Symbol or NDArray
Attention scores. Shape (batch_size, query_length, memory_length)
"""
return super(MultiHeadAttentionCell, self).__call__(query, key, value, mask)

Expand All @@ -263,8 +285,10 @@ def _compute_weight(self, F, query, key, mask=None):
mask = F.broadcast_axis(F.expand_dims(mask, axis=1),
axis=1, size=self._num_heads)\
.reshape(shape=(-1, 0, 0), reverse=True)
att_weights = self._base_cell._compute_weight(F, query, key, mask)
return att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True)
att_weights, att_scores = self._base_cell._compute_weight(F, query, key, mask)
att_scores = att_scores.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True)
att_weights = att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True)
return att_weights, att_scores

def _read_by_weight(self, F, att_weights, value):
att_weights = att_weights.reshape(shape=(-1, 0, 0), reverse=True)
Expand Down Expand Up @@ -311,10 +335,13 @@ class MLPAttentionCell(AttentionCell):
See document of `Block`.
params : ParameterDict or None, default None
See document of `Block`.
return_attention_scores: bool, default False
Return also the raw attention scores if True
"""

def __init__(self, units, act=nn.Activation('tanh'), normalized=False, dropout=0.0,
weight_initializer=None, bias_initializer='zeros', prefix=None, params=None):
weight_initializer=None, bias_initializer='zeros', prefix=None, params=None,
return_attention_scores=False):
# Define a temporary class to implement the normalized version
# TODO(sxjscience) Find a better solution
class _NormalizedScoreProj(HybridBlock):
Expand All @@ -334,7 +361,8 @@ def hybrid_forward(self, F, x, g, v): # pylint: disable=arguments-differ
flatten=False, name='fwd')
return out

super(MLPAttentionCell, self).__init__(prefix=prefix, params=params)
super(MLPAttentionCell, self).__init__(prefix=prefix, params=params,
return_attention_scores=return_attention_scores)
self._units = units
self._act = act
self._normalized = normalized
Expand Down Expand Up @@ -366,8 +394,9 @@ def _compute_weight(self, F, query, key, mask=None):
F.expand_dims(mapped_key, axis=1))
mid_feat = self._act(mid_feat)
att_score = self._attention_score(mid_feat).reshape(shape=(0, 0, 0))
att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype))
return att_weights
att_weights, att_score = _masked_softmax(F, att_score, mask, self._dtype)
att_weights = self._dropout_layer(att_weights)
return att_weights, att_score


class DotProductAttentionCell(AttentionCell):
Expand Down Expand Up @@ -426,11 +455,14 @@ class DotProductAttentionCell(AttentionCell):
See document of `Block`.
params : str or None, default None
See document of `Block`.
return_attention_scores: bool, default False
Return also the raw attention scores if True
"""
def __init__(self, units=None, luong_style=False, scaled=True, normalized=False, use_bias=True,
dropout=0.0, weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(DotProductAttentionCell, self).__init__(prefix=prefix, params=params)
prefix=None, params=None, return_attention_scores=False):
super(DotProductAttentionCell, self).__init__(
prefix=prefix, params=params, return_attention_scores=return_attention_scores)
self._units = units
self._scaled = scaled
self._normalized = normalized
Expand Down Expand Up @@ -472,6 +504,6 @@ def _compute_weight(self, F, query, key, mask=None):
query = F.contrib.div_sqrt_dim(query)

att_score = F.batch_dot(query, key, transpose_b=True)

att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype))
return att_weights
att_weights, att_score = _masked_softmax(F, att_score, mask, self._dtype)
att_weights = self._dropout_layer(att_weights)
return att_weights, att_score