Skip to content

Commit

Permalink
ChoiceLayer: prefix decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-wilken committed Sep 14, 2021
1 parent c74613b commit e7bbfd2
Showing 1 changed file with 69 additions and 3 deletions.
72 changes: 69 additions & 3 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4504,7 +4504,7 @@ def __init__(self, beam_size, keep_beams=False,
length_normalization=True,
length_normalization_exponent=1.0,
custom_score_combine=None,
source_beam_sizes=None, scheduled_sampling=False, cheating=False,
source_beam_sizes=None, scheduled_sampling=False, cheating=False, prefix_target=None,
explicit_search_sources=None,
**kwargs):
"""
Expand All @@ -4525,6 +4525,8 @@ def __init__(self, beam_size, keep_beams=False,
:param dict|None scheduled_sampling:
:param bool|str cheating: if True, will always add the true target in the beam.
if "exclusive", enables cheating_exclusive. see :func:`TFUtil.beam_search`.
:param str|None prefix_target: If given, this data stream will be enforced to be the prefix of the layer output,
i.e. for the first n positions, the beam choices will be overwritten by the labels from "prefix_target".
:param list[LayerBase]|None explicit_search_sources: will mark it as an additional dependency.
You might use these also in custom_score_combine.
:param callable|None custom_score_combine:
Expand All @@ -4546,6 +4548,7 @@ def __init__(self, beam_size, keep_beams=False,
self.search_scores_base = None
self.search_scores_combined = None
# We assume log-softmax here, inside the rec layer.
self.prefix_target = prefix_target

if self.search_flag:
if cheating:
Expand Down Expand Up @@ -4699,6 +4702,13 @@ def __init__(self, beam_size, keep_beams=False,
cheating_exclusive=cheating_exclusive)
self.search_choices.set_src_beams(src_beams) # (batch, beam) -> beam_in idx
labels = tf.reshape(labels, [net_batch_dim * beam_size]) # (batch * beam)

if self.prefix_target:
assert len(self.sources) == 1, "Prefix decoding not yet implemented for multiple sources."
labels, scores = self._enforce_prefixes(
top_k_labels=labels, all_scores=scores_comb, top_k_scores=scores, batch_dim=net_batch_dim,
beam_size=beam_size)

labels = tf.cast(labels, self.output.dtype)

if len(self.sources) > 1:
Expand Down Expand Up @@ -5008,6 +5018,59 @@ def _get_cheating_targets_and_src_beam_idxs(self, scores):
src_beams = src_beams[:, 0] # (batch,)
return cheating_gold_targets, src_beams

def _enforce_prefixes(self, top_k_labels, all_scores, top_k_scores, batch_dim, beam_size):
"""
This function replaces the target labels from beam search by the ones predefined by the target prefixes as long
as search is still at a position within the prefix. We also replace the scores such that they correspond to a
prediction of the prefixes.
:param tf.Tensor top_k_labels: target labels from beam seach, shape (batch * beam,)
:param tf.Tensor all_scores: scores before beam pruning, used to lookup prefix scores, shape (batch, beam, dim)
:param tf.Tensor top_k_scores: scores after beam pruning, shape (batch, beam)
:param tf.Tensor|int batch_dim: number of sequences in batch (without beam)
:param int beam_size: outgoing beam size of this layer
:return: labels (batch * beam,) and scores (batch, beam) of self.prefix_target as long as within prefix, after
that top_k_labels and top_k_scores from beam search
:rtype: (tf.Tensor, tf.Tensor)
"""
assert self.prefix_target

# Get the labels of the prefixes which should be enforced. They are padded with zeros, therefore we will
# get zeros for those sequences where the current timestep is beyond the length of the prefix.
target_prefix_labels = self._get_target_value(
target=self.prefix_target).get_placeholder_as_batch_major() # (batch * beam,), int32

# Get prefixes that have already ended (i.e. have a smaller length than the current time step).
target_prefix_ended = tf.equal(target_prefix_labels, 0)

# Select between the prefixes and the labels from free decoding, depending on whether the prefix
# has still got to be enforced.
labels = tf.where(target_prefix_ended, top_k_labels, target_prefix_labels)

# Get rid of the redundant beam, all entries are the same, only keep first entry.
target_prefix_labels = tf.reshape(target_prefix_labels, [batch_dim, beam_size])[:, 0] # (batch,)

# Now also get the scores for the prefixes. First, select only the first entry of the incoming beam as all entries
# are the same if we are still within the prefix.
all_scores = all_scores[:, 0, :] # (batch, dim)

# Gather scores for the prefix labels.
from returnn.tf.util.basic import nd_indices
target_prefix_nd_indices = nd_indices(target_prefix_labels)
prefix_scores = tf.expand_dims(tf.gather_nd(all_scores, target_prefix_nd_indices), axis=-1) # (batch, 1)

# Create an artificial beam, where all but the first scores are infinite. Tiling the one entry we have would
# lead to a beam of all equal hypotheses for the rest of the search.
# Conceptually similar to TFUtil.filter_ended_scores().
prefix_scores_padding = tf.fill([batch_dim, beam_size - 1], -1.e30)
prefix_scores = tf.concat([prefix_scores, prefix_scores_padding], axis=1)

# Use prefix scores for sequences where the prefix has not ended yet.
target_prefix_ended = tf.reshape(target_prefix_ended, [batch_dim, beam_size])
scores = tf.where(target_prefix_ended, top_k_scores, prefix_scores) # (batch, beam)

return labels, scores

@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
Expand Down Expand Up @@ -5063,8 +5126,8 @@ def _create_search_beam(cls, name, beam_size, sources, network):
name="%s%s" % (network.get_absolute_name_prefix(), name))

@classmethod
def get_out_data_from_opts(cls, name, sources, target, network,
beam_size, search=NotSpecified, scheduled_sampling=False, cheating=False, **kwargs):
def get_out_data_from_opts(cls, name, sources, target, network, beam_size, search=NotSpecified,
scheduled_sampling=False, cheating=False, prefix_target=None, **kwargs):
"""
:param str name:
:param list[LayerBase] sources:
Expand All @@ -5074,6 +5137,7 @@ def get_out_data_from_opts(cls, name, sources, target, network,
:param NotSpecified|bool search:
:param dict|bool scheduled_sampling:
:param bool cheating:
:param str prefix_target:
:rtype: Data
"""
search = NotSpecified.resolve(search, network.search_flag)
Expand All @@ -5099,6 +5163,8 @@ def get_out_data_from_opts(cls, name, sources, target, network,
out_data.batch = out_data.batch.copy_set_beam(out_data.beam)
if cheating or scheduled_sampling or not search:
cls._static_get_target_value(target=target, network=network, mark_data_key_as_used=True) # mark as used
if search and prefix_target:
cls._static_get_target_value(target=prefix_target, network=network, mark_data_key_as_used=True) # mark as used
return out_data

def get_sub_layer(self, layer_name):
Expand Down

0 comments on commit e7bbfd2

Please sign in to comment.