Skip to content

Commit

Permalink
PR #15867: add scoring methods in Luong-style attention
Browse files Browse the repository at this point in the history
Imported from GitHub PR #15867

Luong-style attention attention use three types of scoring methods, namely dot, general and concat. This can be found in the 3rd page of...

PiperOrigin-RevId: 422950111
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jan 20, 2022
1 parent 063aa1d commit 571b796
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'use_scale\', \'score_mode\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'dot\'], "
argspec: "args=[\'self\', \'use_scale\'], varargs=None, keywords=kwargs, defaults=[\'False\'], "
}
member_method {
name: "add_loss"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'use_scale\', \'score_mode\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'dot\'], "
argspec: "args=[\'self\', \'use_scale\'], varargs=None, keywords=kwargs, defaults=[\'False\'], "
}
member_method {
name: "add_loss"
Expand Down
43 changes: 6 additions & 37 deletions keras/layers/dense_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,6 @@ class Attention(BaseDenseAttention):
Defaults to `False`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
attention scores. Defaults to 0.0.
score_mode: Function to use to compute attention scores, one of
`{"dot", "concat"}`. `"dot"` refers to the dot product between the query
and key vectors. `"concat"` refers to the hyperbolic tangent of the
concatenation of the query and key vectors.
Call Args:
Expand Down Expand Up @@ -323,16 +319,12 @@ class Attention(BaseDenseAttention):
```
"""

def __init__(self, use_scale=False, score_mode='dot', **kwargs):
def __init__(self, use_scale=False, **kwargs):
super(Attention, self).__init__(**kwargs)
self.use_scale = use_scale
self.score_mode = score_mode
if self.score_mode not in ['dot', 'concat']:
raise ValueError(f'Received: score_mode={score_mode}. Acceptable values '
'are: ["dot", "concat"]')

def build(self, input_shape):
"""Creates variable when `use_scale` is True or `score_mode` is `concat`."""
"""Creates scale variable if use_scale==True."""
if self.use_scale:
self.scale = self.add_weight(
name='scale',
Expand All @@ -342,15 +334,6 @@ def build(self, input_shape):
trainable=True)
else:
self.scale = None
if self.score_mode == 'concat':
self.concat_score_weight = self.add_weight(
name='concat_score_weight',
shape=(),
initializer='ones',
dtype=self.dtype,
trainable=True)
else:
self.concat_score_weight = None
super(Attention, self).build(input_shape)

def _calculate_scores(self, query, key):
Expand All @@ -362,27 +345,13 @@ def _calculate_scores(self, query, key):
Returns:
Tensor of shape `[batch_size, Tq, Tv]`.
"""
if self.score_mode == 'dot':
scores = tf.matmul(query, key, transpose_b=True)
if self.scale is not None:
scores *= self.scale
elif self.score_mode == 'concat':
# Reshape tensors to enable broadcasting.
# Reshape into [batch_size, Tq, 1, dim].
q_reshaped = tf.expand_dims(query, axis=-2)
# Reshape into [batch_size, 1, Tv, dim].
k_reshaped = tf.expand_dims(key, axis=-3)
if self.scale is not None:
scores = self.concat_score_weight * tf.reduce_sum(
tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1)
else:
scores = self.concat_score_weight * tf.reduce_sum(
tf.tanh(q_reshaped + k_reshaped), axis=-1)

scores = tf.matmul(query, key, transpose_b=True)
if self.scale is not None:
scores *= self.scale
return scores

def get_config(self):
config = {'use_scale': self.use_scale, 'score_mode': self.score_mode}
config = {'use_scale': self.use_scale}
base_config = super(Attention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down
80 changes: 0 additions & 80 deletions keras/layers/dense_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,31 +204,6 @@ def test_calculate_scores_multi_dim(self):
dtype=np.float32)
self.assertAllClose(expected, actual)

def test_calculate_scores_multi_dim_concat(self):
# Query tensor of shape [1, 2, 4]
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
# Key tensor of shape [1, 3, 4]
k = np.array(
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
dtype=np.float32)
attention_layer = dense_attention.Attention(score_mode='concat')
attention_layer.concat_score_weight = 1
attention_layer.build(input_shape=([1, 2, 4], [1, 3, 4]))
actual = keras.backend.get_value(
attention_layer._calculate_scores(query=q, key=k))

# pylint:disable=line-too-long
# expected000 = tanh(1.+1.5) + tanh(1.1+1.6) + tanh(1.2+1.7) + tanh(1.3+1.8) = 3.96753427840
# expected001 = tanh(1.+2.5) + tanh(1.1+2.6) + tanh(1.2+2.7) + tanh(1.3+2.8) = 3.99558784825
# expected002 = tanh(1.+3.5) + tanh(1.1+3.6) + tanh(1.2+3.7) + tanh(1.3+3.8) = 3.99940254147
# expected010 = tanh(2.+1.5) + tanh(2.1+1.6) + tanh(2.2+1.7) + tanh(2.3+1.8) = 3.99558784825
# expected011 = tanh(2.+2.5) + tanh(2.1+2.6) + tanh(2.2+2.7) + tanh(2.3+2.8) = 3.99940254147
# expected012 = tanh(2.+3.5) + tanh(2.1+3.6) + tanh(2.2+3.7) + tanh(2.3+3.8) = 3.99991913657
expected = np.array([[[3.96753427840, 3.99558784825, 3.99940254147],
[3.99558784825, 3.99940254147, 3.99991913657]]],
dtype=np.float32)
self.assertAllClose(expected, actual)

def test_calculate_scores_one_dim_batch_size_two(self):
# Query tensor of shape [2, 1, 1]
q = np.array([[[1.1]], [[2.1]]], dtype=np.float32)
Expand Down Expand Up @@ -260,25 +235,6 @@ def test_calculate_scores_one_dim_with_scale(self):
expected = np.array([[[-3.52]]], dtype=np.float32)
self.assertAllClose(expected, actual)

def test_calculate_scores_one_dim_with_scale_concat(self):
"""Tests that scores are multiplied by scale."""
# Query tensor of shape [1, 1, 1]
q = np.array([[[1.1]]], dtype=np.float32)
# Key tensor of shape [1, 1, 1]
k = np.array([[[1.6]]], dtype=np.float32)
attention_layer = dense_attention.Attention(
use_scale=True, score_mode='concat')
attention_layer.concat_score_weight = 1
attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1]))
attention_layer.scale = 2.
actual = keras.backend.get_value(
attention_layer._calculate_scores(query=q, key=k))

# Expected tensor of shape [1, 1, 1].
# expected000 = tanh(2*(1.1+1.6)) = 0.9999592018254402
expected = np.array([[[0.999959202]]], dtype=np.float32)
self.assertAllClose(expected, actual)

def test_shape(self):
# Query tensor of shape [1, 2, 4]
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
Expand All @@ -294,22 +250,6 @@ def test_shape(self):
expected_shape = [1, 2, 4]
self.assertAllEqual(expected_shape, tf.shape(actual))

def test_shape_concat(self):
# Query tensor of shape [1, 2, 4]
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
# Value tensor of shape [1, 3, 4]
v = np.array(
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
dtype=np.float32)
# Value mask tensor of shape [1, 3]
v_mask = np.array([[True, True, False]], dtype=np.bool_)
attention_layer = dense_attention.Attention(score_mode='concat')
attention_layer.concat_score_weight = 1
actual = attention_layer([q, v], mask=[None, v_mask])

expected_shape = [1, 2, 4]
self.assertAllEqual(expected_shape, tf.shape(actual))

def test_shape_with_key(self):
# Query tensor of shape [1, 2, 4]
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
Expand All @@ -329,26 +269,6 @@ def test_shape_with_key(self):
expected_shape = [1, 2, 4]
self.assertAllEqual(expected_shape, tf.shape(actual))

def test_shape_with_key_concat(self):
# Query tensor of shape [1, 2, 4]
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
# Value tensor of shape [1, 3, 4]
v = np.array(
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
dtype=np.float32)
# Key tensor of shape [1, 3, 4]
k = np.array(
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
dtype=np.float32)
# Value mask tensor of shape [1, 3]
v_mask = np.array([[True, True, False]], dtype=np.bool_)
attention_layer = dense_attention.Attention(score_mode='concat')
attention_layer.concat_score_weight = 1
actual = attention_layer([q, v, k], mask=[None, v_mask])

expected_shape = [1, 2, 4]
self.assertAllEqual(expected_shape, tf.shape(actual))

def test_multi_dim(self):
# Query tensor of shape [1, 1, 1]
q = np.array([[[1.1]]], dtype=np.float32)
Expand Down

0 comments on commit 571b796

Please sign in to comment.