Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip entire header for llama3 decode #1656

Merged
merged 6 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/torchtune/models/llama3/test_llama3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,17 @@ def test_validate_special_tokens(self):
"<|python_tag|>": 128255,
},
)

def test_skip_special_tokens(
self,
tokenizer,
user_text_message,
assistant_text_message,
user_text_a,
user_text_b,
assistant_text,
):
# This should satisfy text = decode(encode(text))
tokens = user_text_message[1] + assistant_text_message[1]
text = tokenizer.decode(tokens, skip_special_tokens=True)
assert text == user_text_a + user_text_b + assistant_text
8 changes: 4 additions & 4 deletions tests/torchtune/modules/tokenizers/test_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def texts(self):
@pytest.fixture
def token_ids(self):
return [
0,
73,
503,
654,
Expand All @@ -64,17 +63,18 @@ def token_ids(self):
511,
115,
46,
-1,
]

def test_encode(self, tokenizer, texts, token_ids):
assert tokenizer.encode(texts[0]) == token_ids
assert tokenizer.encode(texts[0], add_bos=True, add_eos=True) == [
0
] + token_ids + [-1]

def test_decode(self, tokenizer, texts, token_ids):
assert tokenizer.decode(token_ids) == texts[0]

def test_encode_and_decode(self, tokenizer, texts):
token_ids = tokenizer.encode(texts[0])
token_ids = tokenizer.encode(texts[0], add_bos=False, add_eos=False)
decoded_text = tokenizer.decode(token_ids)
assert texts[0] == decoded_text

Expand Down
30 changes: 27 additions & 3 deletions torchtune/models/llama3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any, Dict, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate, truncate
Expand Down Expand Up @@ -113,6 +114,12 @@ def __init__(

self.prompt_template = prompt_template

# Regex for removing special tokens from the decoded string
self._special_token_regex = re.compile(r"<\|.*?\|>")
self._special_token_header_regex = re.compile(
r"<\|start_header_id\|>.*?<\|end_header_id\|>\n\n"
)

def _validate_special_tokens(
self,
):
Expand All @@ -131,6 +138,15 @@ def _validate_special_tokens(
if token not in self.special_tokens:
raise ValueError(f"{token} missing from special_tokens")

def _remove_special_tokens(self, text: str) -> str:
"""
Remove special tokens from the decoded string.
"""
# First remove the headers, then the remaining special tokens
return self._special_token_regex.sub(
"", self._special_token_header_regex.sub("", text)
)

@property
def base_vocab_size(self) -> int:
return self.tt_model.base_vocab_size
Expand Down Expand Up @@ -166,10 +182,18 @@ def decode(
Returns:
str: The decoded string.
"""
return self.tt_model.decode(
token_ids,
# We will remove special tokens manually via regex on the decoded string.
# This is because removing all special tokens does not remove the role and
# whitespace added from the special tokens, i.e., the "user" and "\n\n" in
# "<|start_header_id|>user<|end_header_id|>\n\n"
Comment on lines +185 to +188
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe move this comment up to where you define self._special_token_regex and self._special_token_header_regex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is where it actually happens, so makes more sense to keep it here? no strong opinions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I feel the same, fine to keep it here then

decoded_string = self.tt_model.decode(
token_ids=token_ids,
truncate_at_eos=truncate_at_eos,
skip_special_tokens=skip_special_tokens,
)
return (
self._remove_special_tokens(decoded_string)
if skip_special_tokens
else decoded_string
)

def _tokenize_header(self, message: Message) -> List[int]:
Expand Down
6 changes: 4 additions & 2 deletions torchtune/models/phi3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def encode(
trim_leading_whitespace=trim_leading_whitespace,
)

def decode(self, ids: List[int]) -> str:
def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str:
"""Decode token IDs to strings.

Args:
ids (List[int]): The input token IDs to be decoded.
skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string.
Default is True.

Returns:
str: The decoded text.
Expand All @@ -114,7 +116,7 @@ def decode(self, ids: List[int]) -> str:
for token_id in ids:
# Filter out special tokens and the placeholder tokens added
# by the Phi3 team
if token_id >= 32_000 and token_id <= 32_064:
if skip_special_tokens and (token_id >= 32_000 and token_id <= 32_064):
continue
else:
ids_for_decode.append(token_id)
Expand Down
10 changes: 0 additions & 10 deletions torchtune/modules/tokenizers/_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def decode(
self,
token_ids: List[int],
truncate_at_eos: bool = True,
skip_special_tokens: bool = True,
) -> str:
"""
Decode a list of token ids into a string.
Expand All @@ -147,8 +146,6 @@ def decode(
token_ids (List[int]): The list of token ids.
truncate_at_eos (bool): Whether to truncate the string at the end of
sequence token. Default is True.
skip_special_tokens (bool): Whether to show or skip special tokens in the decoded string.
Default is True.

Returns:
str: The decoded string.
Expand All @@ -160,11 +157,4 @@ def decode(
k = None
if k:
token_ids = token_ids[:k]
if skip_special_tokens:
token_ids = [
token_id
for token_id in token_ids
if token_id not in self.tt_model._special_tokens.values()
and token_id != self.bos_id
]
return self.tt_model.decode(token_ids)
Loading