Skip to content

Commit

Permalink
refactor(stream): make MessageStream wrap Stream directly (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki authored Dec 29, 2024
1 parent 9cf1e99 commit 5669399
Showing 1 changed file with 16 additions and 51 deletions.
67 changes: 16 additions & 51 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from ..._models import build, construct_type
from ..._streaming import Stream, AsyncStream

if TYPE_CHECKING:
from ..._client import Anthropic, AsyncAnthropic


class MessageStream:
text_stream: Iterator[str]
Expand All @@ -33,24 +30,15 @@ class MessageStream:
```
"""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[RawMessageStreamEvent],
response: httpx.Response,
client: Anthropic,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client

def __init__(self, raw_stream: Stream[RawMessageStreamEvent]) -> None:
self._raw_stream = raw_stream
self.text_stream = self.__stream_text__()
self._iterator = self.__stream__()
self.__final_message_snapshot: Message | None = None

self._iterator = self.__stream__()
self._raw_stream: Stream[RawMessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client)
@property
def response(self) -> httpx.Response:
return self._raw_stream.response

def __next__(self) -> MessageStreamEvent:
return self._iterator.__next__()
Expand All @@ -76,7 +64,7 @@ def close(self) -> None:
Automatically called if the response body is read to completion.
"""
self.response.close()
self._raw_stream.close()

def get_final_message(self) -> Message:
"""Waits until the stream has been read to completion and returns
Expand Down Expand Up @@ -151,13 +139,7 @@ def __init__(

def __enter__(self) -> MessageStream:
raw_stream = self.__api_request()

self.__stream = MessageStream(
cast_to=raw_stream._cast_to,
response=raw_stream.response,
client=raw_stream._client,
)

self.__stream = MessageStream(raw_stream)
return self.__stream

def __exit__(
Expand All @@ -181,26 +163,15 @@ class AsyncMessageStream:
```
"""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[RawMessageStreamEvent],
response: httpx.Response,
client: AsyncAnthropic,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client

def __init__(self, raw_stream: AsyncStream[RawMessageStreamEvent]) -> None:
self._raw_stream = raw_stream
self.text_stream = self.__stream_text__()
self._iterator = self.__stream__()
self.__final_message_snapshot: Message | None = None

self._iterator = self.__stream__()
self._raw_stream: AsyncStream[RawMessageStreamEvent] = AsyncStream(
cast_to=cast_to, response=response, client=client
)
@property
def response(self) -> httpx.Response:
return self._raw_stream.response

async def __anext__(self) -> MessageStreamEvent:
return await self._iterator.__anext__()
Expand All @@ -226,7 +197,7 @@ async def close(self) -> None:
Automatically called if the response body is read to completion.
"""
await self.response.aclose()
await self._raw_stream.close()

async def get_final_message(self) -> Message:
"""Waits until the stream has been read to completion and returns
Expand Down Expand Up @@ -303,13 +274,7 @@ def __init__(

async def __aenter__(self) -> AsyncMessageStream:
raw_stream = await self.__api_request

self.__stream = AsyncMessageStream(
cast_to=raw_stream._cast_to,
response=raw_stream.response,
client=raw_stream._client,
)

self.__stream = AsyncMessageStream(raw_stream)
return self.__stream

async def __aexit__(
Expand Down

0 comments on commit 5669399

Please sign in to comment.