Skip to content

Commit

Permalink
bugfix for popping from config and tokenizer reload
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 31, 2024
1 parent a695dff commit dea2d31
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,18 @@ def load_tokenizer(cfg):

additional_special_tokens = None
if cfg.special_tokens:
additional_special_tokens = cfg.special_tokens.pop(
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in cfg.special_tokens.items():
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val)) > 1)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
Expand Down Expand Up @@ -229,7 +230,7 @@ def load_tokenizer(cfg):
# ```
if additional_special_tokens is not None:
for token in additional_special_tokens:
if len(tokenizer.encode(token)) > 1:
if len(tokenizer.encode(token, add_special_tokens=False)) > 2:
LOG.warning(f"missing {token} in cfg.tokens, adding to vocabulary.")
tokenizer.add_tokens(
[AddedToken(token, rstrip=False, lstrip=False, normalized=False)]
Expand Down
5 changes: 5 additions & 0 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def test_add_additional_special_tokens(self):
)
tokenizer = load_tokenizer(cfg)
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)

# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)


if __name__ == "__main__":
Expand Down

0 comments on commit dea2d31

Please sign in to comment.