From d9acfd427e3d7d8c6bc3d6ed8994194a07ed6a92 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Thu, 30 May 2024 18:09:26 +0100 Subject: [PATCH] refactor(streaming)!: remove old event_handler API (#532) this wasn't actually seeing any usage and has been replaced by a more ergonomic iterator API. --- README.md | 2 +- src/anthropic/lib/streaming/__init__.py | 2 - src/anthropic/lib/streaming/_messages.py | 353 +++++------------------ src/anthropic/resources/messages.py | 163 +---------- tests/lib/streaming/test_messages.py | 40 +-- 5 files changed, 95 insertions(+), 465 deletions(-) diff --git a/README.md b/README.md index 703351d4..f27c10e8 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ async def main() -> None: asyncio.run(main()) ``` -Streaming with `client.messages.stream(...)` exposes [various helpers for your convenience](helpers.md) including event handlers and accumulation. +Streaming with `client.messages.stream(...)` exposes [various helpers for your convenience](helpers.md) including accumulation & SDK-specific events. Alternatively, you can use `client.messages.create(..., stream=True)` which only returns an async iterable of the events in the stream and thus uses less memory (it does not build up a final message object for you). diff --git a/src/anthropic/lib/streaming/__init__.py b/src/anthropic/lib/streaming/__init__.py index 83074e4a..0ab41209 100644 --- a/src/anthropic/lib/streaming/__init__.py +++ b/src/anthropic/lib/streaming/__init__.py @@ -7,9 +7,7 @@ ) from ._messages import ( MessageStream as MessageStream, - MessageStreamT as MessageStreamT, AsyncMessageStream as AsyncMessageStream, - AsyncMessageStreamT as AsyncMessageStreamT, MessageStreamManager as MessageStreamManager, AsyncMessageStreamManager as AsyncMessageStreamManager, ) diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index 78e52c33..f1073fea 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -1,8 +1,7 @@ from __future__ import annotations -import asyncio from types import TracebackType -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, cast from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never import httpx @@ -78,7 +77,6 @@ def close(self) -> None: Automatically called if the response body is read to completion. """ self.response.close() - self.on_end() def get_final_message(self) -> Message: """Waits until the stream has been read to completion and returns @@ -117,137 +115,24 @@ def current_message_snapshot(self) -> Message: assert self.__final_message_snapshot is not None return self.__final_message_snapshot - # event handlers - def on_stream_event(self, event: RawMessageStreamEvent) -> None: - """Callback that is fired for every Server-Sent-Event""" - - def on_message(self, message: Message) -> None: - """Callback that is fired when a full Message object is accumulated. - - This corresponds to the `message_stop` SSE type. - """ - - def on_content_block(self, content_block: ContentBlock) -> None: - """Callback that is fired whenever a full ContentBlock is accumulated. - - This corresponds to the `content_block_stop` SSE type. - """ - - def on_text(self, text: str, snapshot: str) -> None: - """Callback that is fired whenever a `text` ContentBlock is yielded. - - The first argument is the text delta and the second is the current accumulated - text, for example: - - ```py - on_text("Hello", "Hello") - on_text(" there", "Hello there") - on_text("!", "Hello there!") - ``` - """ - - def on_input_json(self, delta: str, snapshot: object) -> None: - """Callback that is fired whenever a `input_json_delta` ContentBlock is yielded. - - The first argument is the json string delta and the second is the current accumulated - parsed object, for example: - - ``` - on_input_json('{"locations": ["San ', {"locations": []}) - on_input_json('Francisco"]', {"locations": ["San Francisco"]}) - ``` - """ - - def on_exception(self, exception: Exception) -> None: - """Fires if any exception occurs""" - - def on_end(self) -> None: - ... - - def on_timeout(self) -> None: - """Fires if the request times out""" - def __stream__(self) -> Iterator[MessageStreamEvent]: - try: - for sse_event in self._raw_stream: - self.__final_message_snapshot = accumulate_event( - event=sse_event, - current_snapshot=self.__final_message_snapshot, - ) + for sse_event in self._raw_stream: + self.__final_message_snapshot = accumulate_event( + event=sse_event, + current_snapshot=self.__final_message_snapshot, + ) - events_to_fire = self._emit_sse_event(sse_event) - for event in events_to_fire: - yield event - except (httpx.TimeoutException, asyncio.TimeoutError) as exc: - self.on_timeout() - self.on_exception(exc) - raise - except Exception as exc: - self.on_exception(exc) - raise - finally: - self.on_end() + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event def __stream_text__(self) -> Iterator[str]: for chunk in self: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]: - self.on_stream_event(event) - events_to_fire: list[MessageStreamEvent] = [] - - if event.type == "message_start": - events_to_fire.append(event) - elif event.type == "message_delta": - events_to_fire.append(event) - elif event.type == "message_stop": - self.on_message(self.current_message_snapshot) - events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot)) - elif event.type == "content_block_start": - events_to_fire.append(event) - elif event.type == "content_block_delta": - events_to_fire.append(event) - - content_block = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content_block.type == "text": - self.on_text(event.delta.text, content_block.text) - events_to_fire.append( - TextEvent( - type="text", - text=event.delta.text, - snapshot=content_block.text, - ) - ) - elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": - self.on_input_json(event.delta.partial_json, content_block.input) - events_to_fire.append( - InputJsonEvent( - type="input_json", - partial_json=event.delta.partial_json, - snapshot=content_block.input, - ) - ) - elif event.type == "content_block_stop": - content_block = self.current_message_snapshot.content[event.index] - self.on_content_block(content_block) - - events_to_fire.append( - ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block), - ) - else: - # we only want exhaustive checking for linters, not at runtime - if TYPE_CHECKING: # type: ignore[unreachable] - assert_never(event) - - return events_to_fire - - -MessageStreamT = TypeVar("MessageStreamT", bound=MessageStream) - - -class MessageStreamManager(Generic[MessageStreamT]): +class MessageStreamManager: """Wrapper over MessageStream that is returned by `.stream()`. ```py @@ -260,22 +145,20 @@ class MessageStreamManager(Generic[MessageStreamT]): def __init__( self, api_request: Callable[[], Stream[RawMessageStreamEvent]], - event_handler_cls: type[MessageStreamT], ) -> None: - self.__event_handler: MessageStreamT | None = None - self.__event_handler_cls: type[MessageStreamT] = event_handler_cls + self.__stream: MessageStream | None = None self.__api_request = api_request - def __enter__(self) -> MessageStreamT: + def __enter__(self) -> MessageStream: raw_stream = self.__api_request() - self.__event_handler = self.__event_handler_cls( + self.__stream = MessageStream( cast_to=raw_stream._cast_to, response=raw_stream.response, client=raw_stream._client, ) - return self.__event_handler + return self.__stream def __exit__( self, @@ -283,8 +166,8 @@ def __exit__( exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: - if self.__event_handler is not None: - self.__event_handler.close() + if self.__stream is not None: + self.__stream.close() class AsyncMessageStream: @@ -344,7 +227,6 @@ async def close(self) -> None: Automatically called if the response body is read to completion. """ await self.response.aclose() - await self.on_end() async def get_final_message(self) -> Message: """Waits until the stream has been read to completion and returns @@ -383,146 +265,24 @@ def current_message_snapshot(self) -> Message: assert self.__final_message_snapshot is not None return self.__final_message_snapshot - # event handlers - async def on_stream_event(self, event: RawMessageStreamEvent) -> None: - """Callback that is fired for every Server-Sent-Event""" - - async def on_message(self, message: Message) -> None: - """Callback that is fired when a full Message object is accumulated. - - This corresponds to the `message_stop` SSE type. - """ - - async def on_content_block(self, content_block: ContentBlock) -> None: - """Callback that is fired whenever a full ContentBlock is accumulated. - - This corresponds to the `content_block_stop` SSE type. - """ - - async def on_text(self, text: str, snapshot: str) -> None: - """Callback that is fired whenever a `text` ContentBlock is yielded. - - The first argument is the text delta and the second is the current accumulated - text, for example: - - ``` - on_text("Hello", "Hello") - on_text(" there", "Hello there") - on_text("!", "Hello there!") - ``` - """ - - async def on_input_json(self, delta: str, snapshot: object) -> None: - """Callback that is fired whenever a `input_json_delta` ContentBlock is yielded. - - The first argument is the json string delta and the second is the current accumulated - parsed object, for example: - - ``` - on_input_json('{"locations": ["San ', {"locations": []}) - on_input_json('Francisco"]', {"locations": ["San Francisco"]}) - ``` - """ - - async def on_final_text(self, text: str) -> None: - """Callback that is fired whenever a full `text` ContentBlock is accumulated. - - This corresponds to the `content_block_stop` SSE type. - """ - - async def on_exception(self, exception: Exception) -> None: - """Fires if any exception occurs""" - - async def on_end(self) -> None: - ... - - async def on_timeout(self) -> None: - """Fires if the request times out""" - async def __stream__(self) -> AsyncIterator[MessageStreamEvent]: - try: - async for sse_event in self._raw_stream: - self.__final_message_snapshot = accumulate_event( - event=sse_event, - current_snapshot=self.__final_message_snapshot, - ) + async for sse_event in self._raw_stream: + self.__final_message_snapshot = accumulate_event( + event=sse_event, + current_snapshot=self.__final_message_snapshot, + ) - events_to_fire = await self._emit_sse_event(sse_event) - for event in events_to_fire: - yield event - except (httpx.TimeoutException, asyncio.TimeoutError) as exc: - await self.on_timeout() - await self.on_exception(exc) - raise - except Exception as exc: - await self.on_exception(exc) - raise - finally: - await self.on_end() + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event async def __stream_text__(self) -> AsyncIterator[str]: async for chunk in self: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - async def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]: - await self.on_stream_event(event) - - events_to_fire: list[MessageStreamEvent] = [] - - if event.type == "message_start": - events_to_fire.append(event) - elif event.type == "message_delta": - events_to_fire.append(event) - elif event.type == "message_stop": - await self.on_message(self.current_message_snapshot) - events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot)) - elif event.type == "content_block_start": - events_to_fire.append(event) - elif event.type == "content_block_delta": - events_to_fire.append(event) - - content_block = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content_block.type == "text": - await self.on_text(event.delta.text, content_block.text) - events_to_fire.append( - TextEvent( - type="text", - text=event.delta.text, - snapshot=content_block.text, - ) - ) - elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": - await self.on_input_json(event.delta.partial_json, content_block.input) - events_to_fire.append( - InputJsonEvent( - type="input_json", - partial_json=event.delta.partial_json, - snapshot=content_block.input, - ) - ) - elif event.type == "content_block_stop": - content_block = self.current_message_snapshot.content[event.index] - await self.on_content_block(content_block) - - if content_block.type == "text": - await self.on_final_text(content_block.text) - - events_to_fire.append( - ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block), - ) - else: - # we only want exhaustive checking for linters, not at runtime - if TYPE_CHECKING: # type: ignore[unreachable] - assert_never(event) - - return events_to_fire - - -AsyncMessageStreamT = TypeVar("AsyncMessageStreamT", bound=AsyncMessageStream) - -class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]): +class AsyncMessageStreamManager: """Wrapper over AsyncMessageStream that is returned by `.stream()` so that an async context manager can be used without `await`ing the original client call. @@ -537,22 +297,20 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]): def __init__( self, api_request: Awaitable[AsyncStream[RawMessageStreamEvent]], - event_handler_cls: type[AsyncMessageStreamT], ) -> None: - self.__event_handler: AsyncMessageStreamT | None = None - self.__event_handler_cls: type[AsyncMessageStreamT] = event_handler_cls + self.__stream: AsyncMessageStream | None = None self.__api_request = api_request - async def __aenter__(self) -> AsyncMessageStreamT: + async def __aenter__(self) -> AsyncMessageStream: raw_stream = await self.__api_request - self.__event_handler = self.__event_handler_cls( + self.__stream = AsyncMessageStream( cast_to=raw_stream._cast_to, response=raw_stream.response, client=raw_stream._client, ) - return self.__event_handler + return self.__stream async def __aexit__( self, @@ -560,8 +318,57 @@ async def __aexit__( exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: - if self.__event_handler is not None: - await self.__event_handler.close() + if self.__stream is not None: + await self.__stream.close() + + +def build_events( + *, + event: RawMessageStreamEvent, + message_snapshot: Message, +) -> list[MessageStreamEvent]: + events_to_fire: list[MessageStreamEvent] = [] + + if event.type == "message_start": + events_to_fire.append(event) + elif event.type == "message_delta": + events_to_fire.append(event) + elif event.type == "message_stop": + events_to_fire.append(MessageStopEvent(type="message_stop", message=message_snapshot)) + elif event.type == "content_block_start": + events_to_fire.append(event) + elif event.type == "content_block_delta": + events_to_fire.append(event) + + content_block = message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + events_to_fire.append( + TextEvent( + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) + elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": + events_to_fire.append( + InputJsonEvent( + type="input_json", + partial_json=event.delta.partial_json, + snapshot=content_block.input, + ) + ) + elif event.type == "content_block_stop": + content_block = message_snapshot.content[event.index] + + events_to_fire.append( + ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block), + ) + else: + # we only want exhaustive checking for linters, not at runtime + if TYPE_CHECKING: # type: ignore[unreachable] + assert_never(event) + + return events_to_fire JSON_BUF_PROPERTY = "__json_buf" diff --git a/src/anthropic/resources/messages.py b/src/anthropic/resources/messages.py index 0e3de5ad..9579d98e 100644 --- a/src/anthropic/resources/messages.py +++ b/src/anthropic/resources/messages.py @@ -23,14 +23,7 @@ from .._base_client import ( make_request_options, ) -from ..lib.streaming import ( - MessageStream, - MessageStreamT, - AsyncMessageStream, - AsyncMessageStreamT, - MessageStreamManager, - AsyncMessageStreamManager, -) +from ..lib.streaming import MessageStreamManager, AsyncMessageStreamManager from ..types.message import Message from ..types.tool_param import ToolParam from ..types.message_param import MessageParam @@ -930,42 +923,6 @@ def create( stream_cls=Stream[RawMessageStreamEvent], ) - @overload - def stream( - self, - *, - max_tokens: int, - messages: Iterable[MessageParam], - model: Union[ - str, - Literal[ - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - "claude-2.1", - "claude-2.0", - "claude-instant-1.2", - ], - ], - metadata: message_create_params.Metadata | NotGiven = NOT_GIVEN, - stop_sequences: List[str] | NotGiven = NOT_GIVEN, - system: str | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_k: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - tool_choice: message_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> MessageStreamManager[MessageStream]: - """Create a Message stream""" - ... - - @overload def stream( self, *, @@ -990,53 +947,16 @@ def stream( top_p: float | NotGiven = NOT_GIVEN, tool_choice: message_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN, - event_handler: type[MessageStreamT], - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> MessageStreamManager[MessageStreamT]: - """Create a Message stream""" - ... - - def stream( # pyright: ignore[reportInconsistentOverload] - self, - *, - max_tokens: int, - messages: Iterable[MessageParam], - model: Union[ - str, - Literal[ - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - "claude-2.1", - "claude-2.0", - "claude-instant-1.2", - ], - ], - metadata: message_create_params.Metadata | NotGiven = NOT_GIVEN, - stop_sequences: List[str] | NotGiven = NOT_GIVEN, - system: str | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_k: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - tool_choice: message_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN, - event_handler: type[MessageStreamT] = MessageStream, # type: ignore[assignment] # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> MessageStreamManager[MessageStream] | MessageStreamManager[MessageStreamT]: + ) -> MessageStreamManager: """Create a Message stream""" extra_headers = { "X-Stainless-Stream-Helper": "messages", - "X-Stainless-Custom-Event-Handler": "true" if event_handler != MessageStream else "false", **(extra_headers or {}), } make_request = partial( @@ -1066,7 +986,7 @@ def stream( # pyright: ignore[reportInconsistentOverload] stream=True, stream_cls=Stream[RawMessageStreamEvent], ) - return MessageStreamManager(make_request, event_handler) + return MessageStreamManager(make_request) class AsyncMessages(AsyncAPIResource): @@ -1960,7 +1880,6 @@ async def create( stream_cls=AsyncStream[RawMessageStreamEvent], ) - @overload def stream( self, *, @@ -1991,82 +1910,10 @@ def stream( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AsyncMessageStreamManager[AsyncMessageStream]: - """Create a Message stream""" - ... - - @overload - def stream( - self, - *, - max_tokens: int, - messages: Iterable[MessageParam], - model: Union[ - str, - Literal[ - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - "claude-2.1", - "claude-2.0", - "claude-instant-1.2", - ], - ], - metadata: message_create_params.Metadata | NotGiven = NOT_GIVEN, - stop_sequences: List[str] | NotGiven = NOT_GIVEN, - system: str | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_k: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - tool_choice: message_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN, - event_handler: type[AsyncMessageStreamT], - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AsyncMessageStreamManager[AsyncMessageStreamT]: - """Create a Message stream""" - ... - - def stream( # pyright: ignore[reportInconsistentOverload] - self, - *, - max_tokens: int, - messages: Iterable[MessageParam], - model: Union[ - str, - Literal[ - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - "claude-2.1", - "claude-2.0", - "claude-instant-1.2", - ], - ], - metadata: message_create_params.Metadata | NotGiven = NOT_GIVEN, - stop_sequences: List[str] | NotGiven = NOT_GIVEN, - system: str | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_k: int | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - tool_choice: message_create_params.ToolChoice | NotGiven = NOT_GIVEN, - tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN, - event_handler: type[AsyncMessageStreamT] = AsyncMessageStream, # type: ignore[assignment] - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AsyncMessageStreamManager[AsyncMessageStream] | AsyncMessageStreamManager[AsyncMessageStreamT]: + ) -> AsyncMessageStreamManager: """Create a Message stream""" extra_headers = { "X-Stainless-Stream-Helper": "messages", - "X-Stainless-Custom-Event-Handler": "true" if event_handler != AsyncMessageStream else "false", **(extra_headers or {}), } request = self._post( @@ -2095,7 +1942,7 @@ def stream( # pyright: ignore[reportInconsistentOverload] stream=True, stream_cls=AsyncStream[RawMessageStreamEvent], ) - return AsyncMessageStreamManager(request, event_handler) + return AsyncMessageStreamManager(request) class MessagesWithRawResponse: diff --git a/tests/lib/streaming/test_messages.py b/tests/lib/streaming/test_messages.py index c09baade..408546de 100644 --- a/tests/lib/streaming/test_messages.py +++ b/tests/lib/streaming/test_messages.py @@ -3,16 +3,15 @@ import os import inspect from typing import Any, TypeVar, cast -from typing_extensions import Iterator, AsyncIterator, override +from typing_extensions import Iterator, AsyncIterator import httpx import pytest from respx import MockRouter from anthropic import Stream, Anthropic, AsyncStream, AsyncAnthropic -from anthropic.lib.streaming import MessageStream, AsyncMessageStream +from anthropic.lib.streaming import MessageStreamEvent from anthropic.types.message import Message -from anthropic.types.message_stream_event import MessageStreamEvent base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") api_key = "my-anthropic-api-key" @@ -63,29 +62,7 @@ async def to_async_iter(iter: Iterator[_T]) -> AsyncIterator[_T]: yield event -class SyncEventTracker(MessageStream): - def __init__(self, *, cast_to: type[MessageStreamEvent], response: httpx.Response, client: Anthropic) -> None: - super().__init__(cast_to=cast_to, response=response, client=client) - - self._events: list[MessageStreamEvent] = [] - - @override - def on_stream_event(self, event: MessageStreamEvent) -> None: - self._events.append(event) - - -class AsyncEventTracker(AsyncMessageStream): - def __init__(self, *, cast_to: type[MessageStreamEvent], response: httpx.Response, client: AsyncAnthropic) -> None: - super().__init__(cast_to=cast_to, response=response, client=client) - - self._events: list[MessageStreamEvent] = [] - - @override - async def on_stream_event(self, event: MessageStreamEvent) -> None: - self._events.append(event) - - -def assert_basic_response(stream: SyncEventTracker | AsyncEventTracker, message: Message) -> None: +def assert_basic_response(events: list[MessageStreamEvent], message: Message) -> None: assert message.id == "msg_4QpJur2dWWDjF6C758FbBw5vm12BaVipnK" assert message.model == "claude-3-opus-20240229" assert message.role == "assistant" @@ -98,12 +75,15 @@ def assert_basic_response(stream: SyncEventTracker | AsyncEventTracker, message: assert content.type == "text" assert content.text == "Hello there!" - assert [e.type for e in stream._events] == [ + assert [e.type for e in events] == [ "message_start", "content_block_start", "content_block_delta", + "text", "content_block_delta", + "text", "content_block_delta", + "text", "content_block_stop", "message_delta", ] @@ -123,12 +103,11 @@ def test_basic_response(self, respx_mock: MockRouter) -> None: } ], model="claude-3-opus-20240229", - event_handler=SyncEventTracker, ) as stream: with pytest.warns(DeprecationWarning): assert isinstance(cast(Any, stream), Stream) - assert_basic_response(stream, stream.get_final_message()) + assert_basic_response([event for event in stream], stream.get_final_message()) @pytest.mark.respx(base_url=base_url) def test_context_manager(self, respx_mock: MockRouter) -> None: @@ -165,12 +144,11 @@ async def test_basic_response(self, respx_mock: MockRouter) -> None: } ], model="claude-3-opus-20240229", - event_handler=AsyncEventTracker, ) as stream: with pytest.warns(DeprecationWarning): assert isinstance(cast(Any, stream), AsyncStream) - assert_basic_response(stream, await stream.get_final_message()) + assert_basic_response([event async for event in stream], await stream.get_final_message()) @pytest.mark.asyncio @pytest.mark.respx(base_url=base_url)