From 0b3e2bd5ebba12f3a6a93d42e169d6afecf62441 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 8 Oct 2020 19:56:16 +0200 Subject: [PATCH 1/3] Switch to request context manager interface --- docs/advanced.md | 3 +- docs/api.md | 6 +- docs/async.md | 18 +- docs/exceptions.md | 4 - httpx/__init__.py | 2 - httpx/_api.py | 46 +-- httpx/_client.py | 472 +++++++++++++++--------------- httpx/_compat.py | 5 + httpx/_exceptions.py | 15 - httpx/_models.py | 30 -- httpx/_utils.py | 32 ++ tests/client/test_async_client.py | 3 +- tests/client/test_client.py | 6 +- tests/client/test_event_hooks.py | 8 +- tests/client/test_redirects.py | 14 +- tests/models/test_responses.py | 35 --- 16 files changed, 325 insertions(+), 374 deletions(-) create mode 100644 httpx/_compat.py diff --git a/docs/advanced.md b/docs/advanced.md index 1b0ecee7c5..8c9bcc3290 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -200,7 +200,8 @@ To dispatch a `Request` instance across to the network, create a [`Client` insta ```python with httpx.Client() as client: - response = client.send(request) + with client.send(request) as response: + response.read() ... ``` diff --git a/docs/api.md b/docs/api.md index 817467ec55..896ca170d5 100644 --- a/docs/api.md +++ b/docs/api.md @@ -94,7 +94,11 @@ what gets sent over the wire.* ```pycon >>> request = httpx.Request("GET", "https://example.org", headers={'host': 'example.org'}) ->>> response = client.send(request) +>>> with client.send(request) as response: +... response.read() +... +>>> response.status_code +200 ``` * `def __init__(method, url, [params], [data], [json], [headers], [cookies])` diff --git a/docs/async.md b/docs/async.md index ba6c6ad704..fd388ef57a 100644 --- a/docs/async.md +++ b/docs/async.md @@ -80,11 +80,12 @@ The async response streaming methods are: * `Response.aiter_raw()` - For streaming the raw response bytes, without applying content decoding. * `Response.aclose()` - For closing the response. You don't usually need this, since `.stream` block closes the response automatically on exit. -For situations when context block usage is not practical, it is possible to enter "manual mode" by sending a [`Request` instance](./advanced.md#request-instances) using `client.send(..., stream=True)`. +For situations when context block usage is not practical, it is possible to enter "manual mode" by sending a [`Request` instance](./advanced.md#request-instances) using `client.send(...)`. Example in the context of forwarding the response to a streaming web endpoint with [Starlette](https://www.starlette.io): ```python +import contextlib import httpx from starlette.background import BackgroundTask from starlette.responses import StreamingResponse @@ -93,12 +94,19 @@ client = httpx.AsyncClient() async def home(request): req = client.build_request("GET", "https://www.example.com/") - r = await client.send(req, stream=True) - return StreamingResponse(r.aiter_text(), background=BackgroundTask(r.aclose)) + exit_stack = contextlib.AsyncExitStack() + r = await exit_stack.enter_async_context(client.send(req)) + return StreamingResponse(r.aiter_text(), background=BackgroundTask(exit_stack.aclose)) ``` -!!! warning - When using this "manual streaming mode", it is your duty as a developer to make sure that `Response.aclose()` is called eventually. Failing to do so would leave connections open, most likely resulting in resource leaks down the line. +**Note**: When using this "manual streaming mode", it is your duty as a developer to make sure that the response is eventually properly closed. Failing to do so would leave connections open, most likely resulting in resource leaks down the line. In the above example, we use an `AsyncExitStack` to properly enter and then clean up the context manager returned by `client.send()`. This approach is the least error-prone. Alternatively, you could enter the context manually, and call `.aclose()` on the response: + +```python +async def home(request): + req = client.build_request("GET", "https://www.example.com/") + r = await client.send(req).__aenter__() + return StreamingResponse(r.aiter_text(), background=BackgroundTask(r.aclose)) +``` ### Streaming requests diff --git a/docs/exceptions.md b/docs/exceptions.md index 1bc86100db..170fe1e215 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -81,7 +81,6 @@ except httpx.HTTPStatusError as exc: * StreamConsumed * ResponseNotRead * RequestNotRead - * ResponseClosed --- @@ -167,6 +166,3 @@ except httpx.HTTPStatusError as exc: ::: httpx.RequestNotRead :docstring: - -::: httpx.ResponseClosed - :docstring: diff --git a/httpx/__init__.py b/httpx/__init__.py index 489ffeae4f..78e3f22dfc 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -22,7 +22,6 @@ RemoteProtocolError, RequestError, RequestNotRead, - ResponseClosed, ResponseNotRead, StreamConsumed, StreamError, @@ -83,7 +82,6 @@ "RequestError", "RequestNotRead", "Response", - "ResponseClosed", "ResponseNotRead", "StatusCode", "stream", diff --git a/httpx/_api.py b/httpx/_api.py index 985e2ab938..2cd780c844 100644 --- a/httpx/_api.py +++ b/httpx/_api.py @@ -1,8 +1,9 @@ import typing +from contextlib import contextmanager -from ._client import Client, StreamContextManager +from ._client import Client from ._config import DEFAULT_TIMEOUT_CONFIG -from ._models import Request, Response +from ._models import Response from ._types import ( AuthTypes, CertTypes, @@ -105,6 +106,7 @@ def request( ) +@contextmanager def stream( method: str, url: URLTypes, @@ -123,7 +125,7 @@ def stream( verify: VerifyTypes = True, cert: CertTypes = None, trust_env: bool = True, -) -> StreamContextManager: +) -> typing.Iterator[Response]: """ Alternative to `httpx.request()` that streams the response body instead of loading it into memory at once. @@ -134,26 +136,24 @@ def stream( [0]: /quickstart#streaming-responses """ - client = Client(proxies=proxies, cert=cert, verify=verify, trust_env=trust_env) - request = Request( - method=method, - url=url, - params=params, - content=content, - data=data, - files=files, - json=json, - headers=headers, - cookies=cookies, - ) - return StreamContextManager( - client=client, - request=request, - auth=auth, - timeout=timeout, - allow_redirects=allow_redirects, - close_client=True, - ) + with Client( + proxies=proxies, cert=cert, verify=verify, trust_env=trust_env + ) as client: + with client.stream( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + allow_redirects=allow_redirects, + timeout=timeout, + ) as response: + yield response def get( diff --git a/httpx/_client.py b/httpx/_client.py index d15c004530..9ce956acfb 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -2,12 +2,14 @@ import enum import typing import warnings +from contextlib import contextmanager from types import TracebackType import httpcore from .__version__ import __version__ from ._auth import Auth, BasicAuth, FunctionAuth +from ._compat import asynccontextmanager from ._config import ( DEFAULT_LIMITS, DEFAULT_MAX_REDIRECTS, @@ -50,6 +52,8 @@ NetRCInfo, Timer, URLPattern, + ensure_async_context_manager, + ensure_context_manager, get_environment_proxies, get_logger, same_origin, @@ -234,51 +238,6 @@ def params(self) -> QueryParams: def params(self, params: QueryParamTypes) -> None: self._params = QueryParams(params) - def stream( - self, - method: str, - url: URLTypes, - *, - content: RequestContent = None, - data: RequestData = None, - files: RequestFiles = None, - json: typing.Any = None, - params: QueryParamTypes = None, - headers: HeaderTypes = None, - cookies: CookieTypes = None, - auth: typing.Union[AuthTypes, UnsetType] = UNSET, - allow_redirects: bool = True, - timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, - ) -> "StreamContextManager": - """ - Alternative to `httpx.request()` that streams the response body - instead of loading it into memory at once. - - **Parameters**: See `httpx.request`. - - See also: [Streaming Responses][0] - - [0]: /quickstart#streaming-responses - """ - request = self.build_request( - method=method, - url=url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - ) - return StreamContextManager( - client=self, - request=request, - auth=auth, - allow_redirects=allow_redirects, - timeout=timeout, - ) - def build_request( self, method: str, @@ -710,7 +669,9 @@ def request( ```python request = client.build_request(...) - response = client.send(request, ...) + with client.send(request, ...) as response: + response.read() + # Use `response`... ``` See `Client.build_request()`, `Client.send()` and @@ -730,19 +691,72 @@ def request( headers=headers, cookies=cookies, ) - return self.send( - request, auth=auth, allow_redirects=allow_redirects, timeout=timeout + + with self.send( + request, + auth=auth, + allow_redirects=allow_redirects, + timeout=timeout, + ) as response: + response.read() + + return response + + @contextmanager + def stream( + self, + method: str, + url: URLTypes, + *, + content: RequestContent = None, + data: RequestData = None, + files: RequestFiles = None, + json: typing.Any = None, + params: QueryParamTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + auth: typing.Union[AuthTypes, UnsetType] = UNSET, + allow_redirects: bool = True, + timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, + ) -> typing.Iterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, ) + with self.send( + request, + auth=auth, + allow_redirects=allow_redirects, + timeout=timeout, + ) as response: + yield response + @contextmanager def send( self, request: Request, *, - stream: bool = False, auth: typing.Union[AuthTypes, UnsetType] = UNSET, allow_redirects: bool = True, timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, - ) -> Response: + ) -> typing.Iterator[Response]: """ Send a request. @@ -764,29 +778,18 @@ def send( auth = self._build_request_auth(request, auth) - response = self._send_handling_auth( + with self._send_handling_auth( request, auth=auth, timeout=timeout, allow_redirects=allow_redirects, history=[], - ) - - if not stream: - try: - response.read() - finally: - response.close() - - try: + ) as response: for hook in self._event_hooks["response"]: hook(response) - except Exception: - response.close() - raise - - return response + yield response + @contextmanager def _send_handling_auth( self, request: Request, @@ -794,7 +797,7 @@ def _send_handling_auth( timeout: Timeout, allow_redirects: bool, history: typing.List[Response], - ) -> Response: + ) -> typing.Iterator[Response]: auth_flow = auth.sync_auth_flow(request) request = next(auth_flow) @@ -802,54 +805,58 @@ def _send_handling_auth( hook(request) while True: - response = self._send_handling_redirects( + with self._send_handling_redirects( request, timeout=timeout, allow_redirects=allow_redirects, history=history, - ) - try: - next_request = auth_flow.send(response) - except StopIteration: - return response - except BaseException as exc: - response.close() - raise exc from None - else: + ) as response: + try: + next_request = auth_flow.send(response) + except StopIteration: + yield response + break + response.history = list(history) response.read() request = next_request history.append(response) + @contextmanager def _send_handling_redirects( self, request: Request, timeout: Timeout, allow_redirects: bool, history: typing.List[Response], - ) -> Response: + ) -> typing.Iterator[Response]: while True: if len(history) > self.max_redirects: raise TooManyRedirects( "Exceeded maximum allowed redirects.", request=request ) - response = self._send_single_request(request, timeout) - response.history = list(history) + with self._send_single_request(request, timeout) as response: + response.history = list(history) - if not response.is_redirect: - return response + if not response.is_redirect: + yield response + break - if allow_redirects: - response.read() - request = self._build_redirect_request(request, response) - history = history + [response] + if allow_redirects: + response.read() + request = self._build_redirect_request(request, response) + history = history + [response] - if not allow_redirects: - response.next_request = request - return response + if not allow_redirects: + response.next_request = request + yield response + break - def _send_single_request(self, request: Request, timeout: Timeout) -> Response: + @contextmanager + def _send_single_request( + self, request: Request, timeout: Timeout + ) -> typing.Iterator[Response]: """ Sends a single request, without handling any redirections. """ @@ -858,35 +865,34 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response: timer.sync_start() with map_exceptions(HTTPCORE_EXC_MAP, request=request): - (status_code, headers, stream, ext) = transport.request( - request.method.encode(), - request.url.raw, - headers=request.headers.raw, - stream=request.stream, # type: ignore - ext={"timeout": timeout.as_dict()}, - ) + with ensure_context_manager( + transport.request( + request.method.encode(), + request.url.raw, + headers=request.headers.raw, + stream=request.stream, # type: ignore + ext={"timeout": timeout.as_dict()}, + ) + ) as (status_code, headers, stream, ext): + response = Response( + status_code, + headers=headers, + stream=stream, # type: ignore + ext=ext, + request=request, + ) - def on_close(response: Response) -> None: - response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed()) - if hasattr(stream, "close"): - stream.close() + self.cookies.extract_cookies(response) - response = Response( - status_code, - headers=headers, - stream=stream, # type: ignore - ext=ext, - request=request, - on_close=on_close, - ) - - self.cookies.extract_cookies(response) + status = f"{response.status_code} {response.reason_phrase}" + response_line = f"{response.http_version} {status}" + logger.debug( + f'HTTP Request: {request.method} {request.url} "{response_line}"' + ) - status = f"{response.status_code} {response.reason_phrase}" - response_line = f"{response.http_version} {status}" - logger.debug(f'HTTP Request: {request.method} {request.url} "{response_line}"') + yield response - return response + response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed()) def get( self, @@ -1348,7 +1354,9 @@ async def request( ```python request = client.build_request(...) - response = await client.send(request, ...) + async with client.send(request, ...) as response: + await response.aread() + # Use `response`... ``` See `AsyncClient.build_request()`, `AsyncClient.send()` @@ -1368,20 +1376,67 @@ async def request( headers=headers, cookies=cookies, ) - response = await self.send( + async with self.send( request, auth=auth, allow_redirects=allow_redirects, timeout=timeout - ) + ) as response: + await response.aread() return response + @asynccontextmanager + async def stream( + self, + method: str, + url: URLTypes, + *, + content: RequestContent = None, + data: RequestData = None, + files: RequestFiles = None, + json: typing.Any = None, + params: QueryParamTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + auth: typing.Union[AuthTypes, UnsetType] = UNSET, + allow_redirects: bool = True, + timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, + ) -> typing.AsyncIterator[Response]: + """ + Alternative to `httpx.request()` that streams the response body + instead of loading it into memory at once. + + **Parameters**: See `httpx.request`. + + See also: [Streaming Responses][0] + + [0]: /quickstart#streaming-responses + """ + request = self.build_request( + method=method, + url=url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + ) + async with self.send( + request, + auth=auth, + allow_redirects=allow_redirects, + timeout=timeout, + ) as response: + yield response + + @asynccontextmanager async def send( self, request: Request, *, - stream: bool = False, auth: typing.Union[AuthTypes, UnsetType] = UNSET, allow_redirects: bool = True, timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, - ) -> Response: + ) -> typing.AsyncIterator[Response]: """ Send a request. @@ -1403,29 +1458,18 @@ async def send( auth = self._build_request_auth(request, auth) - response = await self._send_handling_auth( + async with self._send_handling_auth( request, auth=auth, timeout=timeout, allow_redirects=allow_redirects, history=[], - ) - - if not stream: - try: - await response.aread() - finally: - await response.aclose() - - try: + ) as response: for hook in self._event_hooks["response"]: await hook(response) - except Exception: - await response.aclose() - raise - - return response + yield response + @asynccontextmanager async def _send_handling_auth( self, request: Request, @@ -1433,7 +1477,7 @@ async def _send_handling_auth( timeout: Timeout, allow_redirects: bool, history: typing.List[Response], - ) -> Response: + ) -> typing.AsyncIterator[Response]: auth_flow = auth.async_auth_flow(request) request = await auth_flow.__anext__() @@ -1441,56 +1485,58 @@ async def _send_handling_auth( await hook(request) while True: - response = await self._send_handling_redirects( + async with self._send_handling_redirects( request, timeout=timeout, allow_redirects=allow_redirects, history=history, - ) - try: - next_request = await auth_flow.asend(response) - except StopAsyncIteration: - return response - except BaseException as exc: - await response.aclose() - raise exc from None - else: + ) as response: + try: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: + yield response + break + response.history = list(history) await response.aread() request = next_request history.append(response) + @asynccontextmanager async def _send_handling_redirects( self, request: Request, timeout: Timeout, allow_redirects: bool, history: typing.List[Response], - ) -> Response: + ) -> typing.AsyncIterator[Response]: while True: if len(history) > self.max_redirects: raise TooManyRedirects( "Exceeded maximum allowed redirects.", request=request ) - response = await self._send_single_request(request, timeout) - response.history = list(history) + async with self._send_single_request(request, timeout) as response: + response.history = list(history) - if not response.is_redirect: - return response + if not response.is_redirect: + yield response + break - if allow_redirects: - await response.aread() - request = self._build_redirect_request(request, response) - history = history + [response] + if allow_redirects: + await response.aread() + request = self._build_redirect_request(request, response) + history = history + [response] - if not allow_redirects: - response.next_request = request - return response + if not allow_redirects: + response.next_request = request + yield response + break + @asynccontextmanager async def _send_single_request( self, request: Request, timeout: Timeout - ) -> Response: + ) -> typing.AsyncIterator[Response]: """ Sends a single request, without handling any redirections. """ @@ -1499,35 +1545,34 @@ async def _send_single_request( await timer.async_start() with map_exceptions(HTTPCORE_EXC_MAP, request=request): - (status_code, headers, stream, ext,) = await transport.arequest( - request.method.encode(), - request.url.raw, - headers=request.headers.raw, - stream=request.stream, # type: ignore - ext={"timeout": timeout.as_dict()}, - ) - - async def on_close(response: Response) -> None: - response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed()) - if hasattr(stream, "aclose"): - await stream.aclose() + async with ensure_async_context_manager( + transport.arequest( + request.method.encode(), + request.url.raw, + headers=request.headers.raw, + stream=request.stream, # type: ignore + ext={"timeout": timeout.as_dict()}, + ) + ) as (status_code, headers, stream, ext): + response = Response( + status_code, + headers=headers, + stream=stream, # type: ignore + ext=ext, + request=request, + ) - response = Response( - status_code, - headers=headers, - stream=stream, # type: ignore - ext=ext, - request=request, - on_close=on_close, - ) + self.cookies.extract_cookies(response) - self.cookies.extract_cookies(response) + status = f"{response.status_code} {response.reason_phrase}" + response_line = f"{response.http_version} {status}" + logger.debug( + f'HTTP Request: {request.method} {request.url} "{response_line}"' + ) - status = f"{response.status_code} {response.reason_phrase}" - response_line = f"{response.http_version} {status}" - logger.debug(f'HTTP Request: {request.method} {request.url} "{response_line}"') + yield response - return response + response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed()) async def get( self, @@ -1783,64 +1828,3 @@ def __del__(self) -> None: "See https://www.python-httpx.org/async/#opening-and-closing-clients " "for details." ) - - -class StreamContextManager: - def __init__( - self, - client: BaseClient, - request: Request, - *, - auth: typing.Union[AuthTypes, UnsetType] = UNSET, - allow_redirects: bool = True, - timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, - close_client: bool = False, - ) -> None: - self.client = client - self.request = request - self.auth = auth - self.allow_redirects = allow_redirects - self.timeout = timeout - self.close_client = close_client - - def __enter__(self) -> "Response": - assert isinstance(self.client, Client) - self.response = self.client.send( - request=self.request, - auth=self.auth, - allow_redirects=self.allow_redirects, - timeout=self.timeout, - stream=True, - ) - return self.response - - def __exit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - assert isinstance(self.client, Client) - self.response.close() - if self.close_client: - self.client.close() - - async def __aenter__(self) -> "Response": - assert isinstance(self.client, AsyncClient) - self.response = await self.client.send( - request=self.request, - auth=self.auth, - allow_redirects=self.allow_redirects, - timeout=self.timeout, - stream=True, - ) - return self.response - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - assert isinstance(self.client, AsyncClient) - await self.response.aclose() diff --git a/httpx/_compat.py b/httpx/_compat.py new file mode 100644 index 0000000000..ff9c1117dc --- /dev/null +++ b/httpx/_compat.py @@ -0,0 +1,5 @@ +try: + from contextlib import asynccontextmanager # type: ignore # Py3.6 +except ImportError: # pragma: no cover + # Python 3.6 + from async_generator import asynccontextmanager # type: ignore # noqa: F401 diff --git a/httpx/_exceptions.py b/httpx/_exceptions.py index bade9f9b81..454631fc06 100644 --- a/httpx/_exceptions.py +++ b/httpx/_exceptions.py @@ -29,7 +29,6 @@ x StreamConsumed x ResponseNotRead x RequestNotRead - x ResponseClosed """ import contextlib import typing @@ -303,20 +302,6 @@ def __init__(self) -> None: super().__init__(message) -class ResponseClosed(StreamError): - """ - Attempted to read or stream response content, but the request has been - closed. - """ - - def __init__(self) -> None: - message = ( - "Attempted to read or stream response content, but the request has " - "been closed." - ) - super().__init__(message) - - @contextlib.contextmanager def map_exceptions( mapping: typing.Mapping[typing.Type[Exception], typing.Type[Exception]], diff --git a/httpx/_models.py b/httpx/_models.py index c981c740bf..15cf15ee51 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -28,7 +28,6 @@ HTTPStatusError, InvalidURL, RequestNotRead, - ResponseClosed, ResponseNotRead, StreamConsumed, map_exceptions, @@ -902,7 +901,6 @@ def __init__( request: Request = None, ext: dict = None, history: typing.List["Response"] = None, - on_close: typing.Callable = None, ): self.status_code = status_code self.headers = Headers(headers) @@ -917,9 +915,7 @@ def __init__( self.ext = {} if ext is None else ext self.history = [] if history is None else list(history) - self._on_close = on_close - self.is_closed = False self.is_stream_consumed = False if stream is not None: @@ -1203,8 +1199,6 @@ def iter_raw(self) -> typing.Iterator[bytes]: """ if self.is_stream_consumed: raise StreamConsumed() - if self.is_closed: - raise ResponseClosed() if not isinstance(self.stream, typing.Iterable): raise RuntimeError("Attempted to call a sync iterator on an async stream.") @@ -1214,17 +1208,6 @@ def iter_raw(self) -> typing.Iterator[bytes]: for part in self.stream: self._num_bytes_downloaded += len(part) yield part - self.close() - - def close(self) -> None: - """ - Close the response and release the connection. - Automatically called if the response body is read to completion. - """ - if not self.is_closed: - self.is_closed = True - if self._on_close is not None: - self._on_close(self) async def aread(self) -> bytes: """ @@ -1275,8 +1258,6 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]: """ if self.is_stream_consumed: raise StreamConsumed() - if self.is_closed: - raise ResponseClosed() if not isinstance(self.stream, typing.AsyncIterable): raise RuntimeError("Attempted to call a async iterator on a sync stream.") @@ -1286,17 +1267,6 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]: async for part in self.stream: self._num_bytes_downloaded += len(part) yield part - await self.aclose() - - async def aclose(self) -> None: - """ - Close the response and release the connection. - Automatically called if the response body is read to completion. - """ - if not self.is_closed: - self.is_closed = True - if self._on_close is not None: - await self._on_close(self) class Cookies(MutableMapping): diff --git a/httpx/_utils.py b/httpx/_utils.py index 072db3f1e8..99399b64da 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -9,6 +9,7 @@ import time import typing import warnings +from contextlib import contextmanager from pathlib import Path from urllib.request import getproxies @@ -19,6 +20,8 @@ if typing.TYPE_CHECKING: # pragma: no cover from ._models import URL +T = typing.TypeVar("T") + _HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} _HTML5_FORM_ENCODING_REPLACEMENTS.update( @@ -539,3 +542,32 @@ def __eq__(self, other: typing.Any) -> bool: def warn_deprecated(message: str) -> None: # pragma: nocover warnings.warn(message, DeprecationWarning, stacklevel=2) + + +@contextmanager +def ensure_context_manager( + value: typing.Union[T, typing.ContextManager[T]] +) -> typing.Iterator[T]: + if isinstance(value, typing.ContextManager): + with value as val: + yield val + else: + yield value + + +# mypy isn't able to resolve generics when using @asynccontextmanager here, but we'd +# *really* like it to resolve generics. +class ensure_async_context_manager(typing.AsyncContextManager[T]): + def __init__( + self, value: typing.Union[typing.Awaitable[T], typing.AsyncContextManager[T]] + ) -> None: + self._value = value + + async def __aenter__(self) -> T: + if isinstance(self._value, typing.AsyncContextManager): + return await self._value.__aenter__() + return await self._value + + async def __aexit__(self, *args: typing.Any) -> None: + if isinstance(self._value, typing.AsyncContextManager): + await self._value.__aexit__(*args) diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 44ff90fe51..a7b1414168 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -42,7 +42,8 @@ async def test_build_request(server): async with httpx.AsyncClient() as client: request = client.build_request("GET", url) request.headers.update(headers) - response = await client.send(request) + async with client.send(request) as response: + await response.aread() assert response.status_code == 200 assert response.url == url diff --git a/tests/client/test_client.py b/tests/client/test_client.py index a41f4232fb..345433727c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -45,7 +45,8 @@ def test_build_request(server): with httpx.Client() as client: request = client.build_request("GET", url) request.headers.update(headers) - response = client.send(request) + with client.send(request) as response: + response.read() assert response.status_code == 200 assert response.url == url @@ -60,7 +61,8 @@ def test_build_post_request(server): with httpx.Client() as client: request = client.build_request("POST", url) request.headers.update(headers) - response = client.send(request) + with client.send(request) as response: + response.read() assert response.status_code == 200 assert response.url == url diff --git a/tests/client/test_event_hooks.py b/tests/client/test_event_hooks.py index a81f31e1e5..4f8efe19c7 100644 --- a/tests/client/test_event_hooks.py +++ b/tests/client/test_event_hooks.py @@ -54,10 +54,8 @@ def raise_on_4xx_5xx(response): event_hooks = {"response": [raise_on_4xx_5xx]} with httpx.Client(event_hooks=event_hooks, transport=MockTransport(app)) as http: - try: + with pytest.raises(httpx.HTTPStatusError): http.get("http://127.0.0.1:8000/status/400") - except httpx.HTTPStatusError as exc: - assert exc.response.is_closed @pytest.mark.usefixtures("async_environment") @@ -106,10 +104,8 @@ async def raise_on_4xx_5xx(response): async with httpx.AsyncClient( event_hooks=event_hooks, transport=MockTransport(app) ) as http: - try: + with pytest.raises(httpx.HTTPStatusError): await http.get("http://127.0.0.1:8000/status/400") - except httpx.HTTPStatusError as exc: - assert exc.response.is_closed def test_event_hooks_with_redirect(): diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index df43f53291..61f80a400d 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -142,12 +142,14 @@ def test_redirect_303(): def test_next_request(): client = httpx.Client(transport=MockTransport(redirects)) request = client.build_request("POST", "https://example.org/redirect_303") - response = client.send(request, allow_redirects=False) + with client.send(request, allow_redirects=False) as response: + response.read() assert response.status_code == httpx.codes.SEE_OTHER assert response.url == "https://example.org/redirect_303" assert response.next_request is not None - response = client.send(response.next_request, allow_redirects=False) + with client.send(response.next_request, allow_redirects=False) as response: + response.read() assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" assert response.next_request is None @@ -157,12 +159,14 @@ def test_next_request(): async def test_async_next_request(): client = httpx.AsyncClient(transport=MockTransport(redirects)) request = client.build_request("POST", "https://example.org/redirect_303") - response = await client.send(request, allow_redirects=False) + async with client.send(request, allow_redirects=False) as response: + await response.aread() assert response.status_code == httpx.codes.SEE_OTHER assert response.url == "https://example.org/redirect_303" assert response.next_request is not None - response = await client.send(response.next_request, allow_redirects=False) + async with client.send(response.next_request, allow_redirects=False) as response: + await response.aread() assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" assert response.next_request is None @@ -307,7 +311,7 @@ def test_can_stream_if_no_redirect(): client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.org/redirect_301" with client.stream("GET", url, allow_redirects=False) as response: - assert not response.is_closed + pass assert response.status_code == httpx.codes.MOVED_PERMANENTLY assert response.headers["location"] == "https://example.org/" diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index ef26beda09..6bad0e0ecd 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -272,13 +272,11 @@ def test_read(): assert response.status_code == 200 assert response.text == "Hello, world!" assert response.encoding is None - assert response.is_closed content = response.read() assert content == b"Hello, world!" assert response.content == b"Hello, world!" - assert response.is_closed def test_empty_read(): @@ -287,13 +285,11 @@ def test_empty_read(): assert response.status_code == 200 assert response.text == "" assert response.encoding is None - assert response.is_closed content = response.read() assert content == b"" assert response.content == b"" - assert response.is_closed @pytest.mark.asyncio @@ -306,13 +302,11 @@ async def test_aread(): assert response.status_code == 200 assert response.text == "Hello, world!" assert response.encoding is None - assert response.is_closed content = await response.aread() assert content == b"Hello, world!" assert response.content == b"Hello, world!" - assert response.is_closed @pytest.mark.asyncio @@ -322,13 +316,11 @@ async def test_empty_aread(): assert response.status_code == 200 assert response.text == "" assert response.encoding is None - assert response.is_closed content = await response.aread() assert content == b"" assert response.content == b"" - assert response.is_closed def test_iter_raw(): @@ -487,13 +479,11 @@ def test_sync_streaming_response(): ) assert response.status_code == 200 - assert not response.is_closed content = response.read() assert content == b"Hello, world!" assert response.content == b"Hello, world!" - assert response.is_closed @pytest.mark.asyncio @@ -504,13 +494,11 @@ async def test_async_streaming_response(): ) assert response.status_code == 200 - assert not response.is_closed content = await response.aread() assert content == b"Hello, world!" assert response.content == b"Hello, world!" - assert response.is_closed def test_cannot_read_after_stream_consumed(): @@ -542,29 +530,6 @@ async def test_cannot_aread_after_stream_consumed(): await response.aread() -def test_cannot_read_after_response_closed(): - response = httpx.Response( - 200, - content=streaming_body(), - ) - - response.close() - with pytest.raises(httpx.ResponseClosed): - response.read() - - -@pytest.mark.asyncio -async def test_cannot_aread_after_response_closed(): - response = httpx.Response( - 200, - content=async_streaming_body(), - ) - - await response.aclose() - with pytest.raises(httpx.ResponseClosed): - await response.aread() - - @pytest.mark.asyncio async def test_elapsed_not_available_until_closed(): response = httpx.Response( From df2fa1cd7814351da788441802e2fb4ecfa71a94 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 8 Oct 2020 20:10:22 +0200 Subject: [PATCH 2/3] Coverage, py36 backports --- httpx/_utils.py | 6 +++--- setup.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/httpx/_utils.py b/httpx/_utils.py index 99399b64da..5ab5dcca69 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -548,7 +548,7 @@ def warn_deprecated(message: str) -> None: # pragma: nocover def ensure_context_manager( value: typing.Union[T, typing.ContextManager[T]] ) -> typing.Iterator[T]: - if isinstance(value, typing.ContextManager): + if isinstance(value, typing.ContextManager): # pragma: no cover with value as val: yield val else: @@ -564,10 +564,10 @@ def __init__( self._value = value async def __aenter__(self) -> T: - if isinstance(self._value, typing.AsyncContextManager): + if isinstance(self._value, typing.AsyncContextManager): # pragma: no cover return await self._value.__aenter__() return await self._value async def __aexit__(self, *args: typing.Any) -> None: - if isinstance(self._value, typing.AsyncContextManager): + if isinstance(self._value, typing.AsyncContextManager): # pragma: no cover await self._value.__aexit__(*args) diff --git a/setup.py b/setup.py index 075673ee6f..1b4fc49660 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,8 @@ def get_packages(package): "sniffio", "rfc3986[idna2008]>=1.3,<2", "httpcore==0.12.*", + # Backports. + "async_generator; python_version<'3.7'", ], extras_require={ "http2": "h2==3.*", From 6721b265603e82e05650dd6a245c34a41088d2e0 Mon Sep 17 00:00:00 2001 From: Florimond Manca Date: Mon, 1 Mar 2021 12:10:59 +0100 Subject: [PATCH 3/3] Bring back Response.close/aclose --- docs/exceptions.md | 4 ++++ httpx/__init__.py | 2 ++ httpx/_client.py | 2 ++ httpx/_exceptions.py | 15 +++++++++++++++ httpx/_models.py | 22 ++++++++++++++++++++++ 5 files changed, 45 insertions(+) diff --git a/docs/exceptions.md b/docs/exceptions.md index 144165ab86..949ac47a19 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -81,6 +81,7 @@ except httpx.HTTPStatusError as exc: * StreamConsumed * ResponseNotRead * RequestNotRead + * ResponseClosed --- @@ -166,3 +167,6 @@ except httpx.HTTPStatusError as exc: ::: httpx.RequestNotRead :docstring: + +::: httpx.ResponseClosed + :docstring: diff --git a/httpx/__init__.py b/httpx/__init__.py index a356f4cc5e..96d9e0c2f8 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -22,6 +22,7 @@ RemoteProtocolError, RequestError, RequestNotRead, + ResponseClosed, ResponseNotRead, StreamConsumed, StreamError, @@ -87,6 +88,7 @@ "RequestError", "RequestNotRead", "Response", + "ResponseClosed", "ResponseNotRead", "StatusCode", "stream", diff --git a/httpx/_client.py b/httpx/_client.py index 8e94e6a8d1..7df8570d9a 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -848,6 +848,8 @@ def _send_handling_redirects( yield response break + response.close() + @contextmanager def _send_single_request( self, request: Request, timeout: Timeout diff --git a/httpx/_exceptions.py b/httpx/_exceptions.py index 454631fc06..bade9f9b81 100644 --- a/httpx/_exceptions.py +++ b/httpx/_exceptions.py @@ -29,6 +29,7 @@ x StreamConsumed x ResponseNotRead x RequestNotRead + x ResponseClosed """ import contextlib import typing @@ -302,6 +303,20 @@ def __init__(self) -> None: super().__init__(message) +class ResponseClosed(StreamError): + """ + Attempted to read or stream response content, but the request has been + closed. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream response content, but the request has " + "been closed." + ) + super().__init__(message) + + @contextlib.contextmanager def map_exceptions( mapping: typing.Mapping[typing.Type[Exception], typing.Type[Exception]], diff --git a/httpx/_models.py b/httpx/_models.py index 637954c738..5c76dc581e 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -30,6 +30,7 @@ HTTPStatusError, InvalidURL, RequestNotRead, + ResponseClosed, ResponseNotRead, StreamConsumed, map_exceptions, @@ -918,6 +919,7 @@ def __init__( self.ext = {} if ext is None else ext self.history = [] if history is None else list(history) + self.is_closed = False self.is_stream_consumed = False if stream is not None: @@ -1217,6 +1219,8 @@ def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]: """ if self.is_stream_consumed: raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() if not isinstance(self.stream, typing.Iterable): raise RuntimeError("Attempted to call a sync iterator on an async stream.") @@ -1233,6 +1237,14 @@ def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]: for chunk in chunker.flush(): yield chunk + def close(self) -> None: + """ + Mark the response as closed. + Automatically called if the response body is read to completion. + """ + if not self.is_closed: + self.is_closed = True + async def aread(self) -> bytes: """ Read and return the response content. @@ -1298,6 +1310,8 @@ async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes] """ if self.is_stream_consumed: raise StreamConsumed() + if self.is_closed: + raise ResponseClosed() if not isinstance(self.stream, typing.AsyncIterable): raise RuntimeError("Attempted to call a async iterator on a sync stream.") @@ -1314,6 +1328,14 @@ async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes] for chunk in chunker.flush(): yield chunk + async def aclose(self) -> None: + """ + Mark the response as closed. + Automatically called if the response body is read to completion. + """ + if not self.is_closed: + self.is_closed = True + class Cookies(MutableMapping): """