Skip to content

Commit

Permalink
Add argument to SequenceToSequence models to share embeddings (#195)
Browse files Browse the repository at this point in the history
Start by supporting simple source and target sharing.
  • Loading branch information
guillaumekln authored Aug 7, 2018
1 parent 214af26 commit 05e7295
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov
### New features

* Command line option `--session_config` to configure TensorFlow session parameters (see the "Configuration" documentation)
* `share_embeddings` argument to `SequenceToSequence` models to configure the level of embeddings sharing

### Fixes and improvements

Expand Down
2 changes: 1 addition & 1 deletion opennmt/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module defining models."""

from opennmt.models.transformer import Transformer
from opennmt.models.sequence_to_sequence import SequenceToSequence
from opennmt.models.sequence_to_sequence import SequenceToSequence, EmbeddingsSharingLevel
from opennmt.models.sequence_tagger import SequenceTagger
from opennmt.models.sequence_classifier import SequenceClassifier
87 changes: 61 additions & 26 deletions opennmt/models/sequence_to_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,28 @@ def shift_target_sequence(inputter, data):

return data

def _maybe_reuse_embedding_fn(embedding_fn, scope=None):
def _scoped_embedding_fn(ids):
try:
with tf.variable_scope(scope):
return embedding_fn(ids)
except ValueError:
with tf.variable_scope(scope, reuse=True):
return embedding_fn(ids)
return _scoped_embedding_fn


class EmbeddingsSharingLevel(object):
"""Level of embeddings sharing.
Possible values are:
* ``NONE``: no sharing (default)
* ``SOURCE_TARGET_INPUT``: share source and target word embeddings
"""
NONE = 0
SOURCE_TARGET_INPUT = 1


class SequenceToSequence(Model):
"""A sequence to sequence model."""
Expand All @@ -52,6 +74,7 @@ def __init__(self,
target_inputter,
encoder,
decoder,
share_embeddings=EmbeddingsSharingLevel.NONE,
daisy_chain_variables=False,
name="seq2seq"):
"""Initializes a sequence-to-sequence model.
Expand All @@ -64,13 +87,17 @@ def __init__(self,
:class:`opennmt.inputters.text_inputter.WordEmbedder` is supported.
encoder: A :class:`opennmt.encoders.encoder.Encoder` to encode the source.
decoder: A :class:`opennmt.decoders.decoder.Decoder` to decode the target.
share_embeddings: Level of embeddings sharing, see
:class:`opennmt.models.sequence_to_sequence.EmbeddingsSharingLevel`
for possible values.
daisy_chain_variables: If ``True``, copy variables in a daisy chain
between devices for this model. Not compatible with RNN based models.
name: The name of this model.
Raises:
TypeError: if :obj:`target_inputter` is not a
:class:`opennmt.inputters.text_inputter.WordEmbedder` or if
:class:`opennmt.inputters.text_inputter.WordEmbedder` (same for
:obj:`source_inputter` when embeddings sharing is enabled) or if
:obj:`source_inputter` and :obj:`target_inputter` do not have the same
``dtype``.
"""
Expand All @@ -80,6 +107,10 @@ def __init__(self,
"saw: {} and {}".format(source_inputter.dtype, target_inputter.dtype))
if not isinstance(target_inputter, inputters.WordEmbedder):
raise TypeError("Target inputter must be a WordEmbedder")
if share_embeddings == EmbeddingsSharingLevel.SOURCE_TARGET_INPUT:
if not isinstance(source_inputter, inputters.WordEmbedder):
raise TypeError("Sharing embeddings requires both inputters to be a "
"WordEmbedder")

super(SequenceToSequence, self).__init__(
name,
Expand All @@ -89,39 +120,47 @@ def __init__(self,

self.encoder = encoder
self.decoder = decoder
self.share_embeddings = share_embeddings
self.source_inputter = source_inputter
self.target_inputter = target_inputter
self.target_inputter.add_process_hooks([shift_target_sequence])

def _scoped_target_embedding_fn(self, mode, scope):
def _target_embedding_fn(ids):
try:
with tf.variable_scope(scope):
return self.target_inputter.transform(ids, mode=mode)
except ValueError:
with tf.variable_scope(scope, reuse=True):
return self.target_inputter.transform(ids, mode=mode)
return _target_embedding_fn
def _get_input_scope(self, default_name=""):
if self.share_embeddings == EmbeddingsSharingLevel.SOURCE_TARGET_INPUT:
name = "shared_embeddings"
else:
name = default_name
return tf.VariableScope(None, name=tf.get_variable_scope().name + "/" + name)

def _build(self, features, labels, params, mode, config=None):
features_length = self._get_features_length(features)
log_dir = config.model_dir if config is not None else None

source_input_scope = self._get_input_scope(default_name="encoder")
target_input_scope = self._get_input_scope(default_name="decoder")

source_inputs = _maybe_reuse_embedding_fn(
lambda ids: self.source_inputter.transform_data(ids, mode=mode, log_dir=log_dir),
scope=source_input_scope)(features)

with tf.variable_scope("encoder"):
source_inputs = self.source_inputter.transform_data(
features,
mode=mode,
log_dir=log_dir)
encoder_outputs, encoder_state, encoder_sequence_length = self.encoder.encode(
source_inputs,
sequence_length=features_length,
mode=mode)

target_vocab_size = self.target_inputter.vocabulary_size
target_dtype = self.target_inputter.dtype
target_embedding_fn = _maybe_reuse_embedding_fn(
lambda ids: self.target_inputter.transform(ids, mode=mode),
scope=target_input_scope)

with tf.variable_scope("decoder") as decoder_scope:
if labels is not None:
if labels is not None:
target_inputs = _maybe_reuse_embedding_fn(
lambda ids: self.target_inputter.transform_data(ids, mode=mode, log_dir=log_dir),
scope=target_input_scope)(labels)

with tf.variable_scope("decoder"):
sampling_probability = None
if mode == tf.estimator.ModeKeys.TRAIN:
sampling_probability = get_sampling_probability(
Expand All @@ -130,25 +169,21 @@ def _build(self, features, labels, params, mode, config=None):
schedule_type=params.get("scheduled_sampling_type"),
k=params.get("scheduled_sampling_k"))

target_inputs = self.target_inputter.transform_data(
labels,
mode=mode,
log_dir=log_dir)
logits, _, _ = self.decoder.decode(
target_inputs,
self._get_labels_length(labels),
vocab_size=target_vocab_size,
initial_state=encoder_state,
sampling_probability=sampling_probability,
embedding=self._scoped_target_embedding_fn(mode, decoder_scope),
embedding=target_embedding_fn,
mode=mode,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length)
else:
logits = None
else:
logits = None

if mode != tf.estimator.ModeKeys.TRAIN:
with tf.variable_scope(decoder_scope, reuse=labels is not None) as decoder_scope:
with tf.variable_scope("decoder", reuse=labels is not None):
batch_size = tf.shape(encoder_sequence_length)[0]
beam_width = params.get("beam_width", 1)
maximum_iterations = params.get("maximum_iterations", 250)
Expand All @@ -157,7 +192,7 @@ def _build(self, features, labels, params, mode, config=None):

if beam_width <= 1:
sampled_ids, _, sampled_length, log_probs, alignment = self.decoder.dynamic_decode(
self._scoped_target_embedding_fn(mode, decoder_scope),
target_embedding_fn,
start_tokens,
end_token,
vocab_size=target_vocab_size,
Expand All @@ -172,7 +207,7 @@ def _build(self, features, labels, params, mode, config=None):
length_penalty = params.get("length_penalty", 0)
sampled_ids, _, sampled_length, log_probs, alignment = (
self.decoder.dynamic_decode_and_search(
self._scoped_target_embedding_fn(mode, decoder_scope),
target_embedding_fn,
start_tokens,
end_token,
vocab_size=target_vocab_size,
Expand Down
7 changes: 6 additions & 1 deletion opennmt/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import tensorflow as tf

from opennmt.models.sequence_to_sequence import SequenceToSequence
from opennmt.models.sequence_to_sequence import SequenceToSequence, EmbeddingsSharingLevel
from opennmt.encoders.self_attention_encoder import SelfAttentionEncoder
from opennmt.decoders.self_attention_decoder import SelfAttentionDecoder
from opennmt.layers.position import SinusoidalPositionEncoder
Expand All @@ -25,6 +25,7 @@ def __init__(self,
relu_dropout=0.1,
position_encoder=SinusoidalPositionEncoder(),
decoder_self_attention_type="scaled_dot",
share_embeddings=EmbeddingsSharingLevel.NONE,
name="transformer"):
"""Initializes a Transformer model.
Expand All @@ -46,6 +47,9 @@ def __init__(self,
apply on the inputs.
decoder_self_attention_type: Type of self attention in the decoder,
"scaled_dot" or "average" (case insensitive).
share_embeddings: Level of embeddings sharing, see
:class:`opennmt.models.sequence_to_sequence.EmbeddingsSharingLevel`
for possible values.
name: The name of this model.
"""
encoder = SelfAttentionEncoder(
Expand Down Expand Up @@ -73,6 +77,7 @@ def __init__(self,
target_inputter,
encoder,
decoder,
share_embeddings=share_embeddings,
daisy_chain_variables=True,
name=name)

Expand Down

0 comments on commit 05e7295

Please sign in to comment.