Skip to content

Commit

Permalink
feat(openai): Make tiktoken encoding name configurable + tiktoken usa…
Browse files Browse the repository at this point in the history
…ge opt-in (getsentry#3289)

Make tiktoken encoding name configurable + tiktoken usage opt-in

---------

Co-authored-by: Ivana Kellyer <[email protected]>
  • Loading branch information
2 people authored and arjenzorgdoc committed Sep 30, 2024
1 parent b68d41d commit 5407576
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 64 deletions.
55 changes: 25 additions & 30 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,6 @@
raise DidNotEnable("langchain not installed")


try:
import tiktoken # type: ignore

enc = tiktoken.get_encoding("cl100k_base")

def count_tokens(s):
# type: (str) -> int
return len(enc.encode_ordinary(s))

logger.debug("[langchain] using tiktoken to count tokens")
except ImportError:
logger.info(
"The Sentry Python SDK requires 'tiktoken' in order to measure token usage from streaming langchain calls."
"Please install 'tiktoken' if you aren't receiving accurate token usage in Sentry."
"See https://docs.sentry.io/platforms/python/integrations/langchain/ for more information."
)

def count_tokens(s):
# type: (str) -> int
return 1


DATA_FIELDS = {
"temperature": SPANDATA.AI_TEMPERATURE,
"top_p": SPANDATA.AI_TOP_P,
Expand Down Expand Up @@ -78,10 +56,13 @@ class LangchainIntegration(Integration):
# The most number of spans (e.g., LLM calls) that can be processed at the same time.
max_spans = 1024

def __init__(self, include_prompts=True, max_spans=1024):
# type: (LangchainIntegration, bool, int) -> None
def __init__(
self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None
):
# type: (LangchainIntegration, bool, int, Optional[str]) -> None
self.include_prompts = include_prompts
self.max_spans = max_spans
self.tiktoken_encoding_name = tiktoken_encoding_name

@staticmethod
def setup_once():
Expand Down Expand Up @@ -109,11 +90,23 @@ class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]

max_span_map_size = 0

def __init__(self, max_span_map_size, include_prompts):
# type: (int, bool) -> None
def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None):
# type: (int, bool, Optional[str]) -> None
self.max_span_map_size = max_span_map_size
self.include_prompts = include_prompts

self.tiktoken_encoding = None
if tiktoken_encoding_name is not None:
import tiktoken # type: ignore

self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)

def count_tokens(self, s):
# type: (str) -> int
if self.tiktoken_encoding is not None:
return len(self.tiktoken_encoding.encode_ordinary(s))
return 0

def gc_span_map(self):
# type: () -> None

Expand Down Expand Up @@ -244,9 +237,9 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
if not watched_span.no_collect_tokens:
for list_ in messages:
for message in list_:
self.span_map[run_id].num_prompt_tokens += count_tokens(
self.span_map[run_id].num_prompt_tokens += self.count_tokens(
message.content
) + count_tokens(message.type)
) + self.count_tokens(message.type)

def on_llm_new_token(self, token, *, run_id, **kwargs):
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
Expand All @@ -257,7 +250,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
span_data = self.span_map[run_id]
if not span_data or span_data.no_collect_tokens:
return
span_data.num_completion_tokens += count_tokens(token)
span_data.num_completion_tokens += self.count_tokens(token)

def on_llm_end(self, response, *, run_id, **kwargs):
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
Expand Down Expand Up @@ -461,7 +454,9 @@ def new_configure(*args, **kwargs):
if not already_added:
new_callbacks.append(
SentryLangchainCallback(
integration.max_spans, integration.include_prompts
integration.max_spans,
integration.include_prompts,
integration.tiktoken_encoding_name,
)
)
return f(*args, **kwargs)
Expand Down
57 changes: 25 additions & 32 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.utils import (
logger,
capture_internal_exceptions,
event_from_exception,
ensure_integration_enabled,
Expand All @@ -29,45 +28,33 @@
except ImportError:
raise DidNotEnable("OpenAI not installed")

try:
import tiktoken # type: ignore

enc = None # lazy initialize

def count_tokens(s):
# type: (str) -> int
global enc
if enc is None:
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode_ordinary(s))

logger.debug("[OpenAI] using tiktoken to count tokens")
except ImportError:
logger.info(
"The Sentry Python SDK requires 'tiktoken' in order to measure token usage from some OpenAI APIs"
"Please install 'tiktoken' if you aren't receiving token usage in Sentry."
"See https://docs.sentry.io/platforms/python/integrations/openai/ for more information."
)

def count_tokens(s):
# type: (str) -> int
return 0


class OpenAIIntegration(Integration):
identifier = "openai"
origin = f"auto.ai.{identifier}"

def __init__(self, include_prompts=True):
# type: (OpenAIIntegration, bool) -> None
def __init__(self, include_prompts=True, tiktoken_encoding_name=None):
# type: (OpenAIIntegration, bool, Optional[str]) -> None
self.include_prompts = include_prompts

self.tiktoken_encoding = None
if tiktoken_encoding_name is not None:
import tiktoken # type: ignore

self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)

@staticmethod
def setup_once():
# type: () -> None
Completions.create = _wrap_chat_completion_create(Completions.create)
Embeddings.create = _wrap_embeddings_create(Embeddings.create)

def count_tokens(self, s):
# type: (OpenAIIntegration, str) -> int
if self.tiktoken_encoding is not None:
return len(self.tiktoken_encoding.encode_ordinary(s))
return 0


def _capture_exception(exc):
# type: (Any) -> None
Expand All @@ -80,9 +67,9 @@ def _capture_exception(exc):


def _calculate_chat_completion_usage(
messages, response, span, streaming_message_responses=None
messages, response, span, streaming_message_responses, count_tokens
):
# type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]]) -> None
# type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
completion_tokens = 0 # type: Optional[int]
prompt_tokens = 0 # type: Optional[int]
total_tokens = 0 # type: Optional[int]
Expand Down Expand Up @@ -173,7 +160,9 @@ def new_chat_completion(*args, **kwargs):
"ai.responses",
list(map(lambda x: x.message, res.choices)),
)
_calculate_chat_completion_usage(messages, res, span)
_calculate_chat_completion_usage(
messages, res, span, None, integration.count_tokens
)
span.__exit__(None, None, None)
elif hasattr(res, "_iterator"):
data_buf: list[list[str]] = [] # one for each choice
Expand Down Expand Up @@ -208,7 +197,11 @@ def new_iterator():
span, SPANDATA.AI_RESPONSES, all_responses
)
_calculate_chat_completion_usage(
messages, res, span, all_responses
messages,
res,
span,
all_responses,
integration.count_tokens,
)
span.__exit__(None, None, None)

Expand Down Expand Up @@ -266,7 +259,7 @@ def new_embeddings_create(*args, **kwargs):
total_tokens = response.usage.total_tokens

if prompt_tokens == 0:
prompt_tokens = count_tokens(kwargs["input"] or "")
prompt_tokens = integration.count_tokens(kwargs["input"] or "")

record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens)

Expand Down
16 changes: 15 additions & 1 deletion tests/integrations/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def _llm_type(self) -> str:
return llm_type


def tiktoken_encoding_if_installed():
try:
import tiktoken # type: ignore # noqa # pylint: disable=unused-import

return "cl100k_base"
except ImportError:
return None


@pytest.mark.parametrize(
"send_default_pii, include_prompts, use_unknown_llm_type",
[
Expand All @@ -62,7 +71,12 @@ def test_langchain_agent(
llm_type = "acme-llm" if use_unknown_llm_type else "openai-chat"

sentry_init(
integrations=[LangchainIntegration(include_prompts=include_prompts)],
integrations=[
LangchainIntegration(
include_prompts=include_prompts,
tiktoken_encoding_name=tiktoken_encoding_if_installed(),
)
],
traces_sample_rate=1.0,
send_default_pii=send_default_pii,
)
Expand Down
16 changes: 15 additions & 1 deletion tests/integrations/openai/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ def test_nonstreaming_chat_completion(
assert span["measurements"]["ai_total_tokens_used"]["value"] == 30


def tiktoken_encoding_if_installed():
try:
import tiktoken # type: ignore # noqa # pylint: disable=unused-import

return "cl100k_base"
except ImportError:
return None


# noinspection PyTypeChecker
@pytest.mark.parametrize(
"send_default_pii, include_prompts",
Expand All @@ -87,7 +96,12 @@ def test_streaming_chat_completion(
sentry_init, capture_events, send_default_pii, include_prompts
):
sentry_init(
integrations=[OpenAIIntegration(include_prompts=include_prompts)],
integrations=[
OpenAIIntegration(
include_prompts=include_prompts,
tiktoken_encoding_name=tiktoken_encoding_if_installed(),
)
],
traces_sample_rate=1.0,
send_default_pii=send_default_pii,
)
Expand Down

0 comments on commit 5407576

Please sign in to comment.