Skip to content

Commit

Permalink
update beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jun 21, 2022
1 parent 81c47f8 commit abb2c13
Showing 1 changed file with 45 additions and 124 deletions.
169 changes: 45 additions & 124 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1992,8 +1991,8 @@ def greedy_search(
batch_size, cur_len = input_ids.shape

# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
generated_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id
generated = tf.concat([input_ids, generated_padding], axis=1)
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)

# 4. define "xla-compile-able" stop-condition and auto-regressive function
Expand Down Expand Up @@ -2049,7 +2048,7 @@ def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs):
finished_sequences = finished_sequences | (next_tokens == eos_token_id)

# update `generated` and `cur_len`
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, (batch_size,))], axis=-1)
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
cur_len += 1

Expand Down Expand Up @@ -2250,8 +2249,8 @@ def sample(
batch_size, cur_len = input_ids.shape

# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id
generated = tf.concat([input_ids, generated_padding], axis=1)
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * pad_token_id
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)

# 4. define "xla-compile-able" stop-condition and auto-regressive function
Expand Down Expand Up @@ -2313,7 +2312,7 @@ def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs):
finished_sequences = finished_sequences | (next_tokens == eos_token_id)

# update `generated` and `cur_len`
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, (batch_size,))], axis=-1)
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
cur_len += 1

Expand Down Expand Up @@ -2559,6 +2558,8 @@ def gather_fn(tensor):
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())

# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
Expand All @@ -2568,41 +2569,11 @@ def gather_fn(tensor):

# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape
input_ids_length = cur_len

# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences = tf.TensorArray(
element_shape=(batch_size, num_beams),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
intermediary_running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams * 2),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
intermediary_running_sequences = intermediary_running_sequences.write(
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
)

# write prompt to running_sequences
for i in range(cur_len):
running_sequences = running_sequences.write(i, input_ids[:, :, i])
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * pad_token_id
running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1)
sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * pad_token_id

# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)
Expand Down Expand Up @@ -2630,7 +2601,6 @@ def beam_search_cond_fn(
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
):
"""
Expand Down Expand Up @@ -2659,27 +2629,18 @@ def beam_search_body_fn(
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
intermediary_running_sequences=None,
):
"""
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
seen so far
"""
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.

# 1. Forward current tokens

# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
input_token = tf.slice(
running_sequences_seq_last,
(0, 0, cur_len - input_ids_length),
(batch_size, num_beams, input_ids_length),
)
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs)
if model_kwargs.get("past") is None or needs_full_input:
input_ids = running_sequences[:, :, :cur_len]
else:
input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), **model_kwargs)
model_outputs = self(
**model_inputs,
return_dict=True,
Expand Down Expand Up @@ -2708,9 +2669,7 @@ def beam_search_body_fn(
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
)
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
vocab_size = log_probs.shape[2]
Expand All @@ -2729,23 +2688,28 @@ def beam_search_body_fn(
beams_to_keep = 2 * num_beams
topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep)
topk_beam_indices = topk_indices // vocab_size
topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices)
topk_running_sequences = gather_beams(running_sequences, topk_beam_indices)
topk_ids = topk_indices % vocab_size

# writes the new token
intermediary_running_sequences = intermediary_running_sequences.unstack(
tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1])
indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep])
indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size])
update_indices = tf.stack(
[indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1
)
topk_sequences = tf.tensor_scatter_nd_update(
tensor=topk_running_sequences,
indices=update_indices,
updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]),
)
topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids)
topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0])

# 4. Check which sequences have ended
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value.
eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id
eos_in_next_token = topk_sequences[:, :, cur_len] == eos_token_id
if eos_token_id is None:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape)
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape,
Expand All @@ -2759,8 +2723,8 @@ def beam_search_body_fn(
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# (from top 2*k beams).
next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1]
next_running_sequences_seq_last, next_running_scores = gather_beams(
[topk_sequences_seq_last, running_topk_log_probs], next_topk_indices
next_running_sequences, next_running_scores = gather_beams(
[topk_sequences, running_topk_log_probs], next_topk_indices
)

# 6. Process topk logits
Expand All @@ -2781,18 +2745,18 @@ def beam_search_body_fn(
# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# to existing finished scores and select the best from the new set of beams
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1)
merged_sequences = tf.concat([sequences, topk_sequences], axis=1)
merged_scores = tf.concat([scores, topk_log_probs], axis=1)
merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1]
next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams(
next_sequences, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices
)

# 8. Prepare data for the next iteration
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
# beam-associated caches.
cur_len = cur_len + 1
if "past_key_values" in model_outputs:
cache = tf.nest.map_structure(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis),
Expand All @@ -2815,35 +2779,20 @@ def beam_search_body_fn(

# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
next_input_ids_length = cur_len + 1
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
else:
next_input_ids_length = 1

# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1]))
next_running_sequences = running_sequences.unstack(
tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1])
)

return (
cur_len + 1,
cur_len,
next_running_sequences,
next_running_scores,
next_sequences,
next_scores,
next_is_sent_finished,
next_input_ids_length,
next_model_kwargs,
)

# 5. run generation
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
beam_search_body_fn = partial(
beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences
)

# 1st generation step has to be run before to initialize `past` (if active)
(
cur_len,
Expand All @@ -2852,66 +2801,38 @@ def beam_search_body_fn(
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
) = beam_search_body_fn(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
)

# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
if beam_search_cond_fn(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
):
maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop(
beam_search_cond_fn,
beam_search_body_fn,
(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
),
(cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs),
maximum_iterations=maximum_iterations,
)

# 6. prepare outputs
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])

# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# running sequences for that batch item.
none_finished = tf.math.reduce_any(is_sent_finished, axis=1)
sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last)
sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
scores = tf.where(none_finished[:, None], scores, running_scores)

# Take best beams for each batch (the score is sorted in ascending order)
sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :])
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences])

if not use_xla:
# Cut for backward compatibility
sequences_seq_last = sequences_seq_last[:, :cur_len]
sequences = sequences[:, :cur_len]

if return_dict_in_generate:
if self.config.is_encoder_decoder:
Expand All @@ -2922,7 +2843,7 @@ def beam_search_body_fn(
)

return TFBeamSearchEncoderDecoderOutput(
sequences=sequences_seq_last,
sequences=sequences,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
Expand All @@ -2932,13 +2853,13 @@ def beam_search_body_fn(
)
else:
return TFBeamSearchDecoderOnlyOutput(
sequences=sequences_seq_last,
sequences=sequences,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequences_seq_last
return sequences


def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
Expand Down

0 comments on commit abb2c13

Please sign in to comment.