Skip to content

Commit

Permalink
Fix whisper kwargs and generation config (huggingface#30018)
Browse files Browse the repository at this point in the history
* clean-up whisper kwargs

* failing test
  • Loading branch information
zucchini-nlp authored Apr 5, 2024
1 parent 9b5a645 commit 76fa17c
Showing 1 changed file with 15 additions and 64 deletions.
79 changes: 15 additions & 64 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def generate(
self._set_language_and_task(
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
)
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
self._set_num_frames(
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
)
Expand Down Expand Up @@ -546,13 +545,13 @@ def generate(
logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token
is_shortform=is_shortform,
num_beams=kwargs.get("num_beams", 1),
num_beams=generation_config.num_beams,
)

# 5. If we're in shortform mode, simple generate the whole input at once and return the output
if is_shortform:
if temperature is not None:
kwargs["temperature"] = temperature
generation_config.temperature = temperature

decoder_input_ids = kwargs.pop("decoder_input_ids", None)
if decoder_input_ids is None:
Expand All @@ -564,8 +563,8 @@ def generate(
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
)

if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
max_new_tokens = kwargs.get("max_new_tokens", 0)
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
raise ValueError(
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
Expand Down Expand Up @@ -666,11 +665,10 @@ def generate(
)

# 6.6 set max new tokens or max length
kwargs = self._set_max_new_tokens_and_length(
self._set_max_new_tokens_and_length(
config=self.config,
decoder_input_ids=decoder_input_ids,
generation_config=generation_config,
kwargs=kwargs,
)

# 6.7 Set current `begin_index` for all logit processors
Expand Down Expand Up @@ -770,9 +768,9 @@ def generate_with_fallback(

for fallback_idx, temperature in enumerate(temperatures):
generation_config.do_sample = temperature is not None and temperature > 0.0

generation_config.temperature = temperature if generation_config.do_sample else 1.0
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
if generation_config.do_sample:
generation_config.num_beams = 1

generate_kwargs = copy.copy(kwargs)
for key in ["do_sample", "temperature", "num_beams"]:
Expand Down Expand Up @@ -1095,20 +1093,15 @@ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
task = getattr(generation_config, "task", None)
language = getattr(generation_config, "language", None)

if kwargs.get("forced_decoder_ids", None) is not None:
forced_decoder_ids = kwargs["forced_decoder_ids"]
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
forced_decoder_ids = generation_config.forced_decoder_ids

forced_decoder_ids = generation_config.forced_decoder_ids
if forced_decoder_ids is not None:
if language is None and task is None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
)
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
forced_decoder_ids = config.forced_decoder_ids
else:
forced_decoder_ids = None

if forced_decoder_ids is not None and task is not None:
logger.info(
Expand Down Expand Up @@ -1288,21 +1281,6 @@ def _check_decoder_input_ids(kwargs):
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
)

@staticmethod
def _set_token_ids(generation_config, config, kwargs):
eos_token_id = kwargs.pop("eos_token_id", None)
decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id
)

generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id
generation_config.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id
)

@staticmethod
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
if return_token_timestamps:
Expand All @@ -1313,7 +1291,6 @@ def _set_num_frames(return_token_timestamps, generation_config, kwargs):
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)

generation_config.num_frames = kwargs.pop("num_frames", None)

@staticmethod
Expand Down Expand Up @@ -1517,47 +1494,21 @@ def _prepare_decoder_input_ids(
return decoder_input_ids, kwargs

@staticmethod
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs):
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config):
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)

passed_max_length = kwargs.pop("max_length", None)
passed_max_new_tokens = kwargs.pop("max_new_tokens", None)
max_length_config = getattr(generation_config, "max_length", None)
max_new_tokens_config = getattr(generation_config, "max_new_tokens", None)

max_new_tokens = None
max_length = None

# Make sure we don't get larger than `max_length`
if passed_max_length is not None and passed_max_new_tokens is None:
max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions)
logger.info(
f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment."
)
elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None:
if generation_config.max_length is not None and generation_config.max_new_tokens is None:
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
logger.info(
f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment."
f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
)
elif (
passed_max_new_tokens is not None
and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
generation_config.max_new_tokens is not None
and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
):
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
elif (
passed_max_new_tokens is None
and max_new_tokens_config is not None
and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions
):
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]

if max_new_tokens is not None:
kwargs["max_new_tokens"] = max_new_tokens

if max_length is not None:
kwargs["max_length"] = max_length

return kwargs
generation_config.max_new_tokens = max_new_tokens

@staticmethod
def _retrieve_compression_ratio(tokens, vocab_size):
Expand Down

0 comments on commit 76fa17c

Please sign in to comment.