Skip to content

Commit

Permalink
fix(client): attempt to parse unknown json content types (#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Nov 21, 2023
1 parent 2343e63 commit c26014e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 16 deletions.
20 changes: 14 additions & 6 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@
RAW_RESPONSE_HEADER,
)
from ._streaming import Stream, AsyncStream
from ._exceptions import APIStatusError, APITimeoutError, APIConnectionError
from ._exceptions import (
APIStatusError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
)

log: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -518,13 +523,16 @@ def _process_response_data(
if cast_to is UnknownResponse:
return cast(ResponseT, data)

if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
return cast(ResponseT, cast_to.build(response=response, data=data))
try:
if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
return cast(ResponseT, cast_to.build(response=response, data=data))

if self._strict_response_validation:
return cast(ResponseT, validate_type(type_=cast_to, value=data))
if self._strict_response_validation:
return cast(ResponseT, validate_type(type_=cast_to, value=data))

return cast(ResponseT, construct_type(type_=cast_to, value=data))
return cast(ResponseT, construct_type(type_=cast_to, value=data))
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, body=data) from err

@property
def qs(self) -> Querystring:
Expand Down
13 changes: 13 additions & 0 deletions src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,19 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
return construct_type(value=value, type_=type_)


def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
origin = get_origin(type_) or type_
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
return True

return False

return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)


def construct_type(*, value: object, type_: type) -> object:
"""Loose coercion to the expected type with construction of nested values.
Expand Down
31 changes: 21 additions & 10 deletions src/openai/_response.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

import inspect
import logging
import datetime
import functools
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin

import httpx
import pydantic

from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._utils import is_given
from ._models import BaseModel
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._exceptions import APIResponseValidationError

Expand All @@ -23,6 +23,8 @@
P = ParamSpec("P")
R = TypeVar("R")

log: logging.Logger = logging.getLogger(__name__)


class APIResponse(Generic[R]):
_cast_to: type[R]
Expand Down Expand Up @@ -174,6 +176,18 @@ def _parse(self) -> R:
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type").split(";")
if content_type != "application/json":
if is_basemodel(cast_to):
try:
data = response.json()
except Exception as exc:
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
else:
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)

if self._client._strict_response_validation:
raise APIResponseValidationError(
response=response,
Expand All @@ -188,14 +202,11 @@ def _parse(self) -> R:

data = response.json()

try:
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, body=data) from err
return self._client._process_response_data(
data=data,
cast_to=cast_to, # type: ignore
response=response,
)

@override
def __repr__(self) -> str:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,27 @@ class Model2(BaseModel):
assert isinstance(response, Model1)
assert response.foo == 1

@pytest.mark.respx(base_url=base_url)
def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""

class Model(BaseModel):
foo: int

respx_mock.get("/foo").mock(
return_value=httpx.Response(
200,
content=json.dumps({"foo": 2}),
headers={"Content-Type": "application/text"},
)
)

response = self.client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2

def test_base_url_env(self) -> None:
with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
client = OpenAI(api_key=api_key, _strict_response_validation=True)
Expand Down Expand Up @@ -939,6 +960,27 @@ class Model2(BaseModel):
assert isinstance(response, Model1)
assert response.foo == 1

@pytest.mark.respx(base_url=base_url)
async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""

class Model(BaseModel):
foo: int

respx_mock.get("/foo").mock(
return_value=httpx.Response(
200,
content=json.dumps({"foo": 2}),
headers={"Content-Type": "application/text"},
)
)

response = await self.client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2

def test_base_url_env(self) -> None:
with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
Expand Down

0 comments on commit c26014e

Please sign in to comment.