Skip to content

Commit

Permalink
Do not rebuild input and output layers under reused scope (#372)
Browse files Browse the repository at this point in the history
This iterates on ea72cb3 to correctly better handle stateful layers.
  • Loading branch information
guillaumekln authored Mar 6, 2019
1 parent 64f59f1 commit b9df179
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov

### Fixes and improvements

* Fix multi GPU training: some variables were not correctly reused when building the graph for other devices

## [1.21.2](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.21.2) (2019-03-05)

### Fixes and improvements
Expand Down
10 changes: 5 additions & 5 deletions opennmt/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ def auto_config(self, num_devices=1):
}
})

def _call(self, features, labels, params, mode):
training = mode == tf.estimator.ModeKeys.TRAIN
outputs, predictions = None, None

# Initialize input and output layers.
def _build(self):
self.examples_inputter.build()
vocab_size = self.examples_inputter.vocabulary_size
output_layer = None
Expand All @@ -64,6 +60,10 @@ def _call(self, features, labels, params, mode):
dtype=self.examples_inputter.dtype)
self.decoder.initialize(vocab_size=vocab_size, output_layer=output_layer)

def _call(self, features, labels, params, mode):
training = mode == tf.estimator.ModeKeys.TRAIN
outputs, predictions = None, None

ids, length = features["ids"], features["length"]
if mode != tf.estimator.ModeKeys.PREDICT:
# For training and evaluation, forward the full sequence.
Expand Down
7 changes: 7 additions & 0 deletions opennmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from opennmt import estimator
from opennmt import inputters
from opennmt.utils import compat
from opennmt.utils.optim import optimize_loss


Expand Down Expand Up @@ -82,6 +83,8 @@ def __call__(self, features, labels, params, mode, config=None): # pylint: disa
the arguments of this function.
"""
with tf.variable_scope(self.name, initializer=self._initializer(params)):
if not compat.reuse():
self._build() # Always rebuild unless the scope is marked for reuse.
return self._call(features, labels, params, mode)

def _initializer(self, params):
Expand All @@ -99,6 +102,10 @@ def _initializer(self, params):
minval=-param_init, maxval=param_init, dtype=self.dtype)
return None

def _build(self):
"""Builds stateful layers."""
return

@abc.abstractmethod
def _call(self, features, labels, params, mode):
"""Creates the graph.
Expand Down
29 changes: 15 additions & 14 deletions opennmt/models/sequence_to_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(self,
self.encoder = encoder
self.decoder = decoder
self.share_embeddings = share_embeddings
self.output_layer = None

def auto_config(self, num_devices=1):
config = super(SequenceToSequence, self).auto_config(num_devices=num_devices)
Expand All @@ -153,9 +154,19 @@ def auto_config(self, num_devices=1):
}
})

def _build(self):
self.examples_inputter.build()
if EmbeddingsSharingLevel.share_target_embeddings(self.share_embeddings):
self.output_layer = layers.Dense(
self.labels_inputter.vocabulary_size,
weight=self.labels_inputter.embedding,
transpose=True,
dtype=self.labels_inputter.vocabulary_size.dtype)
with tf.name_scope(tf.get_variable_scope().name + "/"):
self.output_layer.build([None, self.decoder.output_size])

def _call(self, features, labels, params, mode):
training = mode == tf.estimator.ModeKeys.TRAIN
self.examples_inputter.build()

features_length = self.features_inputter.get_length(features)
source_inputs = self.features_inputter.make_inputs(features, training=training)
Expand All @@ -167,16 +178,6 @@ def _call(self, features, labels, params, mode):

target_vocab_size = self.labels_inputter.vocabulary_size
target_dtype = self.labels_inputter.dtype
output_layer = None
if EmbeddingsSharingLevel.share_target_embeddings(self.share_embeddings):
output_layer = layers.Dense(
target_vocab_size,
weight=self.labels_inputter.embedding,
transpose=True,
dtype=target_dtype)
with tf.name_scope(tf.get_variable_scope().name + "/"):
output_layer.build([None, self.decoder.output_size])

if labels is not None:
target_inputs = self.labels_inputter.make_inputs(labels, training=training)
with tf.variable_scope("decoder"):
Expand All @@ -195,7 +196,7 @@ def _call(self, features, labels, params, mode):
initial_state=encoder_state,
sampling_probability=sampling_probability,
embedding=self.labels_inputter.embedding,
output_layer=output_layer,
output_layer=self.output_layer,
mode=mode,
memory=encoder_outputs,
memory_sequence_length=encoder_sequence_length,
Expand Down Expand Up @@ -228,7 +229,7 @@ def _call(self, features, labels, params, mode):
end_token,
vocab_size=target_vocab_size,
initial_state=encoder_state,
output_layer=output_layer,
output_layer=self.output_layer,
maximum_iterations=maximum_iterations,
minimum_length=minimum_length,
mode=mode,
Expand All @@ -247,7 +248,7 @@ def _call(self, features, labels, params, mode):
end_token,
vocab_size=target_vocab_size,
initial_state=encoder_state,
output_layer=output_layer,
output_layer=self.output_layer,
beam_width=beam_width,
length_penalty=length_penalty,
maximum_iterations=maximum_iterations,
Expand Down

0 comments on commit b9df179

Please sign in to comment.