Skip to content

Commit

Permalink
Add support for async calls in Anthropic and OpenAI integration (#3497)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Anton Pirker <[email protected]>
  • Loading branch information
vetyy and antonpirker authored Oct 17, 2024
1 parent 891afee commit 9ae5820
Show file tree
Hide file tree
Showing 5 changed files with 1,366 additions and 209 deletions.
270 changes: 190 additions & 80 deletions sentry_sdk/integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
)

try:
from anthropic.resources import Messages
from anthropic.resources import AsyncMessages, Messages

if TYPE_CHECKING:
from anthropic.types import MessageStreamEvent
except ImportError:
raise DidNotEnable("Anthropic not installed")


if TYPE_CHECKING:
from typing import Any, Iterator
from typing import Any, AsyncIterator, Iterator
from sentry_sdk.tracing import Span


Expand All @@ -46,6 +45,7 @@ def setup_once():
raise DidNotEnable("anthropic 0.16 or newer required.")

Messages.create = _wrap_message_create(Messages.create)
AsyncMessages.create = _wrap_message_create_async(AsyncMessages.create)


def _capture_exception(exc):
Expand Down Expand Up @@ -75,7 +75,9 @@ def _calculate_token_usage(result, span):

def _get_responses(content):
# type: (list[Any]) -> list[dict[str, Any]]
"""Get JSON of a Anthropic responses."""
"""
Get JSON of a Anthropic responses.
"""
responses = []
for item in content:
if hasattr(item, "text"):
Expand All @@ -88,94 +90,202 @@ def _get_responses(content):
return responses


def _collect_ai_data(event, input_tokens, output_tokens, content_blocks):
# type: (MessageStreamEvent, int, int, list[str]) -> tuple[int, int, list[str]]
"""
Count token usage and collect content blocks from the AI streaming response.
"""
with capture_internal_exceptions():
if hasattr(event, "type"):
if event.type == "message_start":
usage = event.message.usage
input_tokens += usage.input_tokens
output_tokens += usage.output_tokens
elif event.type == "content_block_start":
pass
elif event.type == "content_block_delta":
if hasattr(event.delta, "text"):
content_blocks.append(event.delta.text)
elif event.type == "content_block_stop":
pass
elif event.type == "message_delta":
output_tokens += event.usage.output_tokens

return input_tokens, output_tokens, content_blocks


def _add_ai_data_to_span(
span, integration, input_tokens, output_tokens, content_blocks
):
# type: (Span, AnthropicIntegration, int, int, list[str]) -> None
"""
Add token usage and content blocks from the AI streaming response to the span.
"""
with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
complete_message = "".join(content_blocks)
span.set_data(
SPANDATA.AI_RESPONSES,
[{"type": "text", "text": complete_message}],
)
total_tokens = input_tokens + output_tokens
record_token_usage(span, input_tokens, output_tokens, total_tokens)
span.set_data(SPANDATA.AI_STREAMING, True)


def _sentry_patched_create_common(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
integration = kwargs.pop("integration")
if integration is None:
return f(*args, **kwargs)

if "messages" not in kwargs:
return f(*args, **kwargs)

try:
iter(kwargs["messages"])
except TypeError:
return f(*args, **kwargs)

span = sentry_sdk.start_span(
op=OP.ANTHROPIC_MESSAGES_CREATE,
description="Anthropic messages create",
origin=AnthropicIntegration.origin,
)
span.__enter__()

result = yield f, args, kwargs

# add data to span and finish it
messages = list(kwargs["messages"])
model = kwargs.get("model")

with capture_internal_exceptions():
span.set_data(SPANDATA.AI_MODEL_ID, model)
span.set_data(SPANDATA.AI_STREAMING, False)

if should_send_default_pii() and integration.include_prompts:
span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages)

if hasattr(result, "content"):
if should_send_default_pii() and integration.include_prompts:
span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content))
_calculate_token_usage(result, span)
span.__exit__(None, None, None)

# Streaming response
elif hasattr(result, "_iterator"):
old_iterator = result._iterator

def new_iterator():
# type: () -> Iterator[MessageStreamEvent]
input_tokens = 0
output_tokens = 0
content_blocks = [] # type: list[str]

for event in old_iterator:
input_tokens, output_tokens, content_blocks = _collect_ai_data(
event, input_tokens, output_tokens, content_blocks
)
if event.type != "message_stop":
yield event

_add_ai_data_to_span(
span, integration, input_tokens, output_tokens, content_blocks
)
span.__exit__(None, None, None)

async def new_iterator_async():
# type: () -> AsyncIterator[MessageStreamEvent]
input_tokens = 0
output_tokens = 0
content_blocks = [] # type: list[str]

async for event in old_iterator:
input_tokens, output_tokens, content_blocks = _collect_ai_data(
event, input_tokens, output_tokens, content_blocks
)
if event.type != "message_stop":
yield event

_add_ai_data_to_span(
span, integration, input_tokens, output_tokens, content_blocks
)
span.__exit__(None, None, None)

if str(type(result._iterator)) == "<class 'async_generator'>":
result._iterator = new_iterator_async()
else:
result._iterator = new_iterator()

else:
span.set_data("unknown_response", True)
span.__exit__(None, None, None)

return result


def _wrap_message_create(f):
# type: (Any) -> Any
def _execute_sync(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
gen = _sentry_patched_create_common(f, *args, **kwargs)

try:
f, args, kwargs = next(gen)
except StopIteration as e:
return e.value

try:
try:
result = f(*args, **kwargs)
except Exception as exc:
_capture_exception(exc)
raise exc from None

return gen.send(result)
except StopIteration as e:
return e.value

@wraps(f)
def _sentry_patched_create(*args, **kwargs):
def _sentry_patched_create_sync(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
kwargs["integration"] = integration

if integration is None or "messages" not in kwargs:
return f(*args, **kwargs)
return _execute_sync(f, *args, **kwargs)

try:
iter(kwargs["messages"])
except TypeError:
return f(*args, **kwargs)
return _sentry_patched_create_sync

messages = list(kwargs["messages"])
model = kwargs.get("model")

span = sentry_sdk.start_span(
op=OP.ANTHROPIC_MESSAGES_CREATE,
name="Anthropic messages create",
origin=AnthropicIntegration.origin,
)
span.__enter__()
def _wrap_message_create_async(f):
# type: (Any) -> Any
async def _execute_async(f, *args, **kwargs):
# type: (Any, *Any, **Any) -> Any
gen = _sentry_patched_create_common(f, *args, **kwargs)

try:
result = f(*args, **kwargs)
except Exception as exc:
_capture_exception(exc)
span.__exit__(None, None, None)
raise exc from None
f, args, kwargs = next(gen)
except StopIteration as e:
return await e.value

with capture_internal_exceptions():
span.set_data(SPANDATA.AI_MODEL_ID, model)
span.set_data(SPANDATA.AI_STREAMING, False)
if should_send_default_pii() and integration.include_prompts:
span.set_data(SPANDATA.AI_INPUT_MESSAGES, messages)
if hasattr(result, "content"):
if should_send_default_pii() and integration.include_prompts:
span.set_data(SPANDATA.AI_RESPONSES, _get_responses(result.content))
_calculate_token_usage(result, span)
span.__exit__(None, None, None)
elif hasattr(result, "_iterator"):
old_iterator = result._iterator

def new_iterator():
# type: () -> Iterator[MessageStreamEvent]
input_tokens = 0
output_tokens = 0
content_blocks = []
with capture_internal_exceptions():
for event in old_iterator:
if hasattr(event, "type"):
if event.type == "message_start":
usage = event.message.usage
input_tokens += usage.input_tokens
output_tokens += usage.output_tokens
elif event.type == "content_block_start":
pass
elif event.type == "content_block_delta":
if hasattr(event.delta, "text"):
content_blocks.append(event.delta.text)
elif event.type == "content_block_stop":
pass
elif event.type == "message_delta":
output_tokens += event.usage.output_tokens
elif event.type == "message_stop":
continue
yield event

if should_send_default_pii() and integration.include_prompts:
complete_message = "".join(content_blocks)
span.set_data(
SPANDATA.AI_RESPONSES,
[{"type": "text", "text": complete_message}],
)
total_tokens = input_tokens + output_tokens
record_token_usage(
span, input_tokens, output_tokens, total_tokens
)
span.set_data(SPANDATA.AI_STREAMING, True)
span.__exit__(None, None, None)
try:
try:
result = await f(*args, **kwargs)
except Exception as exc:
_capture_exception(exc)
raise exc from None

result._iterator = new_iterator()
else:
span.set_data("unknown_response", True)
span.__exit__(None, None, None)
return gen.send(result)
except StopIteration as e:
return e.value

@wraps(f)
async def _sentry_patched_create_async(*args, **kwargs):
# type: (*Any, **Any) -> Any
integration = sentry_sdk.get_client().get_integration(AnthropicIntegration)
kwargs["integration"] = integration

return result
return await _execute_async(f, *args, **kwargs)

return _sentry_patched_create
return _sentry_patched_create_async
Loading

0 comments on commit 9ae5820

Please sign in to comment.