Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): allow binary returns #164

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/finch/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
import json
import time
import uuid
Expand Down Expand Up @@ -60,6 +61,7 @@
RequestOptions,
UnknownResponse,
ModelBuilderProtocol,
BinaryResponseContent,
)
from ._utils import is_dict, is_given, is_mapping
from ._compat import model_copy, model_dump
Expand Down Expand Up @@ -1672,3 +1674,94 @@ def _merge_mappings(
"""
merged = {**obj1, **obj2}
return {key: value for key, value in merged.items() if not isinstance(value, Omit)}


class HttpxBinaryResponseContent(BinaryResponseContent):
response: httpx.Response

def __init__(self, response: httpx.Response) -> None:
self.response = response

@property
@override
def content(self) -> bytes:
return self.response.content

@property
@override
def text(self) -> str:
return self.response.text

@property
@override
def encoding(self) -> Optional[str]:
return self.response.encoding

@property
@override
def charset_encoding(self) -> Optional[str]:
return self.response.charset_encoding

@override
def json(self, **kwargs: Any) -> Any:
return self.response.json(**kwargs)

@override
def read(self) -> bytes:
return self.response.read()

@override
def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
return self.response.iter_bytes(chunk_size)

@override
def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]:
return self.response.iter_text(chunk_size)

@override
def iter_lines(self) -> Iterator[str]:
return self.response.iter_lines()

@override
def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
return self.response.iter_raw(chunk_size)

@override
def stream_to_file(self, file: str | os.PathLike[str]) -> None:
with open(file, mode="wb") as f:
for data in self.response.iter_bytes():
f.write(data)

@override
def close(self) -> None:
return self.response.close()

@override
async def aread(self) -> bytes:
return await self.response.aread()

@override
async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
return self.response.aiter_bytes(chunk_size)

@override
async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]:
return self.response.aiter_text(chunk_size)

@override
async def aiter_lines(self) -> AsyncIterator[str]:
return self.response.aiter_lines()

@override
async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
return self.response.aiter_raw(chunk_size)

@override
async def astream_to_file(self, file: str | os.PathLike[str]) -> None:
with open(file, mode="wb") as f:
async for data in self.response.aiter_bytes():
f.write(data)

@override
async def aclose(self) -> None:
return await self.response.aclose()
5 changes: 4 additions & 1 deletion src/finch/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import httpx
import pydantic

from ._types import NoneType, UnknownResponse
from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._utils import is_given
from ._models import BaseModel
from ._constants import RAW_RESPONSE_HEADER
Expand Down Expand Up @@ -135,6 +135,9 @@ def _parse(self) -> R:

origin = get_origin(cast_to) or cast_to

if inspect.isclass(origin) and issubclass(origin, BinaryResponseContent):
return cast(R, cast_to(response)) # type: ignore

if origin == APIResponse:
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")

Expand Down
151 changes: 149 additions & 2 deletions src/finch/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from os import PathLike
from abc import ABC, abstractmethod
from typing import (
IO,
TYPE_CHECKING,
Expand All @@ -13,8 +14,10 @@
Mapping,
TypeVar,
Callable,
Iterator,
Optional,
Sequence,
AsyncIterator,
)
from typing_extensions import (
Literal,
Expand All @@ -25,7 +28,6 @@
runtime_checkable,
)

import httpx
import pydantic
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport

Expand All @@ -40,6 +42,151 @@
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")


class BinaryResponseContent(ABC):
def __init__(
self,
response: Any,
) -> None:
...

@property
@abstractmethod
def content(self) -> bytes:
pass

@property
@abstractmethod
def text(self) -> str:
pass

@property
@abstractmethod
def encoding(self) -> Optional[str]:
"""
Return an encoding to use for decoding the byte content into text.
The priority for determining this is given by...

* `.encoding = <>` has been set explicitly.
* The encoding as specified by the charset parameter in the Content-Type header.
* The encoding as determined by `default_encoding`, which may either be
a string like "utf-8" indicating the encoding to use, or may be a callable
which enables charset autodetection.
"""
pass

@property
@abstractmethod
def charset_encoding(self) -> Optional[str]:
"""
Return the encoding, as specified by the Content-Type header.
"""
pass

@abstractmethod
def json(self, **kwargs: Any) -> Any:
pass

@abstractmethod
def read(self) -> bytes:
"""
Read and return the response content.
"""
pass

@abstractmethod
def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
pass

@abstractmethod
def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
pass

@abstractmethod
def iter_lines(self) -> Iterator[str]:
pass

@abstractmethod
def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
pass

@abstractmethod
def stream_to_file(self, file: str | PathLike[str]) -> None:
"""
Stream the output to the given file.
"""
pass

@abstractmethod
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
pass

@abstractmethod
async def aread(self) -> bytes:
"""
Read and return the response content.
"""
pass

@abstractmethod
async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
pass

@abstractmethod
async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
pass

@abstractmethod
async def aiter_lines(self) -> AsyncIterator[str]:
pass

@abstractmethod
async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
pass

async def astream_to_file(self, file: str | PathLike[str]) -> None:
"""
Stream the output to the given file.
"""
pass

@abstractmethod
async def aclose(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
pass


# Approximates httpx internal ProxiesTypes and RequestFiles types
# while adding support for `PathLike` instances
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
Expand Down Expand Up @@ -181,7 +328,7 @@ def get(self, __key: str) -> str | None:

ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], httpx.Response, UnknownResponse, ModelBuilderProtocol]",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
)

StrBytesIntFloat = Union[str, bytes, int, float]
Expand Down
31 changes: 27 additions & 4 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,23 @@ class TestFinch:

@pytest.mark.respx(base_url=base_url)
def test_raw_response(self, respx_mock: MockRouter) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))

response = self.client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
assert response.json() == '{"foo": "bar"}'

@pytest.mark.respx(base_url=base_url)
def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)

response = self.client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == '{"foo": "bar"}'

def test_copy(self) -> None:
copied = self.client.copy()
Expand Down Expand Up @@ -672,12 +683,24 @@ class TestAsyncFinch:
@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_raw_response(self, respx_mock: MockRouter) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))

response = await self.client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == '{"foo": "bar"}'

@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)

response = await self.client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
assert response.json() == '{"foo": "bar"}'

def test_copy(self) -> None:
copied = self.client.copy()
Expand Down