diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 305b445b2e..60c791fa12 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -27,28 +27,6 @@ raise DidNotEnable("langchain not installed") -try: - import tiktoken # type: ignore - - enc = tiktoken.get_encoding("cl100k_base") - - def count_tokens(s): - # type: (str) -> int - return len(enc.encode_ordinary(s)) - - logger.debug("[langchain] using tiktoken to count tokens") -except ImportError: - logger.info( - "The Sentry Python SDK requires 'tiktoken' in order to measure token usage from streaming langchain calls." - "Please install 'tiktoken' if you aren't receiving accurate token usage in Sentry." - "See https://docs.sentry.io/platforms/python/integrations/langchain/ for more information." - ) - - def count_tokens(s): - # type: (str) -> int - return 1 - - DATA_FIELDS = { "temperature": SPANDATA.AI_TEMPERATURE, "top_p": SPANDATA.AI_TOP_P, @@ -78,10 +56,13 @@ class LangchainIntegration(Integration): # The most number of spans (e.g., LLM calls) that can be processed at the same time. max_spans = 1024 - def __init__(self, include_prompts=True, max_spans=1024): - # type: (LangchainIntegration, bool, int) -> None + def __init__( + self, include_prompts=True, max_spans=1024, tiktoken_encoding_name=None + ): + # type: (LangchainIntegration, bool, int, Optional[str]) -> None self.include_prompts = include_prompts self.max_spans = max_spans + self.tiktoken_encoding_name = tiktoken_encoding_name @staticmethod def setup_once(): @@ -109,11 +90,23 @@ class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc] max_span_map_size = 0 - def __init__(self, max_span_map_size, include_prompts): - # type: (int, bool) -> None + def __init__(self, max_span_map_size, include_prompts, tiktoken_encoding_name=None): + # type: (int, bool, Optional[str]) -> None self.max_span_map_size = max_span_map_size self.include_prompts = include_prompts + self.tiktoken_encoding = None + if tiktoken_encoding_name is not None: + import tiktoken # type: ignore + + self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name) + + def count_tokens(self, s): + # type: (str) -> int + if self.tiktoken_encoding is not None: + return len(self.tiktoken_encoding.encode_ordinary(s)) + return 0 + def gc_span_map(self): # type: () -> None @@ -244,9 +237,9 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): if not watched_span.no_collect_tokens: for list_ in messages: for message in list_: - self.span_map[run_id].num_prompt_tokens += count_tokens( + self.span_map[run_id].num_prompt_tokens += self.count_tokens( message.content - ) + count_tokens(message.type) + ) + self.count_tokens(message.type) def on_llm_new_token(self, token, *, run_id, **kwargs): # type: (SentryLangchainCallback, str, UUID, Any) -> Any @@ -257,7 +250,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs): span_data = self.span_map[run_id] if not span_data or span_data.no_collect_tokens: return - span_data.num_completion_tokens += count_tokens(token) + span_data.num_completion_tokens += self.count_tokens(token) def on_llm_end(self, response, *, run_id, **kwargs): # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any @@ -461,7 +454,9 @@ def new_configure(*args, **kwargs): if not already_added: new_callbacks.append( SentryLangchainCallback( - integration.max_spans, integration.include_prompts + integration.max_spans, + integration.include_prompts, + integration.tiktoken_encoding_name, ) ) return f(*args, **kwargs) diff --git a/sentry_sdk/integrations/openai.py b/sentry_sdk/integrations/openai.py index 052d65f7a6..d06c188712 100644 --- a/sentry_sdk/integrations/openai.py +++ b/sentry_sdk/integrations/openai.py @@ -14,7 +14,6 @@ from sentry_sdk.scope import should_send_default_pii from sentry_sdk.integrations import DidNotEnable, Integration from sentry_sdk.utils import ( - logger, capture_internal_exceptions, event_from_exception, ensure_integration_enabled, @@ -29,45 +28,33 @@ except ImportError: raise DidNotEnable("OpenAI not installed") -try: - import tiktoken # type: ignore - - enc = None # lazy initialize - - def count_tokens(s): - # type: (str) -> int - global enc - if enc is None: - enc = tiktoken.get_encoding("cl100k_base") - return len(enc.encode_ordinary(s)) - - logger.debug("[OpenAI] using tiktoken to count tokens") -except ImportError: - logger.info( - "The Sentry Python SDK requires 'tiktoken' in order to measure token usage from some OpenAI APIs" - "Please install 'tiktoken' if you aren't receiving token usage in Sentry." - "See https://docs.sentry.io/platforms/python/integrations/openai/ for more information." - ) - - def count_tokens(s): - # type: (str) -> int - return 0 - class OpenAIIntegration(Integration): identifier = "openai" origin = f"auto.ai.{identifier}" - def __init__(self, include_prompts=True): - # type: (OpenAIIntegration, bool) -> None + def __init__(self, include_prompts=True, tiktoken_encoding_name=None): + # type: (OpenAIIntegration, bool, Optional[str]) -> None self.include_prompts = include_prompts + self.tiktoken_encoding = None + if tiktoken_encoding_name is not None: + import tiktoken # type: ignore + + self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name) + @staticmethod def setup_once(): # type: () -> None Completions.create = _wrap_chat_completion_create(Completions.create) Embeddings.create = _wrap_embeddings_create(Embeddings.create) + def count_tokens(self, s): + # type: (OpenAIIntegration, str) -> int + if self.tiktoken_encoding is not None: + return len(self.tiktoken_encoding.encode_ordinary(s)) + return 0 + def _capture_exception(exc): # type: (Any) -> None @@ -80,9 +67,9 @@ def _capture_exception(exc): def _calculate_chat_completion_usage( - messages, response, span, streaming_message_responses=None + messages, response, span, streaming_message_responses, count_tokens ): - # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]]) -> None + # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None completion_tokens = 0 # type: Optional[int] prompt_tokens = 0 # type: Optional[int] total_tokens = 0 # type: Optional[int] @@ -173,7 +160,9 @@ def new_chat_completion(*args, **kwargs): "ai.responses", list(map(lambda x: x.message, res.choices)), ) - _calculate_chat_completion_usage(messages, res, span) + _calculate_chat_completion_usage( + messages, res, span, None, integration.count_tokens + ) span.__exit__(None, None, None) elif hasattr(res, "_iterator"): data_buf: list[list[str]] = [] # one for each choice @@ -208,7 +197,11 @@ def new_iterator(): span, SPANDATA.AI_RESPONSES, all_responses ) _calculate_chat_completion_usage( - messages, res, span, all_responses + messages, + res, + span, + all_responses, + integration.count_tokens, ) span.__exit__(None, None, None) @@ -266,7 +259,7 @@ def new_embeddings_create(*args, **kwargs): total_tokens = response.usage.total_tokens if prompt_tokens == 0: - prompt_tokens = count_tokens(kwargs["input"] or "") + prompt_tokens = integration.count_tokens(kwargs["input"] or "") record_token_usage(span, prompt_tokens, None, total_tokens or prompt_tokens) diff --git a/tests/integrations/langchain/test_langchain.py b/tests/integrations/langchain/test_langchain.py index 5e7ebbbf1d..b9e5705b88 100644 --- a/tests/integrations/langchain/test_langchain.py +++ b/tests/integrations/langchain/test_langchain.py @@ -46,6 +46,15 @@ def _llm_type(self) -> str: return llm_type +def tiktoken_encoding_if_installed(): + try: + import tiktoken # type: ignore # noqa # pylint: disable=unused-import + + return "cl100k_base" + except ImportError: + return None + + @pytest.mark.parametrize( "send_default_pii, include_prompts, use_unknown_llm_type", [ @@ -62,7 +71,12 @@ def test_langchain_agent( llm_type = "acme-llm" if use_unknown_llm_type else "openai-chat" sentry_init( - integrations=[LangchainIntegration(include_prompts=include_prompts)], + integrations=[ + LangchainIntegration( + include_prompts=include_prompts, + tiktoken_encoding_name=tiktoken_encoding_if_installed(), + ) + ], traces_sample_rate=1.0, send_default_pii=send_default_pii, ) diff --git a/tests/integrations/openai/test_openai.py b/tests/integrations/openai/test_openai.py index 9cd8761fd6..b0ffc9e768 100644 --- a/tests/integrations/openai/test_openai.py +++ b/tests/integrations/openai/test_openai.py @@ -78,6 +78,15 @@ def test_nonstreaming_chat_completion( assert span["measurements"]["ai_total_tokens_used"]["value"] == 30 +def tiktoken_encoding_if_installed(): + try: + import tiktoken # type: ignore # noqa # pylint: disable=unused-import + + return "cl100k_base" + except ImportError: + return None + + # noinspection PyTypeChecker @pytest.mark.parametrize( "send_default_pii, include_prompts", @@ -87,7 +96,12 @@ def test_streaming_chat_completion( sentry_init, capture_events, send_default_pii, include_prompts ): sentry_init( - integrations=[OpenAIIntegration(include_prompts=include_prompts)], + integrations=[ + OpenAIIntegration( + include_prompts=include_prompts, + tiktoken_encoding_name=tiktoken_encoding_if_installed(), + ) + ], traces_sample_rate=1.0, send_default_pii=send_default_pii, )