Skip to content

Commit

Permalink
chore(internal): minor core client restructuring (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Feb 29, 2024
1 parent e41abf7 commit 4314cdc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
RAW_RESPONSE_HEADER,
OVERRIDE_CAST_TO_HEADER,
)
from ._streaming import Stream, AsyncStream
from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
from ._exceptions import (
APIStatusError,
APITimeoutError,
Expand Down Expand Up @@ -431,6 +431,9 @@ def _prepare_url(self, url: str) -> URL:

return merge_url

def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
return SSEDecoder()

def _build_request(
self,
options: FinalRequestOptions,
Expand Down
34 changes: 28 additions & 6 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
from typing_extensions import Self, TypeGuard, override, get_origin
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable

import httpx

Expand All @@ -24,6 +24,8 @@ class Stream(Generic[_T]):

response: httpx.Response

_decoder: SSEDecoder | SSEBytesDecoder

def __init__(
self,
*,
Expand All @@ -34,7 +36,7 @@ def __init__(
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = SSEDecoder()
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()

def __next__(self) -> _T:
Expand All @@ -45,7 +47,10 @@ def __iter__(self) -> Iterator[_T]:
yield item

def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter(self.response.iter_lines())
if isinstance(self._decoder, SSEBytesDecoder):
yield from self._decoder.iter_bytes(self.response.iter_bytes())
else:
yield from self._decoder.iter(self.response.iter_lines())

def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down Expand Up @@ -97,6 +102,8 @@ class AsyncStream(Generic[_T]):

response: httpx.Response

_decoder: SSEDecoder | SSEBytesDecoder

def __init__(
self,
*,
Expand All @@ -107,7 +114,7 @@ def __init__(
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = SSEDecoder()
self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()

async def __anext__(self) -> _T:
Expand All @@ -118,8 +125,12 @@ async def __aiter__(self) -> AsyncIterator[_T]:
yield item

async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse
if isinstance(self._decoder, SSEBytesDecoder):
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
yield sse
else:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse

async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
Expand Down Expand Up @@ -284,6 +295,17 @@ def decode(self, line: str) -> ServerSentEvent | None:
return None


@runtime_checkable
class SSEBytesDecoder(Protocol):
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
...

def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
...


def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
Expand Down

0 comments on commit 4314cdc

Please sign in to comment.