Skip to content

Commit

Permalink
Nit-added-tokens (huggingface#26538)
Browse files Browse the repository at this point in the history
* fix stripping

* nits

* fix another test

* styling

* fix?

* update

* revert bad merge

* found the bug

* YES SIR

* is that change really required?

* make fast even faster

* re order functions
  • Loading branch information
ArthurZucker authored and blbadger committed Nov 8, 2023
1 parent cd09bc1 commit 62cfead
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 27 deletions.
30 changes: 19 additions & 11 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,25 @@ def __init__(self, **kwargs):

self._decode_use_source_tokenizer = False

@property
def is_fast(self) -> bool:
return False

@property
def vocab_size(self) -> int:
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
raise NotImplementedError

@property
def added_tokens_encoder(self) -> Dict[str, int]:
"""
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
"""
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}

@property
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
"""
Expand All @@ -389,17 +408,6 @@ def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> Dict
self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
self._added_tokens_encoder[str(token)] = index

@property
def is_fast(self) -> bool:
return False

@property
def vocab_size(self) -> int:
"""
`int`: Size of the base vocabulary (without the added tokens).
"""
raise NotImplementedError

def get_added_vocab(self) -> Dict[str, int]:
"""
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
Expand Down
30 changes: 18 additions & 12 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,15 +846,26 @@ def __init__(self, verbose=True, **kwargs):
# We directly set the hidden value to allow initialization with special tokens
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
# TODO clean this up at some point (probably by switching to fast tokenizers)

for key, value in kwargs.items():
if value is None:
continue
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens":
# TODO THIS IS NASTY! Will always reset tokens to default rstrip and lstrip because self.set_attr on strings
# will not check the addedtokens decoder. WILL FIX TOMORROW
assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
assert all(
isinstance(t, (str, AddedToken)) for t in value
), "One of the tokens is not a string or an AddedToken"
if hasattr(self, "added_tokens_encoder"):
extended_token = []
for token in value:
if isinstance(token, str) and str(token) in self.added_tokens_encoder:
extended_token.append(self.added_tokens_decoder[self.added_tokens_encoder[str(token)]])
else:
extended_token.append(token)
value = extended_token
setattr(self, key, value)
elif isinstance(value, (str)):
value = AddedToken(value, normalized=False, special=True)
Expand Down Expand Up @@ -1674,14 +1685,6 @@ def _set_processor_class(self, processor_class: str):
"""Sets processor class as an attribute."""
self._processor_class = processor_class

@property
def added_tokens_encoder(self) -> Dict[str, int]:
"""
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
"""
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}

@property
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
raise NotImplementedError()
Expand Down Expand Up @@ -2196,9 +2199,13 @@ def _from_pretrained(
for idx, token in init_kwargs["added_tokens_decoder"].items():
if isinstance(token, dict):
token = AddedToken(**token)

if isinstance(token, AddedToken):
added_tokens_decoder[int(idx)] = token
if str(token) in additional_special_tokens:
# at this point the token is in `additional_special_tokens` as an str, let's add the AddedToken info
additional_special_tokens.remove(str(token))
if token.special and token not in additional_special_tokens:
additional_special_tokens.append(token)
else:
raise ValueError(
f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary."
Expand Down Expand Up @@ -2381,9 +2388,8 @@ def save_pretrained(

tokenizer_config = copy.deepcopy(self.init_kwargs)

# TODO: Ensure the modified attributes (those are also in the __init__ kwargs) will give identical tokenizers
# target_keys = self.init_kwargs.keys()
target_keys = ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
target_keys = list(self.init_kwargs.keys())
target_keys += ["model_max_length", "clean_up_tokenization_spaces", "additional_special_tokens"]
for k in target_keys:
if hasattr(self, k):
tokenizer_config[k] = getattr(self, k)
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ def get_vocab(self) -> Dict[str, int]:
def vocab(self) -> Dict[str, int]:
return self.get_vocab()

@property
def added_tokens_encoder(self) -> Dict[str, int]:
"""
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
"""
return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}

@property
def added_tokens_decoder(self) -> Dict[int, AddedToken]:
"""
Expand All @@ -202,10 +210,7 @@ def get_added_vocab(self) -> Dict[str, int]:
Returns:
`Dict[str, int]`: The added tokens.
"""
base_vocab = self._tokenizer.get_vocab(with_added_tokens=False)
full_vocab = self._tokenizer.get_vocab(with_added_tokens=True)
added_vocab = {tok: index for tok, index in full_vocab.items() if tok not in base_vocab}
return added_vocab
return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}

def __len__(self) -> int:
"""
Expand Down

0 comments on commit 62cfead

Please sign in to comment.