Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): support parsing custom response types #1111

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._types import NoneType, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._exceptions import (
Expand Down Expand Up @@ -59,6 +60,7 @@
"OpenAI",
"AsyncOpenAI",
"file_from_path",
"BaseModel",
]

from .lib import azure as _azure
Expand Down
102 changes: 70 additions & 32 deletions src/openai/_legacy_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,28 @@
import logging
import datetime
import functools
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast
from typing_extensions import Awaitable, ParamSpec, get_args, override, deprecated, get_origin
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin

import anyio
import httpx
import pydantic

from ._types import NoneType
from ._utils import is_given
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import APIResponseValidationError

if TYPE_CHECKING:
from ._models import FinalRequestOptions
from ._base_client import Stream, BaseClient, AsyncStream
from ._base_client import BaseClient


P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")

log: logging.Logger = logging.getLogger(__name__)

Expand All @@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]):

_cast_to: type[R]
_client: BaseClient[Any, Any]
_parsed: R | None
_parsed_by_type: dict[type[Any], Any]
_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
Expand All @@ -62,27 +65,60 @@ def __init__(
) -> None:
self._cast_to = cast_to
self._client = client
self._parsed = None
self._parsed_by_type = {}
self._stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw

@overload
def parse(self, *, to: type[_T]) -> _T:
...

@overload
def parse(self) -> R:
...

def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.

NOTE: For the async client: this will become a coroutine in the next major version.

For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.

NOTE: For the async client: this will become a coroutine in the next major version.
You can customise the type that the response is parsed into through
the `to` argument, e.g.

```py
from openai import BaseModel


class MyModel(BaseModel):
foo: str


obj = response.parse(to=MyModel)
print(obj.foo)
```

We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `httpx.Response`
"""
if self._parsed is not None:
return self._parsed
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]

parsed = self._parse()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)

self._parsed = parsed
self._parsed_by_type[cache_key] = parsed
return parsed

@property
Expand Down Expand Up @@ -135,13 +171,29 @@ def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed

def _parse(self) -> R:
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
if self._stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")

return cast(
_T,
to(
cast_to=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
),
response=self.http_response,
client=cast(Any, self._client),
),
)

if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_to=_extract_stream_chunk_type(self._stream_cls),
cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
Expand All @@ -160,7 +212,7 @@ def _parse(self) -> R:
),
)

cast_to = self._cast_to
cast_to = to if to is not None else self._cast_to
if cast_to is NoneType:
return cast(R, None)

Expand All @@ -186,14 +238,9 @@ def _parse(self) -> R:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)

# The check here is necessary as we are subverting the the type system
# with casts as the relationship between TypeVars and Types are very strict
# which means we must return *exactly* what was input or transform it in a
# way that retains the TypeVar state. As we cannot do that in this function
# then we have to resort to using `cast`. At the time of writing, we know this
# to be safe as we have handled all the types that could be bound to the
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")

if (
cast_to is not object
and not origin is list
Expand All @@ -202,12 +249,12 @@ def _parse(self) -> R:
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)

# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type").split(";")
content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_to):
try:
Expand Down Expand Up @@ -253,15 +300,6 @@ def __init__(self) -> None:
)


def _extract_stream_chunk_type(stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])


def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
Expand Down
Loading