Skip to content

Commit

Permalink
Mark variable scope for reuse in LM inference (#382)
Browse files Browse the repository at this point in the history
Some stateful layers like PositionEmbedder recreates their variables
unless the scope is marked for reuse.
  • Loading branch information
guillaumekln authored Mar 11, 2019
1 parent 6a34f95 commit 1020e97
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov
### Fixes and improvements

* Fix compatibility issue with legacy TensorFlow 1.4
* Fix inference of language models

## [1.21.4](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.21.4) (2019-03-07)

Expand Down
21 changes: 11 additions & 10 deletions opennmt/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ def _call(self, features, labels, params, mode):
name=self.name + "/") # Force the name scope.

# Iteratively decode from the last decoder state.
sampled_ids, sampled_length, _ = decoder_util.greedy_decode(
self._decode,
tf.squeeze(start_ids, 1),
constants.END_OF_SENTENCE_ID,
decode_length=params.get("maximum_iterations", 250),
state=state,
min_decode_length=params.get("minimum_decoding_length", 0),
last_step_as_input=True,
sample_from=params.get("sampling_topk", 1),
sample_temperature=params.get("sampling_temperature", 1))
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
sampled_ids, sampled_length, _ = decoder_util.greedy_decode(
self._decode,
tf.squeeze(start_ids, 1),
constants.END_OF_SENTENCE_ID,
decode_length=params.get("maximum_iterations", 250),
state=state,
min_decode_length=params.get("minimum_decoding_length", 0),
last_step_as_input=True,
sample_from=params.get("sampling_topk", 1),
sample_temperature=params.get("sampling_temperature", 1))

# Build the full prediction.
full_ids = tf.concat([ids, sampled_ids], 1)
Expand Down

0 comments on commit 1020e97

Please sign in to comment.