From c1f70215d3c5bb2de796f2fbf6a8cd1a016a4b56 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 4 Jun 2024 18:44:46 +0500 Subject: [PATCH 1/5] Fix conflicting key in init kwargs in PreTrainedTokenizerBase --- src/transformers/tokenization_utils_base.py | 3 +++ tests/test_tokenization_common.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 116fbfdf7bbbf0..3cbb9fa44419a9 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1567,6 +1567,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): def __init__(self, **kwargs): # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) self.init_inputs = () + for key in kwargs.keys(): + if hasattr(self, key) and callable(getattr(self, key)): + raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") self.init_kwargs = copy.deepcopy(kwargs) self.name_or_path = kwargs.pop("name_or_path", "") self._processor_class = kwargs.pop("processor_class", None) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 8b0ad38795f26c..4865c1991b16ef 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4319,3 +4319,7 @@ def test_special_token_addition(self): replace_additional_special_tokens=False, ) self.assertEqual(tokenizer_2.additional_special_tokens, ["", "", ""]) + + def test_tokenizer_initialization_with_conflicting_key(self): + with self.assertRaises(AttributeError, msg="conflicts with the method"): + self.get_tokenizer(add_special_tokens=True) From cff6bfc7b33f42042f3b9ac1637a8c5dd47342af Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Wed, 5 Jun 2024 13:00:25 +0500 Subject: [PATCH 2/5] Update code to check for callable key in save_pretrained --- src/transformers/tokenization_utils_base.py | 5 +---- tests/test_tokenization_common.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 3cbb9fa44419a9..947adaa08b56d7 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1567,9 +1567,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): def __init__(self, **kwargs): # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) self.init_inputs = () - for key in kwargs.keys(): - if hasattr(self, key) and callable(getattr(self, key)): - raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") self.init_kwargs = copy.deepcopy(kwargs) self.name_or_path = kwargs.pop("name_or_path", "") self._processor_class = kwargs.pop("processor_class", None) @@ -2467,7 +2464,7 @@ def save_pretrained( target_keys.update(["model_max_length", "clean_up_tokenization_spaces"]) for k in target_keys: - if hasattr(self, k): + if hasattr(self, k) and not callable(getattr(self, k)): tokenizer_config[k] = getattr(self, k) # Let's make sure we properly save the special tokens. diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 4865c1991b16ef..fa7cbf13d85ee9 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4320,6 +4320,13 @@ def test_special_token_addition(self): ) self.assertEqual(tokenizer_2.additional_special_tokens, ["", "", ""]) - def test_tokenizer_initialization_with_conflicting_key(self): - with self.assertRaises(AttributeError, msg="conflicts with the method"): - self.get_tokenizer(add_special_tokens=True) + def test_tokenizer_save_pretrained_with_conflicting_key(self): + if self.test_rust_tokenizer: + tokenizer = self.get_rust_tokenizer(add_special_tokens=True) + else: + tokenizer = self.get_tokenizer(add_special_tokens=True) + + with tempfile.TemporaryDirectory() as tmp_dir_1: + tokenizer.save_pretrained(tmp_dir_1) + loaded_tokenizer = tokenizer.from_pretrained(tmp_dir_1) + assert loaded_tokenizer.init_kwargs.get("add_special_tokens") From 171a02365a440d1da919b75d31f6905e0410460b Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 18 Jun 2024 18:28:30 +0500 Subject: [PATCH 3/5] Apply PR suggestions --- src/transformers/tokenization_utils_base.py | 9 ++++++++- tests/test_tokenization_common.py | 7 +++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index ef14c958916ea1..b1bae45f89b065 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1569,6 +1569,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): def __init__(self, **kwargs): # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) self.init_inputs = () + if "add_special_tokens" in kwargs: + kwargs["_add_special_tokens"] = kwargs.pop("add_special_tokens") + + for key in kwargs: + if hasattr(self, key) and callable(getattr(self, key)): + raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") + self.init_kwargs = copy.deepcopy(kwargs) self.name_or_path = kwargs.pop("name_or_path", "") self._processor_class = kwargs.pop("processor_class", None) @@ -2518,7 +2525,7 @@ def save_pretrained( target_keys.update(["model_max_length", "clean_up_tokenization_spaces"]) for k in target_keys: - if hasattr(self, k) and not callable(getattr(self, k)): + if hasattr(self, k): tokenizer_config[k] = getattr(self, k) # Let's make sure we properly save the special tokens. diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index fa7cbf13d85ee9..ee874b65716ff2 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4320,7 +4320,7 @@ def test_special_token_addition(self): ) self.assertEqual(tokenizer_2.additional_special_tokens, ["", "", ""]) - def test_tokenizer_save_pretrained_with_conflicting_key(self): + def test_tokenizer_initialization_with_conflicting_key(self): if self.test_rust_tokenizer: tokenizer = self.get_rust_tokenizer(add_special_tokens=True) else: @@ -4329,4 +4329,7 @@ def test_tokenizer_save_pretrained_with_conflicting_key(self): with tempfile.TemporaryDirectory() as tmp_dir_1: tokenizer.save_pretrained(tmp_dir_1) loaded_tokenizer = tokenizer.from_pretrained(tmp_dir_1) - assert loaded_tokenizer.init_kwargs.get("add_special_tokens") + assert loaded_tokenizer.init_kwargs.get("_add_special_tokens") + + with self.assertRaises(AttributeError, msg="conflicts with the method"): + self.get_tokenizer(get_vocab=True) From 43558f6d3ad45a9b6281fc2930e354005ffbdb55 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 23 Jul 2024 14:03:15 +0500 Subject: [PATCH 4/5] Invoke CI From 3f05429ea347dcba0ae4b296eb82ac7917cf5af8 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 30 Jul 2024 18:40:57 +0500 Subject: [PATCH 5/5] Updates based on PR suggestion --- src/transformers/tokenization_utils_base.py | 3 --- tests/test_tokenization_common.py | 14 ++++---------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index aec8c817359f63..0e5cd26c8fc07b 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1569,9 +1569,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): def __init__(self, **kwargs): # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) self.init_inputs = () - if "add_special_tokens" in kwargs: - kwargs["_add_special_tokens"] = kwargs.pop("add_special_tokens") - for key in kwargs: if hasattr(self, key) and callable(getattr(self, key)): raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index b869fcb04464a3..56a4922d386730 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4456,15 +4456,9 @@ def test_special_token_addition(self): self.assertEqual(tokenizer_2.additional_special_tokens, ["", "", ""]) def test_tokenizer_initialization_with_conflicting_key(self): - if self.test_rust_tokenizer: - tokenizer = self.get_rust_tokenizer(add_special_tokens=True) - else: - tokenizer = self.get_tokenizer(add_special_tokens=True) - - with tempfile.TemporaryDirectory() as tmp_dir_1: - tokenizer.save_pretrained(tmp_dir_1) - loaded_tokenizer = tokenizer.from_pretrained(tmp_dir_1) - assert loaded_tokenizer.init_kwargs.get("_add_special_tokens") + get_tokenizer_func = self.get_rust_tokenizer if self.test_rust_tokenizer else self.get_tokenizer + with self.assertRaises(AttributeError, msg="conflicts with the method"): + get_tokenizer_func(add_special_tokens=True) with self.assertRaises(AttributeError, msg="conflicts with the method"): - self.get_tokenizer(get_vocab=True) + get_tokenizer_func(get_vocab=True)