Skip to content

Commit

Permalink
Support transition layers in SequentialEncoder (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Apr 28, 2018
1 parent 02fa935 commit 5500493
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov

* Return alignment history when decoding from an `AttentionalRNNDecoder` (requires TensorFlow 1.8+ when decoding with beam search)
* Boolean parameter `replace_unknown_target` to replace unknown target tokens by the source token with the highest attention (requires a decoder that returns the alignment history)
* Support for arbitrary transition layers in `SequentialEncoder`

### Fixes and improvements

Expand Down
29 changes: 27 additions & 2 deletions opennmt/encoders/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,47 @@ def encode(self, inputs, sequence_length=None, mode=tf.estimator.ModeKeys.TRAIN)


class SequentialEncoder(Encoder):
"""An encoder that executes multiple encoders sequentially."""
"""An encoder that executes multiple encoders sequentially with optional
transition layers.
def __init__(self, encoders, states_reducer=JoinReducer()):
See for example "Cascaded Encoder" in https://arxiv.org/abs/1804.09849.
"""

def __init__(self, encoders, states_reducer=JoinReducer(), transition_layer_fn=None):
"""Initializes the parameters of the encoder.
Args:
encoders: A list of :class:`opennmt.encoders.encoder.Encoder`.
states_reducer: A :class:`opennmt.layers.reducer.Reducer` to merge all
states.
transition_layer_fn: A callable or list of callables applied to the
output of an encoder before passing it as input to the next. If it is a
single callable, it is applied between every encoders. Otherwise, the
`i`-th callable will be applied between encoder `i` and `i + 1`.
Raises:
ValueError: if :obj:`transition_layer_fn` is a list with a size not equal
to the number of encoder transitions `len(encoders) - 1`.
"""
if (transition_layer_fn is not None and isinstance(transition_layer_fn, list)
and len(transition_layer_fn) != len(encoders) - 1):
raise ValueError("The number of transition layers must match the number of encoder "
"transitions, expected %d layers but got %d."
% (len(encoders) - 1, len(transition_layer_fn)))
self.encoders = encoders
self.states_reducer = states_reducer
self.transition_layer_fn = transition_layer_fn

def encode(self, inputs, sequence_length=None, mode=tf.estimator.ModeKeys.TRAIN):
encoder_state = []

for i, encoder in enumerate(self.encoders):
with tf.variable_scope("encoder_{}".format(i)):
if i > 0 and self.transition_layer_fn is not None:
if isinstance(self.transition_layer_fn, list):
inputs = self.transition_layer_fn[i - 1](inputs)
else:
inputs = self.transition_layer_fn(inputs)
inputs, state, sequence_length = encoder.encode(
inputs,
sequence_length=sequence_length,
Expand All @@ -65,6 +88,8 @@ class ParallelEncoder(Encoder):
sequence (e.g. the non reduced output of a
:class:`opennmt.inputters.inputter.ParallelInputter`), each encoder will encode
its corresponding input in the sequence.
See for example "Multi-Columnn Encoder" in https://arxiv.org/abs/1804.09849.
"""

def __init__(self,
Expand Down
24 changes: 21 additions & 3 deletions opennmt/tests/encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ def testPyramidalEncoder(self):
self.assertAllEqual([3, 6, 10], outputs.shape)
self.assertAllEqual([4, 5, 5], encoded_length)

def testSequentialEncoder(self):
def _testSequentialEncoder(self, transition_layer_fn=None):
sequence_length = [17, 21, 20]
inputs = _build_dummy_sequences(sequence_length)
encoder = encoders.SequentialEncoder([
encoders_sequence = [
encoders.UnidirectionalRNNEncoder(1, 20),
encoders.PyramidalRNNEncoder(3, 10, reduction_factor=2)])
encoders.PyramidalRNNEncoder(3, 10, reduction_factor=2)]
encoder = encoders.SequentialEncoder(
encoders_sequence, transition_layer_fn=transition_layer_fn)
_, state, encoded_length = encoder.encode(
inputs, sequence_length=sequence_length)
self.assertEqual(4, len(state))
Expand All @@ -74,6 +76,22 @@ def testSequentialEncoder(self):
encoded_length = sess.run(encoded_length)
self.assertAllEqual([4, 5, 5], encoded_length)

def testSequentialEncoder(self):
self._testSequentialEncoder()

def testSequentialEncoderWithTransitionLayer(self):
layer_norm_fn = lambda x: tf.contrib.layers.layer_norm(x, begin_norm_axis=-1)
self._testSequentialEncoder(transition_layer_fn=layer_norm_fn)

def testSequentialEncoderWithTransitionLayerList(self):
layer_norm_fn = lambda x: tf.contrib.layers.layer_norm(x, begin_norm_axis=-1)
self._testSequentialEncoder(transition_layer_fn=[layer_norm_fn])

def testSequentialEncoderWithInvalidTransitionLayerList(self):
layer_norm_fn = lambda x: tf.contrib.layers.layer_norm(x, begin_norm_axis=-1)
with self.assertRaises(ValueError):
self._testSequentialEncoder(transition_layer_fn=[layer_norm_fn, layer_norm_fn])

def _testGoogleRNNEncoder(self, num_layers):
sequence_length = [17, 21, 20]
inputs = _build_dummy_sequences(sequence_length)
Expand Down

0 comments on commit 5500493

Please sign in to comment.