From e0c8e1f0e478f48f01da4c269b798f3e0b94effb Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 8 Feb 2019 17:46:17 +0100 Subject: [PATCH] Add v2 Transformer layers (#324) * Add v2 Transformer layers * Fix Travis * Fix Travis * Fix Travis --- .travis.yml | 7 +- opennmt/v2/layers/__init__.py | 1 + opennmt/v2/layers/transformer.py | 129 +++++++++++++++++++++++++++ opennmt/v2/tests/transformer_test.py | 68 ++++++++++++++ setup.py | 1 + 5 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 opennmt/v2/layers/transformer.py create mode 100644 opennmt/v2/tests/transformer_test.py diff --git a/.travis.yml b/.travis.yml index ab43f2c5f..7fcb86066 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,18 +25,19 @@ matrix: env: - TF_VERSION="2.0" before_install: - - pip install tf-nightly-2.0-preview + - pip install tf-nightly-2.0-preview pylint==2.1.1 - pip install -e .[tests] script: - nose2 -s opennmt/v2/tests + - pylint opennmt/v2/ - python: "3.6" env: - TF_VERSION="$LATEST_TF_VERSION" install: - pip install pylint==2.1.1 wheel twine script: - - nose2 - - pylint opennmt/ + - nose2 -s opennmt/tests + - pylint --ignore=tests,v2 opennmt/ after_success: - | if [[ -n $TRAVIS_TAG ]]; then diff --git a/opennmt/v2/layers/__init__.py b/opennmt/v2/layers/__init__.py index 52e10231f..2380eea08 100644 --- a/opennmt/v2/layers/__init__.py +++ b/opennmt/v2/layers/__init__.py @@ -1,3 +1,4 @@ """Layers module.""" from opennmt.v2.layers.common import LayerNorm +from opennmt.v2.layers.transformer import FeedForwardNetwork, MultiHeadAttention diff --git a/opennmt/v2/layers/transformer.py b/opennmt/v2/layers/transformer.py new file mode 100644 index 000000000..6325398c5 --- /dev/null +++ b/opennmt/v2/layers/transformer.py @@ -0,0 +1,129 @@ +# pylint: disable=arguments-differ + +"""Define layers related to the Google's Transformer model.""" + +import tensorflow as tf + +from opennmt.layers.transformer import combine_heads, split_heads + + +class FeedForwardNetwork(tf.keras.layers.Layer): + """Implements the Transformer's "Feed Forward" layer. + + .. math:: + + ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 + """ + + def __init__(self, inner_dim, output_dim, dropout=0.1, **kwargs): + """Initializes this layer. + + Args: + inner_dim: The number of units of the inner linear transformation. + output_dim: The number of units of the ouput linear transformation. + dropout: The probability to drop units from the inner transformation. + kwargs: Additional layer arguments. + """ + super(FeedForwardNetwork, self).__init__(**kwargs) + self.inner = tf.keras.layers.Dense(inner_dim, activation=tf.nn.relu) + self.outer = tf.keras.layers.Dense(output_dim) + self.dropout = dropout + + def call(self, inputs, training=None): + """Runs the layer.""" + inner = self.inner(inputs) + if training: + inner = tf.nn.dropout(inner, rate=self.dropout) + return self.outer(inner) + + +class MultiHeadAttention(tf.keras.layers.Layer): + """Computes the multi-head attention as described in + https://arxiv.org/abs/1706.03762. + """ + + def __init__(self, + num_heads, + num_units, + dropout=0.1, + return_attention=False, + **kwargs): + """Initializes this layers. + + Args: + num_heads: The number of attention heads. + num_units: The number of hidden units. + dropout: The probability to drop units from the inputs. + return_attention: If ``True``, also return the attention weights of the + first head. + kwargs: Additional layer arguments. + """ + super(MultiHeadAttention, self).__init__(**kwargs) + if num_units % num_heads != 0: + raise ValueError("Multi head attention requires that num_units is a" + " multiple of %s" % num_heads) + self.num_heads = num_heads + self.num_units = num_units + self.linear_queries = tf.keras.layers.Dense(num_units) + self.linear_keys = tf.keras.layers.Dense(num_units) + self.linear_values = tf.keras.layers.Dense(num_units) + self.linear_output = tf.keras.layers.Dense(num_units) + self.dropout = dropout + self.return_attention = return_attention + + def call(self, inputs, memory=None, mask=None, cache=None, training=None): + """Runs the layer. + + Args: + inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + If ``None``, computes self-attention. + mask: A ``tf.Tensor`` applied to the dot product. + cache: A dictionary containing pre-projected keys and values. + training: Run in training mode. + + Returns: + A tuple with the attention context, the updated cache and the attention + probabilities of the first head (if :obj:`return_attention` is ``True``). + """ + + def _compute_kv(x): + keys = self.linear_keys(x) + keys = split_heads(keys, self.num_heads) + values = self.linear_values(x) + values = split_heads(values, self.num_heads) + return keys, values + + # Compute queries. + queries = self.linear_queries(inputs) + queries = split_heads(queries, self.num_heads) + queries *= (self.num_units // self.num_heads)**-0.5 + + # Compute keys and values. + if memory is None: + keys, values = _compute_kv(inputs) + if cache is not None: + keys = tf.concat([cache[0], keys], axis=2) + values = tf.concat([cache[1], values], axis=2) + else: + if not cache or tf.equal(tf.shape(cache[0])[2], 0): + keys, values = _compute_kv(memory) + else: + keys, values = cache + cache = (keys, values) + + # Dot product attention. + dot = tf.matmul(queries, keys, transpose_b=True) + if mask is not None: + mask = tf.expand_dims(tf.cast(mask, tf.float32), 1) # Broadcast on heads dimension. + dot = tf.cast(tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype) + attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) + drop_attn = tf.nn.dropout(attn, dropout) if training else attn + heads = tf.matmul(drop_attn, values) + + # Concatenate all heads output. + combined = combine_heads(heads) + outputs = self.linear_output(combined) + if self.return_attention: + return outputs, cache, attn + return outputs, cache diff --git a/opennmt/v2/tests/transformer_test.py b/opennmt/v2/tests/transformer_test.py new file mode 100644 index 000000000..3b24d6361 --- /dev/null +++ b/opennmt/v2/tests/transformer_test.py @@ -0,0 +1,68 @@ +from parameterized import parameterized + +import tensorflow as tf + +from opennmt.v2.layers import transformer + + +class TransformerTest(tf.test.TestCase): + + @parameterized.expand([[tf.float32], [tf.float16]]) + def testFeedForwardNetwork(self, dtype): + ffn = transformer.FeedForwardNetwork(20, 10) + x = tf.random.uniform([4, 5, 10], dtype=dtype) + y = ffn(x) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, dtype) + + @parameterized.expand([[tf.float32], [tf.float16]]) + def testMultiHeadSelfAttention(self, dtype): + attention = transformer.MultiHeadAttention(4, 20) + queries = tf.random.uniform([4, 5, 10], dtype=dtype) + mask = tf.expand_dims(tf.sequence_mask([4, 3, 5, 2]), 1) + context, _ = attention(queries, mask=mask) + self.assertListEqual(context.shape.as_list(), [4, 5, 20]) + self.assertEqual(context.dtype, dtype) + + @parameterized.expand([[tf.float32], [tf.float16]]) + def testMultiHeadSelfAttentionWithCache(self, dtype): + cache = (tf.zeros([4, 4, 0, 5], dtype=dtype), tf.zeros([4, 4, 0, 5], dtype=dtype)) + attention = transformer.MultiHeadAttention(4, 20) + x = tf.random.uniform([4, 1, 10], dtype=dtype) + _, cache = attention(x, cache=cache) + self.assertEqual(cache[0].shape[2], 1) + self.assertEqual(cache[1].shape[2], 1) + _, cache = attention(x, cache=cache) + self.assertEqual(cache[0].shape[2], 2) + self.assertEqual(cache[0].dtype, dtype) + self.assertEqual(cache[1].shape[2], 2) + self.assertEqual(cache[1].dtype, dtype) + + @parameterized.expand([[tf.float32], [tf.float16]]) + def testMultiHeadAttention(self, dtype): + attention = transformer.MultiHeadAttention(4, 20) + queries = tf.random.uniform([4, 5, 10], dtype=dtype) + memory = tf.random.uniform([4, 3, 10], dtype=dtype) + mask = tf.expand_dims(tf.sequence_mask([1, 3, 2, 2]), 1) + context, _ = attention(queries, memory=memory, mask=mask) + self.assertListEqual(context.shape.as_list(), [4, 5, 20]) + self.assertEqual(context.dtype, dtype) + + @parameterized.expand([[tf.float32], [tf.float16]]) + def testMultiHeadAttentionWithCache(self, dtype): + cache = (tf.zeros([4, 4, 0, 5], dtype=dtype), tf.zeros([4, 4, 0, 5], dtype=dtype)) + attention = transformer.MultiHeadAttention(4, 20) + memory = tf.random.uniform([4, 3, 10], dtype=dtype) + mask = tf.expand_dims(tf.sequence_mask([1, 3, 2, 2]), 1) + x = tf.random.uniform([4, 1, 10], dtype=dtype) + y1, cache = attention(x, memory=memory, mask=mask, cache=cache) + self.assertEqual(cache[0].shape[2], 3) + self.assertEqual(cache[0].dtype, dtype) + self.assertEqual(cache[1].shape[2], 3) + self.assertEqual(cache[1].dtype, dtype) + y2, cache = attention(x, memory=memory, mask=mask, cache=cache) + self.assertAllEqual(y1, y2) + + +if __name__ == "__main__": + tf.test.main() diff --git a/setup.py b/setup.py index 8652676bf..0771904c7 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ from setuptools import setup, find_packages tests_require = [ + "parameterized", "nose2" ]