Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Add callback system and reducers op #3028

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
4 changes: 4 additions & 0 deletions tests/integrations/anthropic/anthropic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ async def test_async_anthropic_stream(

assert call.exception is None and call.ended_at is not None
output = call.output

print(f"{output=}")
print(f"{message=}")

assert output.id == message.id
assert output.model == message.model
assert output.stop_reason == "end_turn"
Expand Down
240 changes: 99 additions & 141 deletions weave/integrations/anthropic/anthropic_sdk.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,97 @@
from __future__ import annotations

import importlib
from collections.abc import AsyncIterator, Iterator
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)
from typing import Any, Callable

from anthropic import MessageStopEvent

import weave
from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator
from weave.trace.op import Callback
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.weave_client import Call


def should_accumulate(call: Call) -> bool:
return bool(call.inputs.get("stream"))


class AnthropicCallback(Callback):
def __init__(self):
self.acc = None

def after_yield(self, call: Call, value: Any) -> None:
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
TextBlock,
Usage,
)

print(f"{value=}, {self.acc=}")

if TYPE_CHECKING:
from anthropic.lib.streaming import MessageStream
from anthropic.types import Message, MessageStreamEvent


def anthropic_accumulator(
acc: Optional["Message"],
value: "MessageStreamEvent",
) -> "Message":
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
TextBlock,
Usage,
)

if acc is None:
if hasattr(value, "message"):
acc = Message(
if self.acc is None:
if not hasattr(value, "message"):
raise ValueError("Initial event must contain a message")
self.acc = Message(
id=value.message.id,
role=value.message.role,
content=[],
model=value.message.model,
stop_reason=value.message.stop_reason,
stop_sequence=value.message.stop_sequence,
type=value.message.type, # Include the type field
type=value.message.type,
usage=Usage(input_tokens=0, output_tokens=0),
)
else:
raise ValueError("Initial event must contain a message")

# Merge in the usage info if available
if hasattr(value, "message") and value.message.usage is not None:
acc.usage.input_tokens += value.message.usage.input_tokens
# Merge in the usage info if available
if hasattr(value, "message") and value.message.usage is not None:
self.acc.usage.input_tokens += value.message.usage.input_tokens

# Accumulate the content if it's a ContentBlockDeltaEvent
if isinstance(value, ContentBlockDeltaEvent) and hasattr(value.delta, "text"):
if acc.content and isinstance(acc.content[-1], TextBlock):
acc.content[-1].text += value.delta.text
else:
acc.content.append(TextBlock(type="text", text=value.delta.text))
# Accumulate the content if it's a ContentBlockDeltaEvent
if isinstance(value, ContentBlockDeltaEvent) and hasattr(value.delta, "text"):
if self.acc.content and isinstance(self.acc.content[-1], TextBlock):
self.acc.content[-1].text += value.delta.text
else:
self.acc.content.append(TextBlock(type="text", text=value.delta.text))

# Handle MessageDeltaEvent for stop_reason and stop_sequence
if isinstance(value, MessageDeltaEvent):
if hasattr(value.delta, "stop_reason") and value.delta.stop_reason:
acc.stop_reason = value.delta.stop_reason
if hasattr(value.delta, "stop_sequence") and value.delta.stop_sequence:
acc.stop_sequence = value.delta.stop_sequence
if hasattr(value, "usage") and value.usage.output_tokens:
acc.usage.output_tokens = value.usage.output_tokens
# Handle MessageDeltaEvent for stop_reason and stop_sequence
if isinstance(value, MessageDeltaEvent):
if hasattr(value.delta, "stop_reason") and value.delta.stop_reason:
self.acc.stop_reason = value.delta.stop_reason
if hasattr(value.delta, "stop_sequence") and value.delta.stop_sequence:
self.acc.stop_sequence = value.delta.stop_sequence
if hasattr(value, "usage") and value.usage.output_tokens:
self.acc.usage.output_tokens = value.usage.output_tokens

return acc
def after_yield_all(self, call: Call) -> None:
call.output = self.acc


# Unlike other integrations, streaming is based on input flag
def should_use_accumulator(inputs: dict) -> bool:
return isinstance(inputs, dict) and bool(inputs.get("stream"))
class AnthropicStreamingCallback:
def __init__(self):
self.acc = None

def after_yield(self, call: Call, value: Any) -> None:
print(f"{value=}, {self.acc=}")

def create_wrapper_sync(
name: str,
) -> Callable[[Callable], Callable]:
if self.acc is None:
self.acc = ""
if isinstance(value, MessageStopEvent):
self.acc = value.message

def after_yield_all(self, call: Call) -> None:
call.output = self.acc


def create_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"
op = weave.op()(fn)
op.name = name # type: ignore
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
should_accumulate=should_use_accumulator,
return weave.op(
fn,
name=name,
callbacks=[AnthropicCallback()],
__should_accumulate=should_accumulate,
)

return wrapper
Expand All @@ -92,9 +100,7 @@ def wrapper(fn: Callable) -> Callable:
# Surprisingly, the async `client.chat.completions.create` does not pass
# `inspect.iscoroutinefunction`, so we can't dispatch on it and must write
# it manually here...
def create_wrapper_async(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -104,92 +110,44 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op()(_fn_wrapper(fn))
op.name = name # type: ignore
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: anthropic_accumulator,
should_accumulate=should_use_accumulator,
return weave.op(
_fn_wrapper(fn),
name=name,
callbacks=[AnthropicCallback()],
__should_accumulate=should_accumulate,
)

return wrapper


def create_wrapper_stream(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
return weave.op(
fn,
name=name,
callbacks=[AnthropicStreamingCallback()],
__should_accumulate=lambda call: True,
__should_use_contextmanager=lambda f: True,
)

return wrapper


## This part of the code is for dealing with the other way
## of streaming, by calling Messages.stream
## this has 2 options: event based or text based.
## This code handles both cases by patching the _IteratorWrapper
## and adding a text_stream property to it.


def anthropic_stream_accumulator(
acc: Optional["Message"],
value: "MessageStream",
) -> "Message":
from anthropic.lib.streaming._types import MessageStopEvent

if acc is None:
acc = ""
if isinstance(value, MessageStopEvent):
acc = value.message
return acc


class AnthropicIteratorWrapper(_IteratorWrapper):
def __getattr__(self, name: str) -> Any:
"""Delegate all other attributes to the wrapped iterator."""
if name in [
"_iterator_or_ctx_manager",
"_on_yield",
"_on_error",
"_on_close",
"_on_finished_called",
"_call_on_error_once",
"text_stream",
]:
return object.__getattribute__(self, name)
return getattr(self._iterator_or_ctx_manager, name)

def __stream_text__(self) -> Union[Iterator[str], AsyncIterator[str]]:
if isinstance(self._iterator_or_ctx_manager, AsyncIterator):
return self.__async_stream_text__()
else:
return self.__sync_stream_text__()

def __sync_stream_text__(self) -> Iterator[str]: # type: ignore
for chunk in self: # type: ignore
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": # type: ignore
yield chunk.delta.text # type: ignore

async def __async_stream_text__(self) -> AsyncIterator[str]: # type: ignore
async for chunk in self: # type: ignore
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": # type: ignore
yield chunk.delta.text # type: ignore

@property
def text_stream(self) -> Union[Iterator[str], AsyncIterator[str]]:
return self.__stream_text__()


def create_stream_wrapper(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_async_stream(name: str) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
return add_accumulator(
op, # type: ignore
make_accumulator=lambda _: anthropic_stream_accumulator,
should_accumulate=lambda _: True,
iterator_wrapper=AnthropicIteratorWrapper, # type: ignore
return weave.op(
fn,
name=name,
callbacks=[AnthropicStreamingCallback()],
__should_accumulate=lambda call: True,
__should_use_contextmanager=lambda f: True,
)

return wrapper


anthropic_patcher = MultiPatcher(
[
# Patch the sync messages.create method for all messages.create methods
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.create",
Expand All @@ -203,12 +161,12 @@ def wrapper(fn: Callable) -> Callable:
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"Messages.stream",
create_stream_wrapper(name="anthropic.Messages.stream"),
create_wrapper_stream(name="anthropic.Messages.stream"),
),
SymbolPatcher(
lambda: importlib.import_module("anthropic.resources.messages"),
"AsyncMessages.stream",
create_stream_wrapper(name="anthropic.AsyncMessages.stream"),
create_wrapper_async_stream(name="anthropic.AsyncMessages.stream"),
),
]
)
Loading
Loading