diff --git a/tests/torchtune/models/llama3/test_llama3_tokenizer.py b/tests/torchtune/models/llama3/test_llama3_tokenizer.py index e61182db63..3b973c206d 100644 --- a/tests/torchtune/models/llama3/test_llama3_tokenizer.py +++ b/tests/torchtune/models/llama3/test_llama3_tokenizer.py @@ -428,3 +428,17 @@ def test_validate_special_tokens(self): "<|python_tag|>": 128255, }, ) + + def test_skip_special_tokens( + self, + tokenizer, + user_text_message, + assistant_text_message, + user_text_a, + user_text_b, + assistant_text, + ): + # This should satisfy text = decode(encode(text)) + tokens = user_text_message[1] + assistant_text_message[1] + text = tokenizer.decode(tokens, skip_special_tokens=True) + assert text == user_text_a + user_text_b + assistant_text diff --git a/tests/torchtune/modules/tokenizers/test_tiktoken.py b/tests/torchtune/modules/tokenizers/test_tiktoken.py index bc31fad381..e7e69f62d3 100644 --- a/tests/torchtune/modules/tokenizers/test_tiktoken.py +++ b/tests/torchtune/modules/tokenizers/test_tiktoken.py @@ -38,7 +38,6 @@ def texts(self): @pytest.fixture def token_ids(self): return [ - 0, 73, 503, 654, @@ -64,17 +63,18 @@ def token_ids(self): 511, 115, 46, - -1, ] def test_encode(self, tokenizer, texts, token_ids): - assert tokenizer.encode(texts[0]) == token_ids + assert tokenizer.encode(texts[0], add_bos=True, add_eos=True) == [ + 0 + ] + token_ids + [-1] def test_decode(self, tokenizer, texts, token_ids): assert tokenizer.decode(token_ids) == texts[0] def test_encode_and_decode(self, tokenizer, texts): - token_ids = tokenizer.encode(texts[0]) + token_ids = tokenizer.encode(texts[0], add_bos=False, add_eos=False) decoded_text = tokenizer.decode(token_ids) assert texts[0] == decoded_text diff --git a/torchtune/models/llama3/_tokenizer.py b/torchtune/models/llama3/_tokenizer.py index f91cabed3e..50ea0a7581 100644 --- a/torchtune/models/llama3/_tokenizer.py +++ b/torchtune/models/llama3/_tokenizer.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import re from typing import Any, Dict, List, Mapping, Optional, Tuple from torchtune.data import Message, PromptTemplate, truncate @@ -113,6 +114,12 @@ def __init__( self.prompt_template = prompt_template + # Regex for removing special tokens from the decoded string + self._special_token_regex = re.compile(r"<\|.*?\|>") + self._special_token_header_regex = re.compile( + r"<\|start_header_id\|>.*?<\|end_header_id\|>\n\n" + ) + def _validate_special_tokens( self, ): @@ -131,6 +138,15 @@ def _validate_special_tokens( if token not in self.special_tokens: raise ValueError(f"{token} missing from special_tokens") + def _remove_special_tokens(self, text: str) -> str: + """ + Remove special tokens from the decoded string. + """ + # First remove the headers, then the remaining special tokens + return self._special_token_regex.sub( + "", self._special_token_header_regex.sub("", text) + ) + @property def base_vocab_size(self) -> int: return self.tt_model.base_vocab_size @@ -166,10 +182,18 @@ def decode( Returns: str: The decoded string. """ - return self.tt_model.decode( - token_ids, + # We will remove special tokens manually via regex on the decoded string. + # This is because removing all special tokens does not remove the role and + # whitespace added from the special tokens, i.e., the "user" and "\n\n" in + # "<|start_header_id|>user<|end_header_id|>\n\n" + decoded_string = self.tt_model.decode( + token_ids=token_ids, truncate_at_eos=truncate_at_eos, - skip_special_tokens=skip_special_tokens, + ) + return ( + self._remove_special_tokens(decoded_string) + if skip_special_tokens + else decoded_string ) def _tokenize_header(self, message: Message) -> List[int]: diff --git a/torchtune/models/phi3/_tokenizer.py b/torchtune/models/phi3/_tokenizer.py index 544b50c372..bd1466497a 100644 --- a/torchtune/models/phi3/_tokenizer.py +++ b/torchtune/models/phi3/_tokenizer.py @@ -101,11 +101,13 @@ def encode( trim_leading_whitespace=trim_leading_whitespace, ) - def decode(self, ids: List[int]) -> str: + def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str: """Decode token IDs to strings. Args: ids (List[int]): The input token IDs to be decoded. + skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string. + Default is True. Returns: str: The decoded text. @@ -114,7 +116,7 @@ def decode(self, ids: List[int]) -> str: for token_id in ids: # Filter out special tokens and the placeholder tokens added # by the Phi3 team - if token_id >= 32_000 and token_id <= 32_064: + if skip_special_tokens and (token_id >= 32_000 and token_id <= 32_064): continue else: ids_for_decode.append(token_id) diff --git a/torchtune/modules/tokenizers/_tiktoken.py b/torchtune/modules/tokenizers/_tiktoken.py index 7fe29b2801..077b22b0cd 100644 --- a/torchtune/modules/tokenizers/_tiktoken.py +++ b/torchtune/modules/tokenizers/_tiktoken.py @@ -138,7 +138,6 @@ def decode( self, token_ids: List[int], truncate_at_eos: bool = True, - skip_special_tokens: bool = True, ) -> str: """ Decode a list of token ids into a string. @@ -147,8 +146,6 @@ def decode( token_ids (List[int]): The list of token ids. truncate_at_eos (bool): Whether to truncate the string at the end of sequence token. Default is True. - skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string. - Default is True. Returns: str: The decoded string. @@ -160,11 +157,4 @@ def decode( k = None if k: token_ids = token_ids[:k] - if skip_special_tokens: - token_ids = [ - token_id - for token_id in token_ids - if token_id not in self.tt_model._special_tokens.values() - and token_id != self.bos_id - ] return self.tt_model.decode(token_ids)