-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add v2 Transformer layers * Fix Travis * Fix Travis * Fix Travis
- Loading branch information
1 parent
d5bbf9f
commit e0c8e1f
Showing
5 changed files
with
203 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""Layers module.""" | ||
|
||
from opennmt.v2.layers.common import LayerNorm | ||
from opennmt.v2.layers.transformer import FeedForwardNetwork, MultiHeadAttention |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# pylint: disable=arguments-differ | ||
|
||
"""Define layers related to the Google's Transformer model.""" | ||
|
||
import tensorflow as tf | ||
|
||
from opennmt.layers.transformer import combine_heads, split_heads | ||
|
||
|
||
class FeedForwardNetwork(tf.keras.layers.Layer): | ||
"""Implements the Transformer's "Feed Forward" layer. | ||
.. math:: | ||
ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 | ||
""" | ||
|
||
def __init__(self, inner_dim, output_dim, dropout=0.1, **kwargs): | ||
"""Initializes this layer. | ||
Args: | ||
inner_dim: The number of units of the inner linear transformation. | ||
output_dim: The number of units of the ouput linear transformation. | ||
dropout: The probability to drop units from the inner transformation. | ||
kwargs: Additional layer arguments. | ||
""" | ||
super(FeedForwardNetwork, self).__init__(**kwargs) | ||
self.inner = tf.keras.layers.Dense(inner_dim, activation=tf.nn.relu) | ||
self.outer = tf.keras.layers.Dense(output_dim) | ||
self.dropout = dropout | ||
|
||
def call(self, inputs, training=None): | ||
"""Runs the layer.""" | ||
inner = self.inner(inputs) | ||
if training: | ||
inner = tf.nn.dropout(inner, rate=self.dropout) | ||
return self.outer(inner) | ||
|
||
|
||
class MultiHeadAttention(tf.keras.layers.Layer): | ||
"""Computes the multi-head attention as described in | ||
https://arxiv.org/abs/1706.03762. | ||
""" | ||
|
||
def __init__(self, | ||
num_heads, | ||
num_units, | ||
dropout=0.1, | ||
return_attention=False, | ||
**kwargs): | ||
"""Initializes this layers. | ||
Args: | ||
num_heads: The number of attention heads. | ||
num_units: The number of hidden units. | ||
dropout: The probability to drop units from the inputs. | ||
return_attention: If ``True``, also return the attention weights of the | ||
first head. | ||
kwargs: Additional layer arguments. | ||
""" | ||
super(MultiHeadAttention, self).__init__(**kwargs) | ||
if num_units % num_heads != 0: | ||
raise ValueError("Multi head attention requires that num_units is a" | ||
" multiple of %s" % num_heads) | ||
self.num_heads = num_heads | ||
self.num_units = num_units | ||
self.linear_queries = tf.keras.layers.Dense(num_units) | ||
self.linear_keys = tf.keras.layers.Dense(num_units) | ||
self.linear_values = tf.keras.layers.Dense(num_units) | ||
self.linear_output = tf.keras.layers.Dense(num_units) | ||
self.dropout = dropout | ||
self.return_attention = return_attention | ||
|
||
def call(self, inputs, memory=None, mask=None, cache=None, training=None): | ||
"""Runs the layer. | ||
Args: | ||
inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. | ||
memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. | ||
If ``None``, computes self-attention. | ||
mask: A ``tf.Tensor`` applied to the dot product. | ||
cache: A dictionary containing pre-projected keys and values. | ||
training: Run in training mode. | ||
Returns: | ||
A tuple with the attention context, the updated cache and the attention | ||
probabilities of the first head (if :obj:`return_attention` is ``True``). | ||
""" | ||
|
||
def _compute_kv(x): | ||
keys = self.linear_keys(x) | ||
keys = split_heads(keys, self.num_heads) | ||
values = self.linear_values(x) | ||
values = split_heads(values, self.num_heads) | ||
return keys, values | ||
|
||
# Compute queries. | ||
queries = self.linear_queries(inputs) | ||
queries = split_heads(queries, self.num_heads) | ||
queries *= (self.num_units // self.num_heads)**-0.5 | ||
|
||
# Compute keys and values. | ||
if memory is None: | ||
keys, values = _compute_kv(inputs) | ||
if cache is not None: | ||
keys = tf.concat([cache[0], keys], axis=2) | ||
values = tf.concat([cache[1], values], axis=2) | ||
else: | ||
if not cache or tf.equal(tf.shape(cache[0])[2], 0): | ||
keys, values = _compute_kv(memory) | ||
else: | ||
keys, values = cache | ||
cache = (keys, values) | ||
|
||
# Dot product attention. | ||
dot = tf.matmul(queries, keys, transpose_b=True) | ||
if mask is not None: | ||
mask = tf.expand_dims(tf.cast(mask, tf.float32), 1) # Broadcast on heads dimension. | ||
dot = tf.cast(tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), dot.dtype) | ||
attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) | ||
drop_attn = tf.nn.dropout(attn, dropout) if training else attn | ||
heads = tf.matmul(drop_attn, values) | ||
|
||
# Concatenate all heads output. | ||
combined = combine_heads(heads) | ||
outputs = self.linear_output(combined) | ||
if self.return_attention: | ||
return outputs, cache, attn | ||
return outputs, cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from parameterized import parameterized | ||
|
||
import tensorflow as tf | ||
|
||
from opennmt.v2.layers import transformer | ||
|
||
|
||
class TransformerTest(tf.test.TestCase): | ||
|
||
@parameterized.expand([[tf.float32], [tf.float16]]) | ||
def testFeedForwardNetwork(self, dtype): | ||
ffn = transformer.FeedForwardNetwork(20, 10) | ||
x = tf.random.uniform([4, 5, 10], dtype=dtype) | ||
y = ffn(x) | ||
self.assertEqual(y.shape, x.shape) | ||
self.assertEqual(y.dtype, dtype) | ||
|
||
@parameterized.expand([[tf.float32], [tf.float16]]) | ||
def testMultiHeadSelfAttention(self, dtype): | ||
attention = transformer.MultiHeadAttention(4, 20) | ||
queries = tf.random.uniform([4, 5, 10], dtype=dtype) | ||
mask = tf.expand_dims(tf.sequence_mask([4, 3, 5, 2]), 1) | ||
context, _ = attention(queries, mask=mask) | ||
self.assertListEqual(context.shape.as_list(), [4, 5, 20]) | ||
self.assertEqual(context.dtype, dtype) | ||
|
||
@parameterized.expand([[tf.float32], [tf.float16]]) | ||
def testMultiHeadSelfAttentionWithCache(self, dtype): | ||
cache = (tf.zeros([4, 4, 0, 5], dtype=dtype), tf.zeros([4, 4, 0, 5], dtype=dtype)) | ||
attention = transformer.MultiHeadAttention(4, 20) | ||
x = tf.random.uniform([4, 1, 10], dtype=dtype) | ||
_, cache = attention(x, cache=cache) | ||
self.assertEqual(cache[0].shape[2], 1) | ||
self.assertEqual(cache[1].shape[2], 1) | ||
_, cache = attention(x, cache=cache) | ||
self.assertEqual(cache[0].shape[2], 2) | ||
self.assertEqual(cache[0].dtype, dtype) | ||
self.assertEqual(cache[1].shape[2], 2) | ||
self.assertEqual(cache[1].dtype, dtype) | ||
|
||
@parameterized.expand([[tf.float32], [tf.float16]]) | ||
def testMultiHeadAttention(self, dtype): | ||
attention = transformer.MultiHeadAttention(4, 20) | ||
queries = tf.random.uniform([4, 5, 10], dtype=dtype) | ||
memory = tf.random.uniform([4, 3, 10], dtype=dtype) | ||
mask = tf.expand_dims(tf.sequence_mask([1, 3, 2, 2]), 1) | ||
context, _ = attention(queries, memory=memory, mask=mask) | ||
self.assertListEqual(context.shape.as_list(), [4, 5, 20]) | ||
self.assertEqual(context.dtype, dtype) | ||
|
||
@parameterized.expand([[tf.float32], [tf.float16]]) | ||
def testMultiHeadAttentionWithCache(self, dtype): | ||
cache = (tf.zeros([4, 4, 0, 5], dtype=dtype), tf.zeros([4, 4, 0, 5], dtype=dtype)) | ||
attention = transformer.MultiHeadAttention(4, 20) | ||
memory = tf.random.uniform([4, 3, 10], dtype=dtype) | ||
mask = tf.expand_dims(tf.sequence_mask([1, 3, 2, 2]), 1) | ||
x = tf.random.uniform([4, 1, 10], dtype=dtype) | ||
y1, cache = attention(x, memory=memory, mask=mask, cache=cache) | ||
self.assertEqual(cache[0].shape[2], 3) | ||
self.assertEqual(cache[0].dtype, dtype) | ||
self.assertEqual(cache[1].shape[2], 3) | ||
self.assertEqual(cache[1].dtype, dtype) | ||
y2, cache = attention(x, memory=memory, mask=mask, cache=cache) | ||
self.assertAllEqual(y1, y2) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from setuptools import setup, find_packages | ||
|
||
tests_require = [ | ||
"parameterized", | ||
"nose2" | ||
] | ||
|
||
|