Skip to content

Commit

Permalink
fix(client): ensure retried requests are closed (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Nov 30, 2023
1 parent b5b9f79 commit 0659932
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 21 deletions.
100 changes: 80 additions & 20 deletions src/finch/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
DEFAULT_TIMEOUT,
DEFAULT_MAX_RETRIES,
RAW_RESPONSE_HEADER,
STREAMED_RAW_RESPONSE_HEADER,
)
from ._streaming import Stream, AsyncStream
from ._exceptions import (
Expand Down Expand Up @@ -363,14 +364,21 @@ def _make_status_error_from_response(
self,
response: httpx.Response,
) -> APIStatusError:
err_text = response.text.strip()
body = err_text
if response.is_closed and not response.is_stream_consumed:
# We can't read the response body as it has been closed
# before it was read. This can happen if an event hook
# raises a status error.
body = None
err_msg = f"Error code: {response.status_code}"
else:
err_text = response.text.strip()
body = err_text

try:
body = json.loads(err_text)
err_msg = f"Error code: {response.status_code} - {body}"
except Exception:
err_msg = err_text or f"Error code: {response.status_code}"
try:
body = json.loads(err_text)
err_msg = f"Error code: {response.status_code} - {body}"
except Exception:
err_msg = err_text or f"Error code: {response.status_code}"

return self._make_status_error(err_msg, body=body, response=response)

Expand Down Expand Up @@ -534,6 +542,12 @@ def _process_response_data(
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, body=data) from err

def _should_stream_response_body(self, *, request: httpx.Request) -> bool:
if request.headers.get(STREAMED_RAW_RESPONSE_HEADER) == "true":
return True

return False

@property
def qs(self) -> Querystring:
return Querystring()
Expand Down Expand Up @@ -606,7 +620,7 @@ def _calculate_retry_timeout(
if response_headers is not None:
retry_header = response_headers.get("retry-after")
try:
retry_after = int(retry_header)
retry_after = float(retry_header)
except Exception:
retry_date_tuple = email.utils.parsedate_tz(retry_header)
if retry_date_tuple is None:
Expand Down Expand Up @@ -862,14 +876,21 @@ def _request(
request = self._build_request(options)
self._prepare_request(request)

response = None

try:
response = self._client.send(request, auth=self.custom_auth, stream=stream)
response = self._client.send(
request,
auth=self.custom_auth,
stream=stream or self._should_stream_response_body(request=request),
)
log.debug(
'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
)
response.raise_for_status()
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
if retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
options,
cast_to,
Expand All @@ -881,27 +902,39 @@ def _request(

# If the response is streamed then we need to explicitly read the response
# to completion before attempting to access the response text.
err.response.read()
if not err.response.is_closed:
err.response.read()

raise self._make_status_error_from_response(err.response) from None
except httpx.TimeoutException as err:
if response is not None:
response.close()

if retries > 0:
return self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=response.headers if response is not None else None,
)

raise APITimeoutError(request=request) from err
except Exception as err:
if response is not None:
response.close()

if retries > 0:
return self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=response.headers if response is not None else None,
)

raise APIConnectionError(request=request) from err

return self._process_response(
Expand All @@ -917,7 +950,7 @@ def _retry_request(
options: FinalRequestOptions,
cast_to: Type[ResponseT],
remaining_retries: int,
response_headers: Optional[httpx.Headers] = None,
response_headers: httpx.Headers | None,
*,
stream: bool,
stream_cls: type[_StreamT] | None,
Expand Down Expand Up @@ -1303,14 +1336,21 @@ async def _request(
request = self._build_request(options)
await self._prepare_request(request)

response = None

try:
response = await self._client.send(request, auth=self.custom_auth, stream=stream)
response = await self._client.send(
request,
auth=self.custom_auth,
stream=stream or self._should_stream_response_body(request=request),
)
log.debug(
'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
)
response.raise_for_status()
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
if retries > 0 and self._should_retry(err.response):
await err.response.aclose()
return await self._retry_request(
options,
cast_to,
Expand All @@ -1322,19 +1362,39 @@ async def _request(

# If the response is streamed then we need to explicitly read the response
# to completion before attempting to access the response text.
await err.response.aread()
if not err.response.is_closed:
await err.response.aread()

raise self._make_status_error_from_response(err.response) from None
except httpx.ConnectTimeout as err:
if retries > 0:
return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
raise APITimeoutError(request=request) from err
except httpx.TimeoutException as err:
if response is not None:
await response.aclose()

if retries > 0:
return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
return await self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=response.headers if response is not None else None,
)

raise APITimeoutError(request=request) from err
except Exception as err:
if response is not None:
await response.aclose()

if retries > 0:
return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
return await self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
response_headers=response.headers if response is not None else None,
)

raise APIConnectionError(request=request) from err

return self._process_response(
Expand All @@ -1350,7 +1410,7 @@ async def _retry_request(
options: FinalRequestOptions,
cast_to: Type[ResponseT],
remaining_retries: int,
response_headers: Optional[httpx.Headers] = None,
response_headers: httpx.Headers | None,
*,
stream: bool,
stream_cls: type[_AsyncStreamT] | None,
Expand Down
1 change: 1 addition & 0 deletions src/finch/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import httpx

RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
STREAMED_RAW_RESPONSE_HEADER = "X-Stainless-Streamed-Raw-Response"

# default timeout is 1 minute
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0)
Expand Down
53 changes: 52 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from finch._types import Omit
from finch._client import Finch, AsyncFinch
from finch._models import BaseModel, FinalRequestOptions
from finch._exceptions import APIResponseValidationError
from finch._exceptions import APIStatusError, APIResponseValidationError
from finch._base_client import (
DEFAULT_TIMEOUT,
HTTPX_DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -704,6 +704,31 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]

@pytest.mark.respx(base_url=base_url)
def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))

def on_response(response: httpx.Response) -> None:
raise httpx.HTTPStatusError(
"Simulating an error inside httpx",
response=response,
request=response.request,
)

client = Finch(
base_url=base_url,
access_token=access_token,
_strict_response_validation=True,
http_client=httpx.Client(
event_hooks={
"response": [on_response],
}
),
max_retries=0,
)
with pytest.raises(APIStatusError):
client.post("/foo", cast_to=httpx.Response)


class TestAsyncFinch:
client = AsyncFinch(base_url=base_url, access_token=access_token, _strict_response_validation=True)
Expand Down Expand Up @@ -1377,3 +1402,29 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]

@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))

def on_response(response: httpx.Response) -> None:
raise httpx.HTTPStatusError(
"Simulating an error inside httpx",
response=response,
request=response.request,
)

client = AsyncFinch(
base_url=base_url,
access_token=access_token,
_strict_response_validation=True,
http_client=httpx.AsyncClient(
event_hooks={
"response": [on_response],
}
),
max_retries=0,
)
with pytest.raises(APIStatusError):
await client.post("/foo", cast_to=httpx.Response)

0 comments on commit 0659932

Please sign in to comment.