Skip to content

Commit

Permalink
Add v2 Transformer layers (#324)
Browse files Browse the repository at this point in the history
* Add v2 Transformer layers

* Fix Travis

* Fix Travis

* Fix Travis
  • Loading branch information
guillaumekln authored Feb 8, 2019
1 parent d5bbf9f commit e0c8e1f
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 3 deletions.
7 changes: 4 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ matrix:
env:
- TF_VERSION="2.0"
before_install:
- pip install tf-nightly-2.0-preview
- pip install tf-nightly-2.0-preview pylint==2.1.1
- pip install -e .[tests]
script:
- nose2 -s opennmt/v2/tests
- pylint opennmt/v2/
- python: "3.6"
env:
- TF_VERSION="$LATEST_TF_VERSION"
install:
- pip install pylint==2.1.1 wheel twine
script:
- nose2
- pylint opennmt/
- nose2 -s opennmt/tests
- pylint --ignore=tests,v2 opennmt/
after_success:
- |
if [[ -n $TRAVIS_TAG ]]; then
Expand Down
1 change: 1 addition & 0 deletions opennmt/v2/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Layers module."""

from opennmt.v2.layers.common import LayerNorm
from opennmt.v2.layers.transformer import FeedForwardNetwork, MultiHeadAttention
129 changes: 129 additions & 0 deletions opennmt/v2/layers/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# pylint: disable=arguments-differ

"""Define layers related to the Google's Transformer model."""

import tensorflow as tf

from opennmt.layers.transformer import combine_heads, split_heads


class FeedForwardNetwork(tf.keras.layers.Layer):
"""Implements the Transformer's "Feed Forward" layer.
.. math::
ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2
"""

def __init__(self, inner_dim, output_dim, dropout=0.1, **kwargs):
"""Initializes this layer.
Args:
inner_dim: The number of units of the inner linear transformation.
output_dim: The number of units of the ouput linear transformation.
dropout: The probability to drop units from the inner transformation.
kwargs: Additional layer arguments.
"""
super(FeedForwardNetwork, self).__init__(**kwargs)
self.inner = tf.keras.layers.Dense(inner_dim, activation=tf.nn.relu)
self.outer = tf.keras.layers.Dense(output_dim)
self.dropout = dropout

def call(self, inputs, training=None):
"""Runs the layer."""
inner = self.inner(inputs)
if training:
inner = tf.nn.dropout(inner, rate=self.dropout)
return self.outer(inner)


class MultiHeadAttention(tf.keras.layers.Layer):
"""Computes the multi-head attention as described in
https://arxiv.org/abs/1706.03762.
"""

def __init__(self,
num_heads,
num_units,
dropout=0.1,
return_attention=False,
**kwargs):
"""Initializes this layers.
Args:
num_heads: The number of attention heads.
num_units: The number of hidden units.
dropout: The probability to drop units from the inputs.
return_attention: If ``True``, also return the attention weights of the
first head.
kwargs: Additional layer arguments.
"""
super(MultiHeadAttention, self).__init__(**kwargs)
if num_units % num_heads != 0:
raise ValueError("Multi head attention requires that num_units is a"
" multiple of %s" % num_heads)
self.num_heads = num_heads
self.num_units = num_units
self.linear_queries = tf.keras.layers.Dense(num_units)
self.linear_keys = tf.keras.layers.Dense(num_units)
self.linear_values = tf.keras.layers.Dense(num_units)
self.linear_output = tf.keras.layers.Dense(num_units)
self.dropout = dropout
self.return_attention = return_attention

def call(self, inputs, memory=None, mask=None, cache=None, training=None):
"""Runs the layer.
Args:
inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`.
memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`.
If ``None``, computes self-attention.
mask: A ``tf.Tensor`` applied to the dot product.
cache: A dictionary containing pre-projected keys and values.
training: Run in training mode.
Returns:
A tuple with the attention context, the updated cache and the attention
probabilities of the first head (if :obj:`return_attention` is ``True``).
"""

def _compute_kv(x):
keys = self.linear_keys(x)
keys = split_heads(keys, self.num_heads)
values = self.linear_values(x)
values = split_heads(values, self.num_heads)
return keys, values

# Compute queries.
queries = self.linear_queries(inputs)
queries = split_heads(queries, self.num_heads)
queries *= (self.num_units // self.num_heads)**-0.5

# Compute keys and values.
if memory is None:
keys, values = _compute_kv(inputs)
if cache is not None:
keys = tf.concat([cache[0], keys], axis=2)
values = tf.concat([cache[1], values], axis=2)
else:
if not cache or tf.equal(tf.shape(cache[0])[2], 0):
keys, values = _compute_kv(memory)
else:
keys, values = cache
cache = (keys, values)

# Dot product attention.
dot = tf.matmul(queries, keys, transpose_b=True)
if mask is not None:
mask = tf.expand_dims(tf.cast(mask, tf.float32), 1) # Broadcast on heads dimension.
dot = tf.cast(tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype)
attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype)
drop_attn = tf.nn.dropout(attn, dropout) if training else attn
heads = tf.matmul(drop_attn, values)

# Concatenate all heads output.
combined = combine_heads(heads)
outputs = self.linear_output(combined)
if self.return_attention:
return outputs, cache, attn
return outputs, cache
68 changes: 68 additions & 0 deletions opennmt/v2/tests/transformer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from parameterized import parameterized

import tensorflow as tf

from opennmt.v2.layers import transformer


class TransformerTest(tf.test.TestCase):

@parameterized.expand([[tf.float32], [tf.float16]])
def testFeedForwardNetwork(self, dtype):
ffn = transformer.FeedForwardNetwork(20, 10)
x = tf.random.uniform([4, 5, 10], dtype=dtype)
y = ffn(x)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, dtype)

@parameterized.expand([[tf.float32], [tf.float16]])
def testMultiHeadSelfAttention(self, dtype):
attention = transformer.MultiHeadAttention(4, 20)
queries = tf.random.uniform([4, 5, 10], dtype=dtype)
mask = tf.expand_dims(tf.sequence_mask([4, 3, 5, 2]), 1)
context, _ = attention(queries, mask=mask)
self.assertListEqual(context.shape.as_list(), [4, 5, 20])
self.assertEqual(context.dtype, dtype)

@parameterized.expand([[tf.float32], [tf.float16]])
def testMultiHeadSelfAttentionWithCache(self, dtype):
cache = (tf.zeros([4, 4, 0, 5], dtype=dtype), tf.zeros([4, 4, 0, 5], dtype=dtype))
attention = transformer.MultiHeadAttention(4, 20)
x = tf.random.uniform([4, 1, 10], dtype=dtype)
_, cache = attention(x, cache=cache)
self.assertEqual(cache[0].shape[2], 1)
self.assertEqual(cache[1].shape[2], 1)
_, cache = attention(x, cache=cache)
self.assertEqual(cache[0].shape[2], 2)
self.assertEqual(cache[0].dtype, dtype)
self.assertEqual(cache[1].shape[2], 2)
self.assertEqual(cache[1].dtype, dtype)

@parameterized.expand([[tf.float32], [tf.float16]])
def testMultiHeadAttention(self, dtype):
attention = transformer.MultiHeadAttention(4, 20)
queries = tf.random.uniform([4, 5, 10], dtype=dtype)
memory = tf.random.uniform([4, 3, 10], dtype=dtype)
mask = tf.expand_dims(tf.sequence_mask([1, 3, 2, 2]), 1)
context, _ = attention(queries, memory=memory, mask=mask)
self.assertListEqual(context.shape.as_list(), [4, 5, 20])
self.assertEqual(context.dtype, dtype)

@parameterized.expand([[tf.float32], [tf.float16]])
def testMultiHeadAttentionWithCache(self, dtype):
cache = (tf.zeros([4, 4, 0, 5], dtype=dtype), tf.zeros([4, 4, 0, 5], dtype=dtype))
attention = transformer.MultiHeadAttention(4, 20)
memory = tf.random.uniform([4, 3, 10], dtype=dtype)
mask = tf.expand_dims(tf.sequence_mask([1, 3, 2, 2]), 1)
x = tf.random.uniform([4, 1, 10], dtype=dtype)
y1, cache = attention(x, memory=memory, mask=mask, cache=cache)
self.assertEqual(cache[0].shape[2], 3)
self.assertEqual(cache[0].dtype, dtype)
self.assertEqual(cache[1].shape[2], 3)
self.assertEqual(cache[1].dtype, dtype)
y2, cache = attention(x, memory=memory, mask=mask, cache=cache)
self.assertAllEqual(y1, y2)


if __name__ == "__main__":
tf.test.main()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from setuptools import setup, find_packages

tests_require = [
"parameterized",
"nose2"
]

Expand Down

0 comments on commit e0c8e1f

Please sign in to comment.