Skip to content

Commit

Permalink
Support for multi source in Transformer decoder (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Nov 22, 2018
1 parent dce7f27 commit 21d2d6f
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 82 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov

### New features

* Multi source Transformer architecture with serial attention layers (see the [example model configuration](https://github.com/OpenNMT/OpenNMT-tf/blob/master/config/models/multi_source_transformer.py))
* Inference now accepts the parameter `bucket_width`: if set, the data will be sorted by length to increase the translation efficiency. The predictions will still be outputted in order as they are available. (Enabled by default when using automatic configuration.)

### Fixes and improvements
Expand Down
28 changes: 28 additions & 0 deletions config/models/multi_source_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Defines a dual source Transformer architecture with serial attention layers
and parameter sharing between the encoders.
See for example https://arxiv.org/pdf/1809.00188.pdf.
"""

import opennmt as onmt

def model():
return onmt.models.SequenceToSequence(
source_inputter=onmt.inputters.ParallelInputter([
onmt.inputters.WordEmbedder(
vocabulary_file_key="source_vocabulary_1",
embedding_size=512),
onmt.inputters.WordEmbedder(
vocabulary_file_key="source_vocabulary_2",
embedding_size=512)]),
target_inputter=onmt.inputters.WordEmbedder(
vocabulary_file_key="target_vocabulary",
embedding_size=512),
num_layers=6,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
share_encoders=True)
9 changes: 8 additions & 1 deletion opennmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ def decode(self,
_ = embedding
if sampling_probability is not None:
raise ValueError("Scheduled sampling is not supported by this decoder")
if (not self.support_multi_source
and memory is not None
and tf.contrib.framework.nest.is_sequence(memory)):
raise ValueError("Multiple source encodings are passed to this decoder "
"but it does not support multi source context. You should "
"instead configure your encoder to merge the different "
"encodings.")

returned_values = self.decode_from_inputs(
inputs,
Expand Down Expand Up @@ -329,7 +336,7 @@ def dynamic_decode_and_search(self,
output_layer = build_output_layer(self.output_size, vocab_size, dtype=dtype)

state = {"decoder": initial_state}
if self.support_alignment_history and not isinstance(memory, (tuple, list)):
if self.support_alignment_history and not tf.contrib.framework.nest.is_sequence(memory):
state["attention"] = tf.zeros([batch_size, 0, tf.shape(memory)[1]], dtype=dtype)

def _symbols_to_logits_fn(ids, step, state):
Expand Down
85 changes: 52 additions & 33 deletions opennmt/decoders/self_attention_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,21 @@ def output_size(self):
def support_alignment_history(self):
return True

def _init_cache(self, batch_size, dtype=tf.float32):
@property
def support_multi_source(self):
return True

def _init_cache(self, batch_size, dtype=tf.float32, num_sources=1):
cache = {}

for l in range(self.num_layers):
proj_cache_shape = [batch_size, self.num_heads, 0, self.num_units // self.num_heads]
layer_cache = {
"memory_keys": tf.zeros(proj_cache_shape, dtype=dtype),
"memory_values": tf.zeros(proj_cache_shape, dtype=dtype),
}
layer_cache = {}
layer_cache["memory"] = [
{
"memory_keys": tf.zeros(proj_cache_shape, dtype=dtype),
"memory_values": tf.zeros(proj_cache_shape, dtype=dtype)
} for _ in range(num_sources)]
if self.self_attention_type == "scaled_dot":
layer_cache["self_keys"] = tf.zeros(proj_cache_shape, dtype=dtype)
layer_cache["self_values"] = tf.zeros(proj_cache_shape, dtype=dtype)
Expand Down Expand Up @@ -121,11 +127,15 @@ def _self_attention_stack(self,
decoder_mask = transformer.cumulative_average_mask(
sequence_length, maximum_length=tf.shape(inputs)[1], dtype=inputs.dtype)

if memory is not None and memory_sequence_length is not None:
memory_mask = transformer.build_sequence_mask(
memory_sequence_length,
num_heads=self.num_heads,
maximum_length=tf.shape(memory)[1])
if memory is not None and not tf.contrib.framework.nest.is_sequence(memory):
memory = (memory,)
if memory_sequence_length is not None:
if not tf.contrib.framework.nest.is_sequence(memory_sequence_length):
memory_sequence_length = (memory_sequence_length,)
memory_mask = [
transformer.build_sequence_mask(
length, num_heads=self.num_heads, maximum_length=tf.shape(m)[1])
for m, length in zip(memory, memory_sequence_length)]

for l in range(self.num_layers):
layer_name = "layer_{}".format(l)
Expand All @@ -142,7 +152,7 @@ def _self_attention_stack(self,
mask=decoder_mask,
cache=layer_cache,
dropout=self.attention_dropout)
encoded = transformer.drop_and_add(
last_context = transformer.drop_and_add(
inputs,
encoded,
mode,
Expand All @@ -160,36 +170,38 @@ def _self_attention_stack(self,
z = tf.layers.dense(tf.concat([x, y], -1), self.num_units * 2)
i, f = tf.split(z, 2, axis=-1)
y = tf.sigmoid(i) * x + tf.sigmoid(f) * y
encoded = transformer.drop_and_add(
last_context = transformer.drop_and_add(
inputs, y, mode, dropout=self.dropout)

if memory is not None:
with tf.variable_scope("multi_head"):
context, last_attention = transformer.multi_head_attention(
self.num_heads,
transformer.norm(encoded),
memory,
mode,
mask=memory_mask,
cache=layer_cache,
dropout=self.attention_dropout,
return_attention=True)
context = transformer.drop_and_add(
encoded,
context,
mode,
dropout=self.dropout)
else:
context = encoded
for i, (mem, mask) in enumerate(zip(memory, memory_mask)):
memory_cache = layer_cache["memory"][i] if layer_cache is not None else None
with tf.variable_scope("multi_head" if i == 0 else "multi_head_%d" % i):
context, last_attention = transformer.multi_head_attention(
self.num_heads,
transformer.norm(last_context),
mem,
mode,
mask=mask,
cache=memory_cache,
dropout=self.attention_dropout,
return_attention=True)
last_context = transformer.drop_and_add(
last_context,
context,
mode,
dropout=self.dropout)
if i > 0: # Do not return attention in case of multi source.
last_attention = None

with tf.variable_scope("ffn"):
transformed = transformer.feed_forward(
transformer.norm(context),
transformer.norm(last_context),
self.ffn_inner_dim,
mode,
dropout=self.relu_dropout)
transformed = transformer.drop_and_add(
context,
last_context,
transformed,
mode,
dropout=self.dropout)
Expand Down Expand Up @@ -227,7 +239,13 @@ def step_fn(self,
memory=None,
memory_sequence_length=None,
dtype=tf.float32):
cache = self._init_cache(batch_size, dtype=dtype)
if memory is None:
num_sources = 0
elif tf.contrib.framework.nest.is_sequence(memory):
num_sources = len(memory)
else:
num_sources = 1
cache = self._init_cache(batch_size, dtype=dtype, num_sources=num_sources)
def _fn(step, inputs, cache, mode):
inputs = tf.expand_dims(inputs, 1)
outputs, attention = self._self_attention_stack(
Expand All @@ -238,6 +256,7 @@ def _fn(step, inputs, cache, mode):
memory_sequence_length=memory_sequence_length,
step=step)
outputs = tf.squeeze(outputs, axis=1)
attention = tf.squeeze(attention, axis=1)
if attention is not None:
attention = tf.squeeze(attention, axis=1)
return outputs, cache, attention
return _fn, cache
10 changes: 8 additions & 2 deletions opennmt/encoders/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def __init__(self,
outputs_reducer=ConcatReducer(axis=1),
states_reducer=JoinReducer(),
outputs_layer_fn=None,
combined_output_layer_fn=None):
combined_output_layer_fn=None,
share_parameters=False):
"""Initializes the parameters of the encoder.
Args:
Expand All @@ -117,6 +118,8 @@ def __init__(self,
output.
combined_output_layer_fn: A callable to apply on the combined output
(i.e. the output of :obj:`outputs_reducer`).
share_parameters: If ``True``, share parameters between the parallel
encoders.
Raises:
ValueError: if :obj:`outputs_layer_fn` is a list with a size not equal
Expand All @@ -132,6 +135,7 @@ def __init__(self,
self.states_reducer = states_reducer if states_reducer is not None else JoinReducer()
self.outputs_layer_fn = outputs_layer_fn
self.combined_output_layer_fn = combined_output_layer_fn
self.share_parameters = share_parameters

def encode(self, inputs, sequence_length=None, mode=tf.estimator.ModeKeys.TRAIN):
all_outputs = []
Expand All @@ -142,7 +146,9 @@ def encode(self, inputs, sequence_length=None, mode=tf.estimator.ModeKeys.TRAIN)
raise ValueError("ParallelEncoder expects as many inputs as parallel encoders")

for i, encoder in enumerate(self.encoders):
with tf.variable_scope("encoder_{}".format(i)):
scope_name = "encoder_{}".format(i) if not self.share_parameters else "parallel_encoder"
reuse = self.share_parameters and i > 0
with tf.variable_scope(scope_name, reuse=reuse):
if tf.contrib.framework.nest.is_sequence(inputs):
encoder_inputs = inputs[i]
length = sequence_length[i]
Expand Down
22 changes: 16 additions & 6 deletions opennmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import tensorflow as tf

from opennmt.layers.reducer import ConcatReducer
from opennmt.layers.reducer import ConcatReducer, JoinReducer
from opennmt.utils.misc import extract_prefixed_keys


Expand All @@ -18,6 +18,11 @@ def __init__(self, dtype=tf.float32):
self.process_hooks = []
self.dtype = dtype

@property
def num_outputs(self):
"""How many parallel outputs does this inputter produce."""
return 1

def add_process_hooks(self, hooks):
"""Adds processing hooks.
Expand Down Expand Up @@ -229,7 +234,7 @@ def transform(self, inputs, mode):
class MultiInputter(Inputter):
"""An inputter that gathers multiple inputters."""

def __init__(self, inputters):
def __init__(self, inputters, reducer=None):
if not isinstance(inputters, list) or not inputters:
raise ValueError("inputters must be a non empty list")
dtype = inputters[0].dtype
Expand All @@ -238,6 +243,13 @@ def __init__(self, inputters):
raise TypeError("All inputters must have the same dtype")
super(MultiInputter, self).__init__(dtype=dtype)
self.inputters = inputters
self.reducer = reducer

@property
def num_outputs(self):
if self.reducer is None or isinstance(self.reducer, JoinReducer):
return len(self.inputters)
return 1

@abc.abstractmethod
def make_dataset(self, data_file):
Expand Down Expand Up @@ -282,8 +294,7 @@ def __init__(self, inputters, reducer=None):
reducer: A :class:`opennmt.layers.reducer.Reducer` to merge all inputs. If
set, parallel inputs are assumed to have the same length.
"""
super(ParallelInputter, self).__init__(inputters)
self.reducer = reducer
super(ParallelInputter, self).__init__(inputters, reducer=reducer)

def get_length(self, data):
lengths = []
Expand Down Expand Up @@ -370,8 +381,7 @@ def __init__(self,
reducer: A :class:`opennmt.layers.reducer.Reducer` to merge all inputs.
dropout: The probability to drop units in the merged inputs.
"""
super(MixedInputter, self).__init__(inputters)
self.reducer = reducer
super(MixedInputter, self).__init__(inputters, reducer=reducer)
self.dropout = dropout

def get_length(self, data):
Expand Down
6 changes: 3 additions & 3 deletions opennmt/inputters/text_inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def __init__(self,
dropout=dropout,
tokenizer=tokenizer,
dtype=dtype)
self.num_outputs = num_outputs
self.output_size = num_outputs
self.kernel_size = kernel_size
self.stride = stride
self.num_oov_buckets = 1
Expand All @@ -522,15 +522,15 @@ def transform(self, inputs, mode):

outputs = tf.layers.conv1d(
outputs,
self.num_outputs,
self.output_size,
self.kernel_size,
strides=self.stride)

# Max pooling over depth.
outputs = tf.reduce_max(outputs, axis=1)

# Split batch and sequence timesteps dimensions.
outputs = tf.reshape(outputs, [-1, tf.shape(inputs)[1], self.num_outputs])
outputs = tf.reshape(outputs, [-1, tf.shape(inputs)[1], self.output_size])

return outputs

Expand Down
35 changes: 25 additions & 10 deletions opennmt/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tensorflow as tf

from opennmt.models.sequence_to_sequence import SequenceToSequence, EmbeddingsSharingLevel
from opennmt.encoders.encoder import ParallelEncoder
from opennmt.encoders.self_attention_encoder import SelfAttentionEncoder
from opennmt.decoders.self_attention_decoder import SelfAttentionDecoder
from opennmt.layers.position import SinusoidalPositionEncoder
Expand All @@ -27,13 +28,15 @@ def __init__(self,
position_encoder=SinusoidalPositionEncoder(),
decoder_self_attention_type="scaled_dot",
share_embeddings=EmbeddingsSharingLevel.NONE,
share_encoders=False,
alignment_file_key="train_alignments",
name="transformer"):
"""Initializes a Transformer model.
Args:
source_inputter: A :class:`opennmt.inputters.inputter.Inputter` to process
the source data.
the source data. If this inputter returns parallel inputs, a multi
source Transformer architecture will be constructed.
target_inputter: A :class:`opennmt.inputters.inputter.Inputter` to process
the target data. Currently, only the
:class:`opennmt.inputters.text_inputter.WordEmbedder` is supported.
Expand All @@ -52,19 +55,31 @@ def __init__(self,
share_embeddings: Level of embeddings sharing, see
:class:`opennmt.models.sequence_to_sequence.EmbeddingsSharingLevel`
for possible values.
share_encoders: In case of multi source architecture, whether to share the
separate encoders parameters or not.
alignment_file_key: The data configuration key of the training alignment
file to support guided alignment.
name: The name of this model.
"""
encoder = SelfAttentionEncoder(
num_layers,
num_units=num_units,
num_heads=num_heads,
ffn_inner_dim=ffn_inner_dim,
dropout=dropout,
attention_dropout=attention_dropout,
relu_dropout=relu_dropout,
position_encoder=position_encoder)
encoders = [
SelfAttentionEncoder(
num_layers,
num_units=num_units,
num_heads=num_heads,
ffn_inner_dim=ffn_inner_dim,
dropout=dropout,
attention_dropout=attention_dropout,
relu_dropout=relu_dropout,
position_encoder=position_encoder)
for _ in range(source_inputter.num_outputs)]
if len(encoders) > 1:
encoder = ParallelEncoder(
encoders,
outputs_reducer=None,
states_reducer=None,
share_parameters=share_encoders)
else:
encoder = encoders[0]
decoder = SelfAttentionDecoder(
num_layers,
num_units=num_units,
Expand Down
Loading

0 comments on commit 21d2d6f

Please sign in to comment.