From 760eaa66ddfecc196bf5a765a34eddda2e41c236 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 11 Mar 2019 12:34:56 +0100 Subject: [PATCH] Explicit reshape dimension in case the alignment vector is empty (#383) --- CHANGELOG.md | 1 + opennmt/models/sequence_to_sequence.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62c4008b6..81433cb64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov * Fix compatibility issue with legacy TensorFlow 1.4 * Fix inference of language models +* Fix inference error when using `replace_unknown_target` and the alignment vector was empty ## [1.21.4](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.21.4) (2019-03-07) diff --git a/opennmt/models/sequence_to_sequence.py b/opennmt/models/sequence_to_sequence.py index 6f0a11305..bc4f37a05 100644 --- a/opennmt/models/sequence_to_sequence.py +++ b/opennmt/models/sequence_to_sequence.py @@ -10,7 +10,7 @@ from opennmt.models.model import Model from opennmt.utils import compat from opennmt.utils.losses import cross_entropy_sequence_loss -from opennmt.utils.misc import print_bytes, format_translation_output, merge_dict +from opennmt.utils.misc import print_bytes, format_translation_output, merge_dict, shape_list from opennmt.decoders.decoder import get_sampling_probability @@ -277,7 +277,9 @@ def _call(self, features, labels, params, mode): # Merge batch and beam dimensions. original_shape = tf.shape(target_tokens) target_tokens = tf.reshape(target_tokens, [-1, original_shape[-1]]) - attention = tf.reshape(alignment, [-1, tf.shape(alignment)[2], tf.shape(alignment)[3]]) + align_shape = shape_list(alignment) + attention = tf.reshape( + alignment, [align_shape[0] * align_shape[1], align_shape[2], align_shape[3]]) # We don't have attention for but ensure that the attention time dimension matches # the tokens time dimension. attention = reducer.align_in_time(attention, tf.shape(target_tokens)[1])