Skip to content

Commit

Permalink
Fix some type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-sentry committed Mar 27, 2024
1 parent 24301ad commit 988cab6
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,26 @@ def __init__(self, span, num_tokens=0):
self.num_tokens = num_tokens

Check warning on line 52 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L51-L52

Added lines #L51 - L52 were not covered by tests


class SentryLangchainCallback(BaseCallbackHandler):
class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
"""Base callback handler that can be used to handle callbacks from langchain."""

span_map = OrderedDict() # type: OrderedDict[UUID, WatchedSpan]

Check warning on line 58 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L58

Added line #L58 was not covered by tests

max_span_map_size = 0

Check warning on line 60 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L60

Added line #L60 was not covered by tests

def __init__(self, max_span_map_size, include_prompts):

Check warning on line 62 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L62

Added line #L62 was not covered by tests
# type: (int, bool) -> None
self.max_span_map_size = max_span_map_size
self.include_prompts = include_prompts

Check warning on line 65 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def gc_span_map(self):

Check warning on line 67 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L67

Added line #L67 was not covered by tests
# type: () -> None

while len(self.span_map) > self.max_span_map_size:
self.span_map.popitem(last=False)[1].span.__exit__(None, None, None)

Check warning on line 71 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L71

Added line #L71 was not covered by tests

def _handle_error(self, run_id, error):

Check warning on line 73 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L73

Added line #L73 was not covered by tests
# type: (str, Any) -> None
# type: (UUID, Any) -> None
if not run_id or not self.span_map[run_id]:
return

Check warning on line 76 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L76

Added line #L76 was not covered by tests

Expand All @@ -80,7 +83,7 @@ def _handle_error(self, run_id, error):
del self.span_map[run_id]

Check warning on line 83 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L80-L83

Added lines #L80 - L83 were not covered by tests

def _normalize_langchain_message(self, message):

Check warning on line 85 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L85

Added line #L85 was not covered by tests
# type: (BaseMessage) -> dict
# type: (BaseMessage) -> Any
parsed = {"content": message.content, "role": message.type}
parsed.update(message.additional_kwargs)
return parsed

Check warning on line 89 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L87-L89

Added lines #L87 - L89 were not covered by tests
Expand Down Expand Up @@ -113,7 +116,7 @@ def on_llm_start(
metadata=None,
**kwargs,
):
# type: (Dict[str, Any], List[str], Any, UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Any) -> Any
# type: (SentryLangchainCallback, Dict[str, Any], List[str], UUID, Optional[List[str]], Optional[UUID], Optional[Dict[str, Any]], Dict[str, Any]) -> Any
"""Run when LLM starts running."""
with capture_internal_exceptions():
if not run_id:
Expand All @@ -128,7 +131,7 @@ def on_llm_start(
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompts)

Check warning on line 131 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L131

Added line #L131 was not covered by tests

def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):

Check warning on line 133 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L133

Added line #L133 was not covered by tests
# type: (Dict[str, Any], List[List[BaseMessage]], Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, Dict[str, Any], List[List[BaseMessage]], UUID, Dict[str, Any]) -> Any
"""Run when Chat Model starts running."""
if not run_id:
return

Check warning on line 137 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L137

Added line #L137 was not covered by tests
Expand All @@ -151,7 +154,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
)

def on_llm_new_token(self, token, *, run_id, **kwargs):

Check warning on line 156 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L156

Added line #L156 was not covered by tests
# type: (str, Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, str, UUID, Dict[str, Any]) -> Any
"""Run on new LLM token. Only available when streaming is enabled."""
with capture_internal_exceptions():
if not run_id or not self.span_map[run_id]:
Expand All @@ -162,7 +165,7 @@ def on_llm_new_token(self, token, *, run_id, **kwargs):
span_data.num_tokens += 1

Check warning on line 165 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L164-L165

Added lines #L164 - L165 were not covered by tests

def on_llm_end(self, response, *, run_id, **kwargs):

Check warning on line 167 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L167

Added line #L167 was not covered by tests
# type: (LLMResult, Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, LLMResult, UUID, Dict[str, Any]) -> Any
"""Run when LLM ends running."""
with capture_internal_exceptions():
if not run_id:
Expand Down Expand Up @@ -206,13 +209,13 @@ def on_llm_end(self, response, *, run_id, **kwargs):
del self.span_map[run_id]

Check warning on line 209 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L208-L209

Added lines #L208 - L209 were not covered by tests

def on_llm_error(self, error, *, run_id, **kwargs):

Check warning on line 211 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L211

Added line #L211 was not covered by tests
# type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any
"""Run when LLM errors."""
with capture_internal_exceptions():
self._handle_error(run_id, error)

Check warning on line 215 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L215

Added line #L215 was not covered by tests

def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):

Check warning on line 217 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L217

Added line #L217 was not covered by tests
# type: (Dict[str, Any], Dict[str, Any], Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Dict[str, Any]) -> Any
"""Run when chain starts running."""
with capture_internal_exceptions():
if not run_id:
Expand All @@ -225,7 +228,7 @@ def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
)

def on_chain_end(self, outputs, *, run_id, **kwargs):

Check warning on line 230 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L230

Added line #L230 was not covered by tests
# type: (Dict[str, Any], Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, Dict[str, Any], UUID, Dict[str, Any]) -> Any
"""Run when chain ends running."""
with capture_internal_exceptions():
if not run_id or not self.span_map[run_id]:
Expand All @@ -238,12 +241,12 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
del self.span_map[run_id]

Check warning on line 241 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L239-L241

Added lines #L239 - L241 were not covered by tests

def on_chain_error(self, error, *, run_id, **kwargs):

Check warning on line 243 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L243

Added line #L243 was not covered by tests
# type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any
"""Run when chain errors."""
self._handle_error(run_id, error)

Check warning on line 246 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L246

Added line #L246 was not covered by tests

def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):

Check warning on line 248 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L248

Added line #L248 was not covered by tests
# type: (Dict[str, Any], str, Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, Dict[str, Any], str, UUID, Dict[str, Any]) -> Any
"""Run when tool starts running."""
with capture_internal_exceptions():
if not run_id:
Expand All @@ -260,7 +263,7 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
)

def on_tool_end(self, output, *, run_id, **kwargs):

Check warning on line 265 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L265

Added line #L265 was not covered by tests
# type: (str, Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, str, UUID, Dict[str, Any]) -> Any
"""Run when tool ends running."""
with capture_internal_exceptions():
if not run_id or not self.span_map[run_id]:
Expand All @@ -275,7 +278,7 @@ def on_tool_end(self, output, *, run_id, **kwargs):
del self.span_map[run_id]

Check warning on line 278 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L276-L278

Added lines #L276 - L278 were not covered by tests

def on_tool_error(self, error, *args, run_id, **kwargs):

Check warning on line 280 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L280

Added line #L280 was not covered by tests
# type: (Union[Exception, KeyboardInterrupt], Any, UUID, Any) -> Any
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Dict[str, Any]) -> Any
"""Run when tool errors."""
self._handle_error(run_id, error)

Check warning on line 283 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L283

Added line #L283 was not covered by tests

Expand All @@ -290,7 +293,7 @@ def new_configure(*args, **kwargs):
integration = sentry_sdk.get_client().get_integration(LangchainIntegration)

Check warning on line 293 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L293

Added line #L293 was not covered by tests

with capture_internal_exceptions():
new_callbacks = []
new_callbacks = [] # type: List[BaseCallbackHandler]

Check warning on line 296 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L296

Added line #L296 was not covered by tests
if "local_callbacks" in kwargs:
existing_callbacks = kwargs["local_callbacks"]
kwargs["local_callbacks"] = new_callbacks

Check warning on line 299 in sentry_sdk/integrations/langchain.py

View check run for this annotation

Codecov / codecov/patch

sentry_sdk/integrations/langchain.py#L298-L299

Added lines #L298 - L299 were not covered by tests
Expand Down

0 comments on commit 988cab6

Please sign in to comment.