diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 4d30a22c768d09..bd88b67bc6cb13 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -511,7 +511,6 @@ def generate( self._set_language_and_task( language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config ) - self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs) self._set_num_frames( return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs ) @@ -546,13 +545,13 @@ def generate( logits_processor=logits_processor, begin_index=begin_index, # begin index is index of first generated decoder token is_shortform=is_shortform, - num_beams=kwargs.get("num_beams", 1), + num_beams=generation_config.num_beams, ) # 5. If we're in shortform mode, simple generate the whole input at once and return the output if is_shortform: if temperature is not None: - kwargs["temperature"] = temperature + generation_config.temperature = temperature decoder_input_ids = kwargs.pop("decoder_input_ids", None) if decoder_input_ids is None: @@ -564,8 +563,8 @@ def generate( [prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1 ) - if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions: - max_new_tokens = kwargs.get("max_new_tokens", 0) + max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0 + if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions: raise ValueError( f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " f"is {max_new_tokens}. Thus, the combined length of " @@ -666,11 +665,10 @@ def generate( ) # 6.6 set max new tokens or max length - kwargs = self._set_max_new_tokens_and_length( + self._set_max_new_tokens_and_length( config=self.config, decoder_input_ids=decoder_input_ids, generation_config=generation_config, - kwargs=kwargs, ) # 6.7 Set current `begin_index` for all logit processors @@ -770,9 +768,9 @@ def generate_with_fallback( for fallback_idx, temperature in enumerate(temperatures): generation_config.do_sample = temperature is not None and temperature > 0.0 - generation_config.temperature = temperature if generation_config.do_sample else 1.0 - generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1 + if generation_config.do_sample: + generation_config.num_beams = 1 generate_kwargs = copy.copy(kwargs) for key in ["do_sample", "temperature", "num_beams"]: @@ -1095,11 +1093,8 @@ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]): task = getattr(generation_config, "task", None) language = getattr(generation_config, "language", None) - if kwargs.get("forced_decoder_ids", None) is not None: - forced_decoder_ids = kwargs["forced_decoder_ids"] - elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None: - forced_decoder_ids = generation_config.forced_decoder_ids - + forced_decoder_ids = generation_config.forced_decoder_ids + if forced_decoder_ids is not None: if language is None and task is None and forced_decoder_ids[0][1] is None: logger.warning_once( "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English." @@ -1107,8 +1102,6 @@ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]): ) elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: forced_decoder_ids = config.forced_decoder_ids - else: - forced_decoder_ids = None if forced_decoder_ids is not None and task is not None: logger.info( @@ -1288,21 +1281,6 @@ def _check_decoder_input_ids(kwargs): "Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.", ) - @staticmethod - def _set_token_ids(generation_config, config, kwargs): - eos_token_id = kwargs.pop("eos_token_id", None) - decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) - - eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id - ) - - generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id - generation_config.decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id - ) - @staticmethod def _set_num_frames(return_token_timestamps, generation_config, kwargs): if return_token_timestamps: @@ -1313,7 +1291,6 @@ def _set_num_frames(return_token_timestamps, generation_config, kwargs): "Model generation config has no `alignment_heads`, token-level timestamps not available. " "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." ) - generation_config.num_frames = kwargs.pop("num_frames", None) @staticmethod @@ -1517,47 +1494,21 @@ def _prepare_decoder_input_ids( return decoder_input_ids, kwargs @staticmethod - def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): + def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config): num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1) - passed_max_length = kwargs.pop("max_length", None) - passed_max_new_tokens = kwargs.pop("max_new_tokens", None) - max_length_config = getattr(generation_config, "max_length", None) - max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) - - max_new_tokens = None - max_length = None - # Make sure we don't get larger than `max_length` - if passed_max_length is not None and passed_max_new_tokens is None: - max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions) - logger.info( - f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment." - ) - elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: + if generation_config.max_length is not None and generation_config.max_new_tokens is None: max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions) logger.info( - f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment." + f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment." ) elif ( - passed_max_new_tokens is not None - and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions + generation_config.max_new_tokens is not None + and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions ): max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] - elif ( - passed_max_new_tokens is None - and max_new_tokens_config is not None - and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions - ): - max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] - - if max_new_tokens is not None: - kwargs["max_new_tokens"] = max_new_tokens - - if max_length is not None: - kwargs["max_length"] = max_length - - return kwargs + generation_config.max_new_tokens = max_new_tokens @staticmethod def _retrieve_compression_ratio(tokens, vocab_size):