From aaf1e559f10a8e27829029dbe63b997820e244a6 Mon Sep 17 00:00:00 2001 From: "Ng, Yen Ting" Date: Sat, 20 Apr 2024 07:07:36 -0700 Subject: [PATCH 1/2] Fixed TextGenerationPipeline._sanitize_parameters default params --- src/transformers/pipelines/text_generation.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 0b358291717ee0..591897bf50e53b 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -117,19 +117,25 @@ def _sanitize_parameters( prefix=None, handle_long_generation=None, stop_sequence=None, - add_special_tokens=False, truncation=None, - padding=False, max_length=None, **generate_kwargs, ): - preprocess_params = { - "add_special_tokens": add_special_tokens, - "truncation": truncation, - "padding": padding, - "max_length": max_length, - } + preprocess_params = {} + + add_special_tokens = False + if "add_special_tokens" in generate_kwargs: + preprocess_params["add_special_tokens"] = generate_kwargs["add_special_tokens"] + add_special_tokens = generate_kwargs["add_special_tokens"] + + if "padding" in generate_kwargs: + preprocess_params["padding"] = generate_kwargs["padding"] + + if truncation is not None: + preprocess_params["truncation"] = truncation + if max_length is not None: + preprocess_params["max_length"] = max_length generate_kwargs["max_length"] = max_length if prefix is not None: From 45600291fc847f4edf8927c9bbc790f9056b5b00 Mon Sep 17 00:00:00 2001 From: yting27 Date: Tue, 23 Apr 2024 22:53:40 +0800 Subject: [PATCH 2/2] removed empty spaces --- src/transformers/pipelines/text_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index da9ee1df6d28c0..4cd61450fa7cda 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -133,7 +133,7 @@ def _sanitize_parameters( **generate_kwargs, ): preprocess_params = {} - + add_special_tokens = False if "add_special_tokens" in generate_kwargs: preprocess_params["add_special_tokens"] = generate_kwargs["add_special_tokens"] @@ -141,7 +141,7 @@ def _sanitize_parameters( if "padding" in generate_kwargs: preprocess_params["padding"] = generate_kwargs["padding"] - + if truncation is not None: preprocess_params["truncation"] = truncation