diff --git a/keras/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt b/keras/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt index 711d4e1ac6c..5a5fc4be0cd 100644 --- a/keras/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt +++ b/keras/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt @@ -131,7 +131,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'use_scale\', \'score_mode\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'dot\'], " + argspec: "args=[\'self\', \'use_scale\'], varargs=None, keywords=kwargs, defaults=[\'False\'], " } member_method { name: "add_loss" diff --git a/keras/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt b/keras/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt index 711d4e1ac6c..5a5fc4be0cd 100644 --- a/keras/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt +++ b/keras/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt @@ -131,7 +131,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'use_scale\', \'score_mode\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'dot\'], " + argspec: "args=[\'self\', \'use_scale\'], varargs=None, keywords=kwargs, defaults=[\'False\'], " } member_method { name: "add_loss" diff --git a/keras/layers/dense_attention.py b/keras/layers/dense_attention.py index eb555ba2b5e..8c9ce1b52c6 100644 --- a/keras/layers/dense_attention.py +++ b/keras/layers/dense_attention.py @@ -242,10 +242,6 @@ class Attention(BaseDenseAttention): Defaults to `False`. dropout: Float between 0 and 1. Fraction of the units to drop for the attention scores. Defaults to 0.0. - score_mode: Function to use to compute attention scores, one of - `{"dot", "concat"}`. `"dot"` refers to the dot product between the query - and key vectors. `"concat"` refers to the hyperbolic tangent of the - concatenation of the query and key vectors. Call Args: @@ -323,16 +319,12 @@ class Attention(BaseDenseAttention): ``` """ - def __init__(self, use_scale=False, score_mode='dot', **kwargs): + def __init__(self, use_scale=False, **kwargs): super(Attention, self).__init__(**kwargs) self.use_scale = use_scale - self.score_mode = score_mode - if self.score_mode not in ['dot', 'concat']: - raise ValueError(f'Received: score_mode={score_mode}. Acceptable values ' - 'are: ["dot", "concat"]') def build(self, input_shape): - """Creates variable when `use_scale` is True or `score_mode` is `concat`.""" + """Creates scale variable if use_scale==True.""" if self.use_scale: self.scale = self.add_weight( name='scale', @@ -342,15 +334,6 @@ def build(self, input_shape): trainable=True) else: self.scale = None - if self.score_mode == 'concat': - self.concat_score_weight = self.add_weight( - name='concat_score_weight', - shape=(), - initializer='ones', - dtype=self.dtype, - trainable=True) - else: - self.concat_score_weight = None super(Attention, self).build(input_shape) def _calculate_scores(self, query, key): @@ -362,27 +345,13 @@ def _calculate_scores(self, query, key): Returns: Tensor of shape `[batch_size, Tq, Tv]`. """ - if self.score_mode == 'dot': - scores = tf.matmul(query, key, transpose_b=True) - if self.scale is not None: - scores *= self.scale - elif self.score_mode == 'concat': - # Reshape tensors to enable broadcasting. - # Reshape into [batch_size, Tq, 1, dim]. - q_reshaped = tf.expand_dims(query, axis=-2) - # Reshape into [batch_size, 1, Tv, dim]. - k_reshaped = tf.expand_dims(key, axis=-3) - if self.scale is not None: - scores = self.concat_score_weight * tf.reduce_sum( - tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1) - else: - scores = self.concat_score_weight * tf.reduce_sum( - tf.tanh(q_reshaped + k_reshaped), axis=-1) - + scores = tf.matmul(query, key, transpose_b=True) + if self.scale is not None: + scores *= self.scale return scores def get_config(self): - config = {'use_scale': self.use_scale, 'score_mode': self.score_mode} + config = {'use_scale': self.use_scale} base_config = super(Attention, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/keras/layers/dense_attention_test.py b/keras/layers/dense_attention_test.py index 51bf2a80410..f54bbae2b25 100644 --- a/keras/layers/dense_attention_test.py +++ b/keras/layers/dense_attention_test.py @@ -204,31 +204,6 @@ def test_calculate_scores_multi_dim(self): dtype=np.float32) self.assertAllClose(expected, actual) - def test_calculate_scores_multi_dim_concat(self): - # Query tensor of shape [1, 2, 4] - q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) - # Key tensor of shape [1, 3, 4] - k = np.array( - [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], - dtype=np.float32) - attention_layer = dense_attention.Attention(score_mode='concat') - attention_layer.concat_score_weight = 1 - attention_layer.build(input_shape=([1, 2, 4], [1, 3, 4])) - actual = keras.backend.get_value( - attention_layer._calculate_scores(query=q, key=k)) - - # pylint:disable=line-too-long - # expected000 = tanh(1.+1.5) + tanh(1.1+1.6) + tanh(1.2+1.7) + tanh(1.3+1.8) = 3.96753427840 - # expected001 = tanh(1.+2.5) + tanh(1.1+2.6) + tanh(1.2+2.7) + tanh(1.3+2.8) = 3.99558784825 - # expected002 = tanh(1.+3.5) + tanh(1.1+3.6) + tanh(1.2+3.7) + tanh(1.3+3.8) = 3.99940254147 - # expected010 = tanh(2.+1.5) + tanh(2.1+1.6) + tanh(2.2+1.7) + tanh(2.3+1.8) = 3.99558784825 - # expected011 = tanh(2.+2.5) + tanh(2.1+2.6) + tanh(2.2+2.7) + tanh(2.3+2.8) = 3.99940254147 - # expected012 = tanh(2.+3.5) + tanh(2.1+3.6) + tanh(2.2+3.7) + tanh(2.3+3.8) = 3.99991913657 - expected = np.array([[[3.96753427840, 3.99558784825, 3.99940254147], - [3.99558784825, 3.99940254147, 3.99991913657]]], - dtype=np.float32) - self.assertAllClose(expected, actual) - def test_calculate_scores_one_dim_batch_size_two(self): # Query tensor of shape [2, 1, 1] q = np.array([[[1.1]], [[2.1]]], dtype=np.float32) @@ -260,25 +235,6 @@ def test_calculate_scores_one_dim_with_scale(self): expected = np.array([[[-3.52]]], dtype=np.float32) self.assertAllClose(expected, actual) - def test_calculate_scores_one_dim_with_scale_concat(self): - """Tests that scores are multiplied by scale.""" - # Query tensor of shape [1, 1, 1] - q = np.array([[[1.1]]], dtype=np.float32) - # Key tensor of shape [1, 1, 1] - k = np.array([[[1.6]]], dtype=np.float32) - attention_layer = dense_attention.Attention( - use_scale=True, score_mode='concat') - attention_layer.concat_score_weight = 1 - attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1])) - attention_layer.scale = 2. - actual = keras.backend.get_value( - attention_layer._calculate_scores(query=q, key=k)) - - # Expected tensor of shape [1, 1, 1]. - # expected000 = tanh(2*(1.1+1.6)) = 0.9999592018254402 - expected = np.array([[[0.999959202]]], dtype=np.float32) - self.assertAllClose(expected, actual) - def test_shape(self): # Query tensor of shape [1, 2, 4] q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) @@ -294,22 +250,6 @@ def test_shape(self): expected_shape = [1, 2, 4] self.assertAllEqual(expected_shape, tf.shape(actual)) - def test_shape_concat(self): - # Query tensor of shape [1, 2, 4] - q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) - # Value tensor of shape [1, 3, 4] - v = np.array( - [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], - dtype=np.float32) - # Value mask tensor of shape [1, 3] - v_mask = np.array([[True, True, False]], dtype=np.bool_) - attention_layer = dense_attention.Attention(score_mode='concat') - attention_layer.concat_score_weight = 1 - actual = attention_layer([q, v], mask=[None, v_mask]) - - expected_shape = [1, 2, 4] - self.assertAllEqual(expected_shape, tf.shape(actual)) - def test_shape_with_key(self): # Query tensor of shape [1, 2, 4] q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) @@ -329,26 +269,6 @@ def test_shape_with_key(self): expected_shape = [1, 2, 4] self.assertAllEqual(expected_shape, tf.shape(actual)) - def test_shape_with_key_concat(self): - # Query tensor of shape [1, 2, 4] - q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) - # Value tensor of shape [1, 3, 4] - v = np.array( - [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], - dtype=np.float32) - # Key tensor of shape [1, 3, 4] - k = np.array( - [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], - dtype=np.float32) - # Value mask tensor of shape [1, 3] - v_mask = np.array([[True, True, False]], dtype=np.bool_) - attention_layer = dense_attention.Attention(score_mode='concat') - attention_layer.concat_score_weight = 1 - actual = attention_layer([q, v, k], mask=[None, v_mask]) - - expected_shape = [1, 2, 4] - self.assertAllEqual(expected_shape, tf.shape(actual)) - def test_multi_dim(self): # Query tensor of shape [1, 1, 1] q = np.array([[[1.1]]], dtype=np.float32)