Skip to content

Commit

Permalink
Support for additional_special_tokens (#1221) [skip ci]
Browse files Browse the repository at this point in the history
* Support for additional_special_tokens

* Support for additional_special_tokens. Adjust whitespace.

* Support for additional_special_tokens. Use correct quotes.

* Support for additional_special_tokens. Safe pop.

* Support for additional_special_tokens. nt.

* Support for additional_special_tokens. cfg.special_tokens may be None.

* add token if not in vocabulary when adding additional_special_tokens

* fix logic for copy/pasta

* bugfix for popping from config and tokenizer reload

* no need to add tokens manually now with previous bugfix

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
DreamGenX and winglian authored Jan 31, 2024
1 parent 52c83d3 commit 25e037f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,20 @@ def load_tokenizer(cfg):
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")

additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in cfg.special_tokens.items():
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val)) > 1)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
Expand Down Expand Up @@ -213,6 +218,21 @@ def load_tokenizer(cfg):
]
)

# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)

LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
Expand Down
15 changes: 15 additions & 0 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ def test_special_tokens_modules_to_save(self):
)
load_tokenizer(cfg)

def test_add_additional_special_tokens(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
}
)
tokenizer = load_tokenizer(cfg)
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)

# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)


if __name__ == "__main__":
unittest.main()

0 comments on commit 25e037f

Please sign in to comment.