From eb811449a2389e48930c45f84c88fd041735cf92 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 4 Nov 2024 21:35:37 +0100 Subject: [PATCH] Fix Whisper CI (#34541) update Co-authored-by: ydshieh --- src/transformers/generation/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6e6d5b8bdce71d..53cd2df3a49c84 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1452,10 +1452,11 @@ def _prepare_generated_length( ): generation_config.max_length -= inputs_tensor.shape[1] elif has_default_max_length: # by default let's always generate 20 new tokens - generation_config.max_length = generation_config.max_length + input_ids_length - max_position_embeddings = getattr(self.config, "max_position_embeddings", None) - if max_position_embeddings is not None: - generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) # same for min length if generation_config.min_new_tokens is not None: