diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 9d213e3a6b..51a06b381c 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -52,7 +52,7 @@ def __init__(self, span, num_tokens=0): self.num_tokens = num_tokens -class SentryLangchainCallback(BaseCallbackHandler): +class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc] """Base callback handler that can be used to handle callbacks from langchain.""" span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan] @@ -60,15 +60,18 @@ class SentryLangchainCallback(BaseCallbackHandler): max_span_map_size = 0 def __init__(self, max_span_map_size, include_prompts): + # type: (int, bool) -> None self.max_span_map_size = max_span_map_size self.include_prompts = include_prompts def gc_span_map(self): + # type: () -> None + while len(self.span_map) > self.max_span_map_size: self.span_map.popitem(last=False)[1].span.__exit__(None, None, None) def _handle_error(self, run_id, error): - # type: (str, Any) -> None + # type: (UUID, Any) -> None if not run_id or not self.span_map[run_id]: return @@ -80,7 +83,7 @@ def _handle_error(self, run_id, error): del self.span_map[run_id] def _normalize_langchain_message(self, message): - # type: (BaseMessage) -> dict + # type: (BaseMessage) -> Any parsed = {"content": message.content, "role": message.type} parsed.update(message.additional_kwargs) return parsed @@ -113,7 +116,7 @@ def on_llm_start( metadata=None, **kwargs, ): - # type: (Dict[str, Any], List[str], Any, UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any + # type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Dict[str, Any]) -> Any """Run when LLM starts running.""" with capture_internal_exceptions(): if not run_id: @@ -128,7 +131,7 @@ def on_llm_start( set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts) def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): - # type: (Dict[str, Any], List[List[BaseMessage]], Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Dict[str, Any]) -> Any """Run when Chat Model starts running.""" if not run_id: return @@ -151,7 +154,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): ) def on_llm_new_token(self, token, *, run_id, **kwargs): - # type: (str, Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, str, UUID, Dict[str, Any]) -> Any """Run on new LLM token. Only available when streaming is enabled.""" with capture_internal_exceptions(): if not run_id or not self.span_map[run_id]: @@ -162,7 +165,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs): span_data.num_tokens += 1 def on_llm_end(self, response, *, run_id, **kwargs): - # type: (LLMResult, Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, LLMResult, UUID, Dict[str, Any]) -> Any """Run when LLM ends running.""" with capture_internal_exceptions(): if not run_id: @@ -206,13 +209,13 @@ def on_llm_end(self, response, *, run_id, **kwargs): del self.span_map[run_id] def on_llm_error(self, error, *, run_id, **kwargs): - # type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any """Run when LLM errors.""" with capture_internal_exceptions(): self._handle_error(run_id, error) def on_chain_start(self, serialized, inputs, *, run_id, **kwargs): - # type: (Dict[str, Any], Dict[str, Any], Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Dict[str, Any]) -> Any """Run when chain starts running.""" with capture_internal_exceptions(): if not run_id: @@ -225,7 +228,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs): ) def on_chain_end(self, outputs, *, run_id, **kwargs): - # type: (Dict[str, Any], Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, Dict[str, Any], UUID, Dict[str, Any]) -> Any """Run when chain ends running.""" with capture_internal_exceptions(): if not run_id or not self.span_map[run_id]: @@ -238,12 +241,12 @@ def on_chain_end(self, outputs, *, run_id, **kwargs): del self.span_map[run_id] def on_chain_error(self, error, *, run_id, **kwargs): - # type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any """Run when chain errors.""" self._handle_error(run_id, error) def on_tool_start(self, serialized, input_str, *, run_id, **kwargs): - # type: (Dict[str, Any], str, Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Dict[str, Any]) -> Any """Run when tool starts running.""" with capture_internal_exceptions(): if not run_id: @@ -260,7 +263,7 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs): ) def on_tool_end(self, output, *, run_id, **kwargs): - # type: (str, Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, str, UUID, Dict[str, Any]) -> Any """Run when tool ends running.""" with capture_internal_exceptions(): if not run_id or not self.span_map[run_id]: @@ -275,7 +278,7 @@ def on_tool_end(self, output, *, run_id, **kwargs): del self.span_map[run_id] def on_tool_error(self, error, *args, run_id, **kwargs): - # type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any + # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any """Run when tool errors.""" self._handle_error(run_id, error) @@ -290,7 +293,7 @@ def new_configure(*args, **kwargs): integration = sentry_sdk.get_client().get_integration(LangchainIntegration) with capture_internal_exceptions(): - new_callbacks = [] + new_callbacks = [] # type: List[BaseCallbackHandler] if "local_callbacks" in kwargs: existing_callbacks = kwargs["local_callbacks"] kwargs["local_callbacks"] = new_callbacks