Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(internal): minor utils restructuring #992

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/openai/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import datetime
import functools
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
from typing_extensions import Awaitable, ParamSpec, override, get_origin

import httpx

from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._utils import is_given
from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._exceptions import APIResponseValidationError
Expand Down Expand Up @@ -221,12 +221,13 @@ def __init__(self) -> None:


def _extract_stream_chunk_type(stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])
from ._base_client import Stream, AsyncStream

return extract_type_var_from_base(
stream_cls,
index=0,
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
)


def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
Expand Down
71 changes: 56 additions & 15 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,31 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
from typing_extensions import override
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
from typing_extensions import Self, override

import httpx

from ._types import ResponseT
from ._utils import is_mapping
from ._exceptions import APIError

if TYPE_CHECKING:
from ._client import OpenAI, AsyncOpenAI


class Stream(Generic[ResponseT]):
_T = TypeVar("_T")


class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[ResponseT],
cast_to: type[_T],
response: httpx.Response,
client: OpenAI,
) -> None:
Expand All @@ -33,18 +36,18 @@ def __init__(
self._decoder = SSEDecoder()
self._iterator = self.__stream__()

def __next__(self) -> ResponseT:
def __next__(self) -> _T:
return self._iterator.__next__()

def __iter__(self) -> Iterator[ResponseT]:
def __iter__(self) -> Iterator[_T]:
for item in self._iterator:
yield item

def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter(self.response.iter_lines())

def __stream__(self) -> Iterator[ResponseT]:
cast_to = self._cast_to
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
Expand All @@ -68,16 +71,35 @@ def __stream__(self) -> Iterator[ResponseT]:
for _sse in iterator:
...

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.response.close()

class AsyncStream(Generic[ResponseT]):

class AsyncStream(Generic[_T]):
"""Provides the core interface to iterate over an asynchronous stream response."""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[ResponseT],
cast_to: type[_T],
response: httpx.Response,
client: AsyncOpenAI,
) -> None:
Expand All @@ -87,19 +109,19 @@ def __init__(
self._decoder = SSEDecoder()
self._iterator = self.__stream__()

async def __anext__(self) -> ResponseT:
async def __anext__(self) -> _T:
return await self._iterator.__anext__()

async def __aiter__(self) -> AsyncIterator[ResponseT]:
async def __aiter__(self) -> AsyncIterator[_T]:
async for item in self._iterator:
yield item

async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse

async def __stream__(self) -> AsyncIterator[ResponseT]:
cast_to = self._cast_to
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
Expand All @@ -123,6 +145,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]:
async for _sse in iterator:
...

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()

async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.response.aclose()


class ServerSentEvent:
def __init__(
Expand Down
14 changes: 14 additions & 0 deletions src/openai/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,17 @@ def get(self, __key: str) -> str | None:
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"

PostParser = Callable[[Any], Any]


@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""

__orig_bases__: tuple[_GenericAlias]


class _GenericAlias(Protocol):
__origin__: type[object]
15 changes: 9 additions & 6 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,32 @@
from ._utils import parse_date as parse_date
from ._utils import is_sequence as is_sequence
from ._utils import coerce_float as coerce_float
from ._utils import is_list_type as is_list_type
from ._utils import is_mapping_t as is_mapping_t
from ._utils import removeprefix as removeprefix
from ._utils import removesuffix as removesuffix
from ._utils import extract_files as extract_files
from ._utils import is_sequence_t as is_sequence_t
from ._utils import is_union_type as is_union_type
from ._utils import required_args as required_args
from ._utils import coerce_boolean as coerce_boolean
from ._utils import coerce_integer as coerce_integer
from ._utils import file_from_path as file_from_path
from ._utils import parse_datetime as parse_datetime
from ._utils import strip_not_given as strip_not_given
from ._utils import deepcopy_minimal as deepcopy_minimal
from ._utils import extract_type_arg as extract_type_arg
from ._utils import is_required_type as is_required_type
from ._utils import get_async_library as get_async_library
from ._utils import is_annotated_type as is_annotated_type
from ._utils import maybe_coerce_float as maybe_coerce_float
from ._utils import get_required_header as get_required_header
from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
from ._utils import maybe_coerce_integer as maybe_coerce_integer
from ._utils import strip_annotated_type as strip_annotated_type
from ._typing import is_list_type as is_list_type
from ._typing import is_union_type as is_union_type
from ._typing import extract_type_arg as extract_type_arg
from ._typing import is_required_type as is_required_type
from ._typing import is_annotated_type as is_annotated_type
from ._typing import strip_annotated_type as strip_annotated_type
from ._typing import extract_type_var_from_base as extract_type_var_from_base
from ._streams import consume_sync_iterator as consume_sync_iterator
from ._streams import consume_async_iterator as consume_async_iterator
from ._transform import PropertyInfo as PropertyInfo
from ._transform import transform as transform
from ._transform import maybe_transform as maybe_transform
12 changes: 12 additions & 0 deletions src/openai/_utils/_streams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any
from typing_extensions import Iterator, AsyncIterator


def consume_sync_iterator(iterator: Iterator[Any]) -> None:
for _ in iterator:
...


async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
async for _ in iterator:
...
5 changes: 2 additions & 3 deletions src/openai/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

import pydantic

from ._utils import (
is_list,
is_mapping,
from ._utils import is_list, is_mapping
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
Expand Down
80 changes: 80 additions & 0 deletions src/openai/_utils/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from typing import Any, cast
from typing_extensions import Required, Annotated, get_args, get_origin

from .._types import InheritsGeneric
from .._compat import is_union as _is_union


def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated


def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list


def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))


def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))

return typ


def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err


def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases:
# we're given the class directly
return extract_type_arg(typ, index)

# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break

if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)

return extract_type_arg(target_base_class, index)

raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
35 changes: 1 addition & 34 deletions src/openai/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
overload,
)
from pathlib import Path
from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin
from typing_extensions import TypeGuard

import sniffio

from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._compat import is_union as _is_union
from .._compat import parse_date as parse_date
from .._compat import parse_datetime as parse_datetime

Expand Down Expand Up @@ -166,38 +165,6 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)


def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated


def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list


def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))


def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))

return typ


def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err


def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
Expand Down