diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index fa2902cfc25126..2ceed1b46d4899 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -367,6 +367,25 @@ def __init__(self, **kwargs): self._decode_use_source_tokenizer = False + @property + def is_fast(self) -> bool: + return False + + @property + def vocab_size(self) -> int: + """ + `int`: Size of the base vocabulary (without the added tokens). + """ + raise NotImplementedError + + @property + def added_tokens_encoder(self) -> Dict[str, int]: + """ + Returns the sorted mapping from string to index. The added tokens encoder is cached for performance + optimisation in `self._added_tokens_encoder` for the slow tokenizers. + """ + return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])} + @property def added_tokens_decoder(self) -> Dict[int, AddedToken]: """ @@ -389,17 +408,6 @@ def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> Dict self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token self._added_tokens_encoder[str(token)] = index - @property - def is_fast(self) -> bool: - return False - - @property - def vocab_size(self) -> int: - """ - `int`: Size of the base vocabulary (without the added tokens). - """ - raise NotImplementedError - def get_added_vocab(self) -> Dict[str, int]: """ Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 9bc20aaef80423..1f2cf6e436f3da 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -846,15 +846,26 @@ def __init__(self, verbose=True, **kwargs): # We directly set the hidden value to allow initialization with special tokens # which are not yet in the vocabulary. Necessary for serialization/de-serialization # TODO clean this up at some point (probably by switching to fast tokenizers) + for key, value in kwargs.items(): if value is None: continue if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == "additional_special_tokens": + # TODO THIS IS NASTY! Will always reset tokens to default rstrip and lstrip because self.set_attr on strings + # will not check the addedtokens decoder. WILL FIX TOMORROW assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple" assert all( isinstance(t, (str, AddedToken)) for t in value ), "One of the tokens is not a string or an AddedToken" + if hasattr(self, "added_tokens_encoder"): + extended_token = [] + for token in value: + if isinstance(token, str) and str(token) in self.added_tokens_encoder: + extended_token.append(self.added_tokens_decoder[self.added_tokens_encoder[str(token)]]) + else: + extended_token.append(token) + value = extended_token setattr(self, key, value) elif isinstance(value, (str)): value = AddedToken(value, normalized=False, special=True) @@ -1674,14 +1685,6 @@ def _set_processor_class(self, processor_class: str): """Sets processor class as an attribute.""" self._processor_class = processor_class - @property - def added_tokens_encoder(self) -> Dict[str, int]: - """ - Returns the sorted mapping from string to index. The added tokens encoder is cached for performance - optimisation in `self._added_tokens_encoder` for the slow tokenizers. - """ - return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])} - @property def added_tokens_decoder(self) -> Dict[int, AddedToken]: raise NotImplementedError() @@ -2196,9 +2199,13 @@ def _from_pretrained( for idx, token in init_kwargs["added_tokens_decoder"].items(): if isinstance(token, dict): token = AddedToken(**token) - if isinstance(token, AddedToken): added_tokens_decoder[int(idx)] = token + if str(token) in additional_special_tokens: + # at this point the token is in `additional_special_tokens` as an str, let's add the AddedToken info + additional_special_tokens.remove(str(token)) + if token.special and token not in additional_special_tokens: + additional_special_tokens.append(token) else: raise ValueError( f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary." @@ -2381,9 +2388,8 @@ def save_pretrained( tokenizer_config = copy.deepcopy(self.init_kwargs) - # TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers - # target_keys = self.init_kwargs.keys() - target_keys = ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"] + target_keys = list(self.init_kwargs.keys()) + target_keys += ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"] for k in target_keys: if hasattr(self, k): tokenizer_config[k] = getattr(self, k) diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index 45a6639e1caab8..2c6b3c167fecd4 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -185,6 +185,14 @@ def get_vocab(self) -> Dict[str, int]: def vocab(self) -> Dict[str, int]: return self.get_vocab() + @property + def added_tokens_encoder(self) -> Dict[str, int]: + """ + Returns the sorted mapping from string to index. The added tokens encoder is cached for performance + optimisation in `self._added_tokens_encoder` for the slow tokenizers. + """ + return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])} + @property def added_tokens_decoder(self) -> Dict[int, AddedToken]: """ @@ -202,10 +210,7 @@ def get_added_vocab(self) -> Dict[str, int]: Returns: `Dict[str, int]`: The added tokens. """ - base_vocab = self._tokenizer.get_vocab(with_added_tokens=False) - full_vocab = self._tokenizer.get_vocab(with_added_tokens=True) - added_vocab = {tok: index for tok, index in full_vocab.items() if tok not in base_vocab} - return added_vocab + return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])} def __len__(self) -> int: """