Skip to content

Commit

Permalink
fix: Tokenizers - Fixed Tokenizer.compute_tokens
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671042779
  • Loading branch information
happy-qiao authored and copybara-github committed Sep 4, 2024
1 parent 6624ebe commit c29fa5d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 30 deletions.
39 changes: 27 additions & 12 deletions tests/system/vertexai/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from nltk.corpus import udhr
from google.cloud import aiplatform
from vertexai.preview.tokenization import (
get_tokenizer_for_model,
get_tokenizer_for_model as tokenizer_preview,
)
from vertexai.tokenization._tokenizers import (
get_tokenizer_for_model as tokenizer_ga,
)
from vertexai.generative_models import (
GenerativeModel,
Expand All @@ -44,8 +47,10 @@
_CORPUS_LIB = [
udhr,
]
_VERSIONED_TOKENIZER = [tokenizer_preview, tokenizer_ga]
_MODEL_CORPUS_PARAMS = [
(model_name, corpus_name, corpus_lib)
(get_tokenizer_for_model, model_name, corpus_name, corpus_lib)
for get_tokenizer_for_model in _VERSIONED_TOKENIZER
for model_name in _MODELS
for (corpus_name, corpus_lib) in zip(_CORPUS, _CORPUS_LIB)
]
Expand Down Expand Up @@ -125,11 +130,16 @@ def setup_method(self, api_endpoint_env_name):
)

@pytest.mark.parametrize(
"model_name, corpus_name, corpus_lib",
"get_tokenizer_for_model, model_name, corpus_name, corpus_lib",
_MODEL_CORPUS_PARAMS,
)
def test_count_tokens_local(
self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
self,
get_tokenizer_for_model,
model_name,
corpus_name,
corpus_lib,
api_endpoint_env_name,
):
# The Gemini 1.5 flash model requires the model version
# number suffix (001) in staging only
Expand All @@ -145,11 +155,16 @@ def test_count_tokens_local(
assert service_result.total_tokens == local_result.total_tokens

@pytest.mark.parametrize(
"model_name, corpus_name, corpus_lib",
"get_tokenizer_for_model, model_name, corpus_name, corpus_lib",
_MODEL_CORPUS_PARAMS,
)
def test_compute_tokens(
self, model_name, corpus_name, corpus_lib, api_endpoint_env_name
self,
get_tokenizer_for_model,
model_name,
corpus_name,
corpus_lib,
api_endpoint_env_name,
):
# The Gemini 1.5 flash model requires the model version
# number suffix (001) in staging only
Expand All @@ -171,7 +186,7 @@ def test_compute_tokens(
_MODELS,
)
def test_count_tokens_system_instruction(self, model_name):
tokenizer = get_tokenizer_for_model(model_name)
tokenizer = tokenizer_preview(model_name)
model = GenerativeModel(model_name, system_instruction=["You are a chatbot."])

assert (
Expand All @@ -188,7 +203,7 @@ def test_count_tokens_system_instruction(self, model_name):
def test_count_tokens_system_instruction_is_function_call(self, model_name):
part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))

tokenizer = get_tokenizer_for_model(model_name)
tokenizer = tokenizer_preview(model_name)
model = GenerativeModel(model_name, system_instruction=[part])

assert (
Expand All @@ -204,7 +219,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name):
part = Part._from_gapic(
gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
)
tokenizer = get_tokenizer_for_model(model_name)
tokenizer = tokenizer_preview(model_name)
model = GenerativeModel(model_name, system_instruction=[part])

assert tokenizer.count_tokens(part, system_instruction=[part]).total_tokens
Expand All @@ -218,7 +233,7 @@ def test_count_tokens_system_instruction_is_function_response(self, model_name):
_MODELS,
)
def test_count_tokens_tool_is_function_declaration(self, model_name):
tokenizer = get_tokenizer_for_model(model_name)
tokenizer = tokenizer_preview(model_name)
model = GenerativeModel(model_name)
tool1 = Tool._from_gapic(
gapic_tool_types.Tool(function_declarations=[_FUNCTION_DECLARATION_1])
Expand All @@ -241,7 +256,7 @@ def test_count_tokens_tool_is_function_declaration(self, model_name):
)
def test_count_tokens_content_is_function_call(self, model_name):
part = Part._from_gapic(gapic_content_types.Part(function_call=_FUNCTION_CALL))
tokenizer = get_tokenizer_for_model(model_name)
tokenizer = tokenizer_preview(model_name)
model = GenerativeModel(model_name)

assert tokenizer.count_tokens(part).total_tokens
Expand All @@ -258,7 +273,7 @@ def test_count_tokens_content_is_function_response(self, model_name):
part = Part._from_gapic(
gapic_content_types.Part(function_response=_FUNCTION_RESPONSE)
)
tokenizer = get_tokenizer_for_model(model_name)
tokenizer = tokenizer_preview(model_name)
model = GenerativeModel(model_name)

assert tokenizer.count_tokens(part).total_tokens
Expand Down
31 changes: 13 additions & 18 deletions vertexai/tokenization/_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,6 @@ class TokensInfo:
role: str = None


@dataclasses.dataclass(frozen=True)
class ComputeTokensResult:
tokens_info: Sequence[TokensInfo]


class PreviewComputeTokensResult(ComputeTokensResult):
def token_info_list(self) -> Sequence[TokensInfo]:
import warnings

message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead."
warnings.warn(message, DeprecationWarning, stacklevel=2)
return self.tokens_info


@dataclasses.dataclass(frozen=True)
class ComputeTokensResult:
"""Represents token string pieces and ids output in compute_tokens function.
Expand All @@ -78,11 +64,18 @@ class ComputeTokensResult:
item represents each string instance. Each token
info consists tokens list, token_ids list and
a role.
token_info_list: the value in this field equal to tokens_info.
"""

tokens_info: Sequence[TokensInfo]
token_info_list: Sequence[TokensInfo]


class PreviewComputeTokensResult(ComputeTokensResult):
def token_info_list(self) -> Sequence[TokensInfo]:
import warnings

message = "PreviewComputeTokensResult.token_info_list is deprecated. Use ComputeTokensResult.tokens_info instead."
warnings.warn(message, DeprecationWarning, stacklevel=2)
return self.tokens_info


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -169,7 +162,7 @@ def compute_tokens(
role=role,
)
)
return ComputeTokensResult(token_info_list=token_infos, tokens_info=token_infos)
return ComputeTokensResult(tokens_info=token_infos)


def _to_gapic_contents(
Expand Down Expand Up @@ -539,7 +532,9 @@ def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:

class PreviewTokenizer(Tokenizer):
def compute_tokens(self, contents: ContentsType) -> PreviewComputeTokensResult:
return PreviewComputeTokensResult(tokens_info=super().compute_tokens(contents))
return PreviewComputeTokensResult(
tokens_info=super().compute_tokens(contents).tokens_info
)


def _get_tokenizer_for_model_preview(model_name: str) -> PreviewTokenizer:
Expand Down

0 comments on commit c29fa5d

Please sign in to comment.