diff --git a/README.md b/README.md index 5178ada6..3ced129f 100644 --- a/README.md +++ b/README.md @@ -43,16 +43,13 @@ Here's an example of making an HTTP GET request using `httpcore`... ```python with httpcore.SyncConnectionPool() as http: - status_code, headers, stream, ext = http.request( + with http.request( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as response: + status_code, headers, stream, ext = response body = b''.join([chunk for chunk in stream]) - finally: - stream.close() print(status_code, body) ``` @@ -61,16 +58,13 @@ Or, using async... ```python async with httpcore.AsyncConnectionPool() as http: - status_code, headers, stream, ext = await http.arequest( + async with http.arequest( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as response: + status_code, headers, stream, ext = response body = b''.join([chunk async for chunk in stream]) - finally: - await stream.aclose() print(status_code, body) ``` diff --git a/docs/api.md b/docs/api.md index 3bbde423..40b075f2 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,55 +2,38 @@ ## Async API Overview -The `AsyncHTTPTransport` and `AsyncByteStream` classes provide the base -interface which transport classes need to implement. +The `AsyncHTTPTransport` class provides the base interface which transport classes need to implement. ::: httpcore.AsyncHTTPTransport :docstring: :members: arequest aclose -::: httpcore.AsyncByteStream - :docstring: - :members: __aiter__ aclose - The `AsyncConnectionPool` class is a concrete implementation of `AsyncHTTPTransport`. ::: httpcore.AsyncConnectionPool :docstring: - -The `PlainByteStream` and `AsyncIteratorByteStream` classes are concrete implementations of `AsyncByteStream`. - -::: httpcore.PlainByteStream - :docstring: - -::: httpcore.AsyncIteratorByteStream - :docstring: - --- ## Sync API Overview -The `SyncHTTPTransport` and `SyncByteStream` classes provide the base -interface which transport classes need to implement. +The `SyncHTTPTransport` class provides the base interface which transport classes need to implement. ::: httpcore.SyncHTTPTransport :docstring: :members: request close -::: httpcore.SyncByteStream - :docstring: - :members: __iter__ close - The `SyncConnectionPool` class is a concrete implementation of `SyncHTTPTransport`. ::: httpcore.SyncConnectionPool :docstring: -The `PlainByteStream` and `IteratorByteStream` classes are concrete implementations of `SyncByteStream`. +--- + +## Utilities -::: httpcore.PlainByteStream - :docstring: +The `PlainByteStream` can be used to return a bytestring with both bytes iterable +and async bytes iterable iterfaces. -::: httpcore.IteratorByteStream +::: httpcore.PlainByteStream :docstring: diff --git a/docs/index.md b/docs/index.md index 5178ada6..3ced129f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -43,16 +43,13 @@ Here's an example of making an HTTP GET request using `httpcore`... ```python with httpcore.SyncConnectionPool() as http: - status_code, headers, stream, ext = http.request( + with http.request( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as response: + status_code, headers, stream, ext = response body = b''.join([chunk for chunk in stream]) - finally: - stream.close() print(status_code, body) ``` @@ -61,16 +58,13 @@ Or, using async... ```python async with httpcore.AsyncConnectionPool() as http: - status_code, headers, stream, ext = await http.arequest( + async with http.arequest( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as response: + status_code, headers, stream, ext = response body = b''.join([chunk async for chunk in stream]) - finally: - await stream.aclose() print(status_code, body) ``` diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 4cc7c1b4..fd7776e3 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,7 +1,7 @@ -from ._async.base import AsyncByteStream, AsyncHTTPTransport +from ._async.base import AsyncHTTPTransport from ._async.connection_pool import AsyncConnectionPool from ._async.http_proxy import AsyncHTTPProxy -from ._bytestreams import AsyncIteratorByteStream, IteratorByteStream, PlainByteStream +from ._bytestreams import PlainByteStream from ._exceptions import ( CloseError, ConnectError, @@ -19,20 +19,17 @@ WriteError, WriteTimeout, ) -from ._sync.base import SyncByteStream, SyncHTTPTransport +from ._sync.base import SyncHTTPTransport from ._sync.connection_pool import SyncConnectionPool from ._sync.http_proxy import SyncHTTPProxy __all__ = [ - "AsyncByteStream", "AsyncConnectionPool", "AsyncHTTPProxy", "AsyncHTTPTransport", - "AsyncIteratorByteStream", "CloseError", "ConnectError", "ConnectTimeout", - "IteratorByteStream", "LocalProtocolError", "NetworkError", "PlainByteStream", @@ -42,7 +39,6 @@ "ReadError", "ReadTimeout", "RemoteProtocolError", - "SyncByteStream", "SyncConnectionPool", "SyncHTTPProxy", "SyncHTTPTransport", diff --git a/httpcore/_async/base.py b/httpcore/_async/base.py index cf449f42..024af43d 100644 --- a/httpcore/_async/base.py +++ b/httpcore/_async/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import AsyncIterator, Tuple, Type +from typing import AsyncContextManager, AsyncIterable, Tuple, Type from .._types import URL, Headers, T @@ -32,43 +32,22 @@ class ConnectionState(enum.IntEnum): CLOSED = 5 # Connection closed. -class AsyncByteStream: - """ - The base interface for request and response bodies. - - Concrete implementations should subclass this class, and implement - the `\\__aiter__` method, and optionally the `aclose` method. - """ - - async def __aiter__(self) -> AsyncIterator[bytes]: - """ - Yield bytes representing the request or response body. - """ - yield b"" # pragma: nocover - - async def aclose(self) -> None: - """ - Must be called by the client to indicate that the stream has been closed. - """ - pass # pragma: nocover - - class AsyncHTTPTransport: """ The base interface for sending HTTP requests. - Concete implementations should subclass this class, and implement + Concrete implementations should subclass this class, and implement the `request` method, and optionally the `close` method. """ - async def arequest( + def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncContextManager[Tuple[int, Headers, AsyncIterable[bytes], dict]]: """ The interface for sending a single HTTP request, and returning a response. @@ -79,17 +58,17 @@ async def arequest( of (scheme, host, port, path). * **headers** - `Optional[List[Tuple[bytes, bytes]]]` - Any HTTP headers to send with the request. - * **stream** - `Optional[AsyncByteStream]` - The body of the HTTP request. + * **stream** - `Optional[AsyncIterable[bytes]]` - The body of the HTTP request. * **ext** - `Optional[dict]` - A dictionary of optional extensions. ** Returns:** - A four-tuple of: + A context manager yielding a four-tuple of: * **status_code** - `int` - The HTTP status code, such as `200`. * **headers** - `List[Tuple[bytes, bytes]]` - Any HTTP headers included on the response. - * **stream** - `AsyncByteStream` - The body of the HTTP response. + * **stream** - `AsyncIterable[bytes]` - The body of the HTTP response. * **ext** - `dict` - A dictionary of optional extensions. """ raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 530663d8..67936d27 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -1,16 +1,12 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import AsyncIterable, AsyncIterator, Optional, Tuple, cast from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend +from .._compat import asynccontextmanager from .._exceptions import ConnectError, ConnectTimeout from .._types import URL, Headers, Origin, TimeoutDict from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import ( - AsyncByteStream, - AsyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import AsyncHTTPTransport, ConnectionState, NewConnectionRequired from .http import AsyncBaseHTTPConnection from .http11 import AsyncHTTP11Connection @@ -72,14 +68,15 @@ def request_lock(self) -> AsyncLock: self._request_lock = self.backend.create_lock() return self._request_lock + @asynccontextmanager async def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: assert url_to_origin(url) == self.origin ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -103,7 +100,10 @@ async def arequest( logger.trace( "connection.arequest method=%r url=%r headers=%r", method, url, headers ) - return await self.connection.arequest(method, url, headers, stream, ext) + async with self.connection.arequest( + method, url, headers, stream, ext + ) as response: + yield response async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: scheme, hostname, port = self.origin diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 46ede0ba..7747a35e 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -1,8 +1,8 @@ import warnings from ssl import SSLContext from typing import ( + AsyncIterable, AsyncIterator, - Callable, Dict, List, Optional, @@ -14,16 +14,12 @@ from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore from .._backends.base import lookup_async_backend +from .._compat import asynccontextmanager from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, origin_to_url_string, url_to_origin -from .base import ( - AsyncByteStream, - AsyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import AsyncHTTPTransport, ConnectionState, NewConnectionRequired from .connection import AsyncHTTPConnection logger = get_logger(__name__) @@ -40,39 +36,6 @@ async def release(self) -> None: return -class ResponseByteStream(AsyncByteStream): - def __init__( - self, - stream: AsyncByteStream, - connection: AsyncHTTPConnection, - callback: Callable, - ) -> None: - """ - A wrapper around the response stream that we return from `.arequest()`. - - Ensures that when `stream.aclose()` is called, the connection pool - is notified via a callback. - """ - self.stream = stream - self.connection = connection - self.callback = callback - - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self.stream: - yield chunk - - async def aclose(self) -> None: - try: - # Call the underlying stream close callback. - # This will be a call to `AsyncHTTP11Connection._response_closed()` - # or `AsyncHTTP2Stream._response_closed()`. - await self.stream.aclose() - finally: - # Call the connection pool close callback. - # This will be a call to `AsyncConnectionPool._response_closed()`. - await self.callback(self.connection) - - class AsyncConnectionPool(AsyncHTTPTransport): """ A connection pool for making HTTP requests. @@ -178,14 +141,15 @@ def _create_connection( backend=self._backend, ) + @asynccontextmanager async def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: if url[0] not in (b"http", b"https"): scheme = url[0].decode("latin-1") raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") @@ -215,9 +179,10 @@ async def arequest( logger.trace("reuse connection=%r", connection) try: - response = await connection.arequest( + async with connection.arequest( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response except NewConnectionRequired: connection = None except Exception: # noqa: PIE786 @@ -225,11 +190,7 @@ async def arequest( await self._remove_from_pool(connection) raise - status_code, headers, stream, ext = response - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, ext + await self._response_closed(connection) async def _get_connection_from_pool( self, origin: Origin diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index ffa3fa8b..25aeb006 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -1,14 +1,15 @@ from ssl import SSLContext -from typing import AsyncIterator, List, Tuple, Union, cast +from typing import AsyncIterable, AsyncIterator, List, Tuple, Union, cast import h11 from .._backends.auto import AsyncSocketStream -from .._bytestreams import AsyncIteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream +from .._compat import asynccontextmanager from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import AsyncByteStream, ConnectionState +from .base import ConnectionState from .http import AsyncBaseHTTPConnection H11Event = Union[ @@ -47,14 +48,15 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @asynccontextmanager async def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: headers = [] if headers is None else headers stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -70,15 +72,15 @@ async def arequest( reason_phrase, headers, ) = await self._receive_response(timeout) - response_stream = AsyncIteratorByteStream( - aiterator=self._receive_response_data(timeout), - aclose_func=self._response_closed, - ) + response_stream = self._receive_response_data(timeout) ext = { "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - return (status_code, headers, response_stream, ext) + try: + yield (status_code, headers, response_stream, ext) + finally: + await self._response_closed() async def start_tls( self, hostname: bytes, timeout: TimeoutDict = None @@ -100,7 +102,7 @@ async def _send_request( await self._send_event(event, timeout) async def _send_request_body( - self, stream: AsyncByteStream, timeout: TimeoutDict + self, stream: AsyncIterable[bytes], timeout: TimeoutDict ) -> None: """ Send the request body. diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 3c7404aa..97234c88 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import AsyncIterator, Dict, List, Tuple, cast +from typing import AsyncIterable, AsyncIterator, Dict, List, Tuple, cast import h2.connection import h2.events @@ -8,11 +8,12 @@ from h2.settings import SettingCodes, Settings from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream -from .._bytestreams import AsyncIteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream +from .._compat import asynccontextmanager from .._exceptions import PoolTimeout, RemoteProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import AsyncByteStream, ConnectionState, NewConnectionRequired +from .base import ConnectionState, NewConnectionRequired from .http import AsyncBaseHTTPConnection logger = get_logger(__name__) @@ -85,14 +86,15 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @asynccontextmanager async def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -103,8 +105,7 @@ async def arequest( await self.send_connection_init(timeout) self.sent_connection_init = True - await self.max_streams_semaphore.acquire() - try: + async with self.max_streams_semaphore: try: stream_id = self.h2_state.get_next_available_stream_id() except NoAvailableStreamIDError: @@ -116,10 +117,10 @@ async def arequest( h2_stream = AsyncHTTP2Stream(stream_id=stream_id, connection=self) self.streams[stream_id] = h2_stream self.events[stream_id] = [] - return await h2_stream.arequest(method, url, headers, stream, ext) - except Exception: # noqa: PIE786 - await self.max_streams_semaphore.release() - raise + async with h2_stream.arequest( + method, url, headers, stream, ext + ) as response: + yield response async def send_connection_init(self, timeout: TimeoutDict) -> None: """ @@ -251,18 +252,15 @@ async def acknowledge_received_data( await self.socket.write(data_to_send, timeout) async def close_stream(self, stream_id: int) -> None: - try: - logger.trace("close_stream stream_id=%r", stream_id) - del self.streams[stream_id] - del self.events[stream_id] - - if not self.streams: - if self.state == ConnectionState.ACTIVE: - self.state = ConnectionState.IDLE - elif self.state == ConnectionState.FULL: - await self.aclose() - finally: - await self.max_streams_semaphore.release() + logger.trace("close_stream stream_id=%r", stream_id) + del self.streams[stream_id] + del self.events[stream_id] + + if not self.streams: + if self.state == ConnectionState.ACTIVE: + self.state = ConnectionState.IDLE + elif self.state == ConnectionState.FULL: + await self.aclose() class AsyncHTTP2Stream: @@ -270,14 +268,15 @@ def __init__(self, stream_id: int, connection: AsyncHTTP2Connection) -> None: self.stream_id = stream_id self.connection = connection + @asynccontextmanager async def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: headers = [] if headers is None else [(k.lower(), v) for (k, v) in headers] stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -295,14 +294,16 @@ async def arequest( # Receive the response. status_code, headers = await self.receive_response(timeout) - response_stream = AsyncIteratorByteStream( - aiterator=self.body_iter(timeout), aclose_func=self._response_closed - ) + response_stream = self.body_iter(timeout) ext = { "http_version": "HTTP/2", } - return (status_code, headers, response_stream, ext) + + try: + yield (status_code, headers, response_stream, ext) + finally: + await self._response_closed() async def send_headers( self, @@ -349,7 +350,9 @@ async def send_headers( await self.connection.send_headers(self.stream_id, headers, end_stream, timeout) - async def send_body(self, stream: AsyncByteStream, timeout: TimeoutDict) -> None: + async def send_body( + self, stream: AsyncIterable[bytes], timeout: TimeoutDict + ) -> None: async for data in stream: while data: max_flow = await self.connection.wait_for_outgoing_flow( diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index d9df762b..031eb1fe 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -1,13 +1,13 @@ from http import HTTPStatus from ssl import SSLContext -from typing import Tuple, cast +from typing import AsyncIterable, AsyncIterator, Tuple, cast +from .._compat import AsyncExitStack, asynccontextmanager from .._exceptions import ProxyError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger, url_to_origin -from .base import AsyncByteStream from .connection import AsyncHTTPConnection -from .connection_pool import AsyncConnectionPool, ResponseByteStream +from .connection_pool import AsyncConnectionPool logger = get_logger(__name__) @@ -87,14 +87,15 @@ def __init__( max_keepalive=max_keepalive, ) + @asynccontextmanager async def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: if self._keepalive_expiry is not None: await self._keepalive_sweep() @@ -109,9 +110,10 @@ async def arequest( method, url, ) - return await self._forward_request( + async with self._forward_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response else: # By default HTTPS should be tunnelled. logger.trace( @@ -121,18 +123,20 @@ async def arequest( method, url, ) - return await self._tunnel_request( + async with self._tunnel_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response + @asynccontextmanager async def _forward_request( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: """ Forwarded proxy requests include the entire URL as the HTTP target, rather than just the path. @@ -162,24 +166,23 @@ async def _forward_request( url = self.proxy_origin + (target,) headers = merge_headers(self.proxy_headers, headers) - (status_code, headers, stream, ext) = await connection.arequest( - method, url, headers=headers, stream=stream, ext=ext - ) - - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - - return status_code, headers, wrapped_stream, ext + try: + async with connection.arequest( + method, url, headers=headers, stream=stream, ext=ext + ) as response: + yield response + finally: + await self._response_closed(connection) + @asynccontextmanager async def _tunnel_request( self, method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: """ Tunnelled proxy requests require an initial CONNECT request to establish the connection, and then send regular requests. @@ -189,82 +192,78 @@ async def _tunnel_request( origin = url_to_origin(url) connection = await self._get_connection_from_pool(origin) - if connection is None: - scheme, host, port = origin - - # First, create a connection to the proxy server - proxy_connection = AsyncHTTPConnection( - origin=self.proxy_origin, - http2=self._http2, - ssl_context=self._ssl_context, - ) - - # Issue a CONNECT request... - - # CONNECT www.example.org:80 HTTP/1.1 - # [proxy-headers] - target = b"%b:%d" % (host, port) - connect_url = self.proxy_origin + (target,) - connect_headers = [(b"Host", target), (b"Accept", b"*/*")] - connect_headers = merge_headers(connect_headers, self.proxy_headers) + async with AsyncExitStack() as exit_stack: + if connection is None: + scheme, host, port = origin - try: - ( - proxy_status_code, - _, - proxy_stream, - _, - ) = await proxy_connection.arequest( - b"CONNECT", connect_url, headers=connect_headers, ext=ext + # First, create a connection to the proxy server + proxy_connection = AsyncHTTPConnection( + origin=self.proxy_origin, + http2=self._http2, + ssl_context=self._ssl_context, ) - proxy_reason = get_reason_phrase(proxy_status_code) - logger.trace( - "tunnel_response proxy_status_code=%r proxy_reason=%r ", - proxy_status_code, - proxy_reason, + # Issue a CONNECT request... + + # CONNECT www.example.org:80 HTTP/1.1 + # [proxy-headers] + target = b"%b:%d" % (host, port) + connect_url = self.proxy_origin + (target,) + connect_headers = [(b"Host", target), (b"Accept", b"*/*")] + connect_headers = merge_headers(connect_headers, self.proxy_headers) + + try: + proxy_response = await exit_stack.enter_async_context( + proxy_connection.arequest( + b"CONNECT", connect_url, headers=connect_headers, ext=ext + ) + ) + proxy_status_code, _, proxy_stream, _ = proxy_response + proxy_reason = get_reason_phrase(proxy_status_code) + logger.trace( + "tunnel_response proxy_status_code=%r proxy_reason=%r ", + proxy_status_code, + proxy_reason, + ) + # Read the response data without closing the socket + async for _ in proxy_stream: + pass + + # See if the tunnel was successfully established. + if proxy_status_code < 200 or proxy_status_code > 299: + msg = "%d %s" % (proxy_status_code, proxy_reason) + raise ProxyError(msg) + + # Upgrade to TLS if required + # We assume the target speaks TLS on the specified port + if scheme == b"https": + await proxy_connection.start_tls(host, timeout) + except Exception as exc: + await proxy_connection.aclose() + raise ProxyError(exc) + + # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. + # This means the proxy connection is now unusable, and we must create + # a new one for regular requests, making sure to use the same socket to + # retain the tunnel. + connection = AsyncHTTPConnection( + origin=origin, + http2=self._http2, + ssl_context=self._ssl_context, + socket=proxy_connection.socket, ) - # Read the response data without closing the socket - async for _ in proxy_stream: - pass - - # See if the tunnel was successfully established. - if proxy_status_code < 200 or proxy_status_code > 299: - msg = "%d %s" % (proxy_status_code, proxy_reason) - raise ProxyError(msg) - - # Upgrade to TLS if required - # We assume the target speaks TLS on the specified port - if scheme == b"https": - await proxy_connection.start_tls(host, timeout) - except Exception as exc: - await proxy_connection.aclose() - raise ProxyError(exc) - - # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. - # This means the proxy connection is now unusable, and we must create - # a new one for regular requests, making sure to use the same socket to - # retain the tunnel. - connection = AsyncHTTPConnection( - origin=origin, - http2=self._http2, - ssl_context=self._ssl_context, - socket=proxy_connection.socket, - ) - await self._add_to_pool(connection, timeout) - - # Once the connection has been established we can send requests on - # it as normal. - (status_code, headers, stream, ext) = await connection.arequest( - method, - url, - headers=headers, - stream=stream, - ext=ext, - ) + await self._add_to_pool(connection, timeout) - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - - return status_code, headers, wrapped_stream, ext + # Once the connection has been established we can send requests on + # it as normal. + try: + async with connection.arequest( + method, + url, + headers=headers, + stream=stream, + ext=ext, + ) as response: + yield response + finally: + await self._response_closed(connection) diff --git a/httpcore/_backends/base.py b/httpcore/_backends/base.py index 1ca6e31b..a3027f07 100644 --- a/httpcore/_backends/base.py +++ b/httpcore/_backends/base.py @@ -96,6 +96,17 @@ class AsyncSemaphore: Abstracts away any asyncio-specific interfaces. """ + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.release() + async def acquire(self, timeout: float = None) -> None: raise NotImplementedError() # pragma: no cover diff --git a/httpcore/_backends/sync.py b/httpcore/_backends/sync.py index 25e38ed0..92fde403 100644 --- a/httpcore/_backends/sync.py +++ b/httpcore/_backends/sync.py @@ -109,6 +109,17 @@ def __init__(self, max_value: int, exc_class: type) -> None: self.exc_class = exc_class self._semaphore = threading.Semaphore(max_value) + def __enter__(self) -> None: + self.acquire() + + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.release() + def acquire(self, timeout: float = None) -> None: if not self._semaphore.acquire(timeout=timeout): # type: ignore raise self.exc_class() diff --git a/httpcore/_bytestreams.py b/httpcore/_bytestreams.py index e938aaf9..5eeba2ee 100644 --- a/httpcore/_bytestreams.py +++ b/httpcore/_bytestreams.py @@ -1,10 +1,7 @@ -from typing import AsyncIterator, Callable, Iterator +from typing import AsyncIterator, Iterator -from ._async.base import AsyncByteStream -from ._sync.base import SyncByteStream - -class PlainByteStream(AsyncByteStream, SyncByteStream): +class PlainByteStream: """ A concrete implementation for either sync or async byte streams. Just handles a plain byte string as the content of the stream. @@ -22,57 +19,3 @@ def __iter__(self) -> Iterator[bytes]: async def __aiter__(self) -> AsyncIterator[bytes]: yield self._content - - -class IteratorByteStream(SyncByteStream): - """ - A concrete implementation for sync byte streams. - Handles a byte iterator as the content of the stream. - - ``` - def generate_content(): - ... - - stream = httpcore.IteratorByteStream(generate_content()) - ``` - """ - - def __init__(self, iterator: Iterator[bytes], close_func: Callable = None) -> None: - self._iterator = iterator - self._close_func = close_func - - def __iter__(self) -> Iterator[bytes]: - for chunk in self._iterator: - yield chunk - - def close(self) -> None: - if self._close_func is not None: - self._close_func() - - -class AsyncIteratorByteStream(AsyncByteStream): - """ - A concrete implementation for async byte streams. - Handles an async byte iterator as the content of the stream. - - ``` - async def generate_content(): - ... - - stream = httpcore.AsyncIteratorByteStream(generate_content()) - ``` - """ - - def __init__( - self, aiterator: AsyncIterator[bytes], aclose_func: Callable = None - ) -> None: - self._aiterator = aiterator - self._aclose_func = aclose_func - - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self._aiterator: - yield chunk - - async def aclose(self) -> None: - if self._aclose_func is not None: - await self._aclose_func() diff --git a/httpcore/_compat.py b/httpcore/_compat.py new file mode 100644 index 00000000..3191536f --- /dev/null +++ b/httpcore/_compat.py @@ -0,0 +1,9 @@ +try: + from contextlib import AsyncExitStack, asynccontextmanager # type: ignore # Py3.6 +except ImportError: # pragma: no cover + # Python 3.6 + from async_exit_stack import AsyncExitStack # type: ignore # noqa: F401 + from async_generator import asynccontextmanager # type: ignore # noqa: F401 + +# These will be imported by the unasynced code. +from contextlib import ExitStack, contextmanager # noqa: F401 diff --git a/httpcore/_sync/base.py b/httpcore/_sync/base.py index 95a434eb..5e6e4fca 100644 --- a/httpcore/_sync/base.py +++ b/httpcore/_sync/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import Iterator, Tuple, Type +from typing import ContextManager, Iterable, Tuple, Type from .._types import URL, Headers, T @@ -32,32 +32,11 @@ class ConnectionState(enum.IntEnum): CLOSED = 5 # Connection closed. -class SyncByteStream: - """ - The base interface for request and response bodies. - - Concrete implementations should subclass this class, and implement - the `\\__iter__` method, and optionally the `close` method. - """ - - def __iter__(self) -> Iterator[bytes]: - """ - Yield bytes representing the request or response body. - """ - yield b"" # pragma: nocover - - def close(self) -> None: - """ - Must be called by the client to indicate that the stream has been closed. - """ - pass # pragma: nocover - - class SyncHTTPTransport: """ The base interface for sending HTTP requests. - Concete implementations should subclass this class, and implement + Concrete implementations should subclass this class, and implement the `request` method, and optionally the `close` method. """ @@ -66,9 +45,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> ContextManager[Tuple[int, Headers, Iterable[bytes], dict]]: """ The interface for sending a single HTTP request, and returning a response. @@ -79,17 +58,17 @@ def request( of (scheme, host, port, path). * **headers** - `Optional[List[Tuple[bytes, bytes]]]` - Any HTTP headers to send with the request. - * **stream** - `Optional[SyncByteStream]` - The body of the HTTP request. + * **stream** - `Optional[Iterable[bytes]]` - The body of the HTTP request. * **ext** - `Optional[dict]` - A dictionary of optional extensions. ** Returns:** - A four-tuple of: + A context manager yielding a four-tuple of: * **status_code** - `int` - The HTTP status code, such as `200`. * **headers** - `List[Tuple[bytes, bytes]]` - Any HTTP headers included on the response. - * **stream** - `SyncByteStream` - The body of the HTTP response. + * **stream** - `Iterable[bytes]` - The body of the HTTP response. * **ext** - `dict` - A dictionary of optional extensions. """ raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 04042227..2b1bbdc6 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -1,16 +1,12 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import Iterable, Iterator, Optional, Tuple, cast from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend +from .._compat import contextmanager from .._exceptions import ConnectError, ConnectTimeout from .._types import URL, Headers, Origin, TimeoutDict from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import ( - SyncByteStream, - SyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import SyncHTTPTransport, ConnectionState, NewConnectionRequired from .http import SyncBaseHTTPConnection from .http11 import SyncHTTP11Connection @@ -72,14 +68,15 @@ def request_lock(self) -> SyncLock: self._request_lock = self.backend.create_lock() return self._request_lock + @contextmanager def request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: assert url_to_origin(url) == self.origin ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -103,7 +100,10 @@ def request( logger.trace( "connection.request method=%r url=%r headers=%r", method, url, headers ) - return self.connection.request(method, url, headers, stream, ext) + with self.connection.request( + method, url, headers, stream, ext + ) as response: + yield response def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: scheme, hostname, port = self.origin diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 4702184b..129ca0e5 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -1,8 +1,8 @@ import warnings from ssl import SSLContext from typing import ( + Iterable, Iterator, - Callable, Dict, List, Optional, @@ -14,16 +14,12 @@ from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore from .._backends.base import lookup_sync_backend +from .._compat import contextmanager from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, origin_to_url_string, url_to_origin -from .base import ( - SyncByteStream, - SyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import SyncHTTPTransport, ConnectionState, NewConnectionRequired from .connection import SyncHTTPConnection logger = get_logger(__name__) @@ -40,39 +36,6 @@ def release(self) -> None: return -class ResponseByteStream(SyncByteStream): - def __init__( - self, - stream: SyncByteStream, - connection: SyncHTTPConnection, - callback: Callable, - ) -> None: - """ - A wrapper around the response stream that we return from `.request()`. - - Ensures that when `stream.close()` is called, the connection pool - is notified via a callback. - """ - self.stream = stream - self.connection = connection - self.callback = callback - - def __iter__(self) -> Iterator[bytes]: - for chunk in self.stream: - yield chunk - - def close(self) -> None: - try: - # Call the underlying stream close callback. - # This will be a call to `SyncHTTP11Connection._response_closed()` - # or `SyncHTTP2Stream._response_closed()`. - self.stream.close() - finally: - # Call the connection pool close callback. - # This will be a call to `SyncConnectionPool._response_closed()`. - self.callback(self.connection) - - class SyncConnectionPool(SyncHTTPTransport): """ A connection pool for making HTTP requests. @@ -178,14 +141,15 @@ def _create_connection( backend=self._backend, ) + @contextmanager def request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: if url[0] not in (b"http", b"https"): scheme = url[0].decode("latin-1") raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") @@ -215,9 +179,10 @@ def request( logger.trace("reuse connection=%r", connection) try: - response = connection.request( + with connection.request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response except NewConnectionRequired: connection = None except Exception: # noqa: PIE786 @@ -225,11 +190,7 @@ def request( self._remove_from_pool(connection) raise - status_code, headers, stream, ext = response - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, ext + self._response_closed(connection) def _get_connection_from_pool( self, origin: Origin diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index b827454a..d940e070 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -1,14 +1,15 @@ from ssl import SSLContext -from typing import Iterator, List, Tuple, Union, cast +from typing import Iterable, Iterator, List, Tuple, Union, cast import h11 from .._backends.sync import SyncSocketStream -from .._bytestreams import IteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream +from .._compat import contextmanager from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import SyncByteStream, ConnectionState +from .base import ConnectionState from .http import SyncBaseHTTPConnection H11Event = Union[ @@ -47,14 +48,15 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @contextmanager def request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: headers = [] if headers is None else headers stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -70,15 +72,15 @@ def request( reason_phrase, headers, ) = self._receive_response(timeout) - response_stream = IteratorByteStream( - iterator=self._receive_response_data(timeout), - close_func=self._response_closed, - ) + response_stream = self._receive_response_data(timeout) ext = { "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - return (status_code, headers, response_stream, ext) + try: + yield (status_code, headers, response_stream, ext) + finally: + self._response_closed() def start_tls( self, hostname: bytes, timeout: TimeoutDict = None @@ -100,7 +102,7 @@ def _send_request( self._send_event(event, timeout) def _send_request_body( - self, stream: SyncByteStream, timeout: TimeoutDict + self, stream: Iterable[bytes], timeout: TimeoutDict ) -> None: """ Send the request body. diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index fe2d55eb..dcfe9f01 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import Iterator, Dict, List, Tuple, cast +from typing import Iterable, Iterator, Dict, List, Tuple, cast import h2.connection import h2.events @@ -8,11 +8,12 @@ from h2.settings import SettingCodes, Settings from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore, SyncSocketStream -from .._bytestreams import IteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream +from .._compat import contextmanager from .._exceptions import PoolTimeout, RemoteProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import SyncByteStream, ConnectionState, NewConnectionRequired +from .base import ConnectionState, NewConnectionRequired from .http import SyncBaseHTTPConnection logger = get_logger(__name__) @@ -85,14 +86,15 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @contextmanager def request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -103,8 +105,7 @@ def request( self.send_connection_init(timeout) self.sent_connection_init = True - self.max_streams_semaphore.acquire() - try: + with self.max_streams_semaphore: try: stream_id = self.h2_state.get_next_available_stream_id() except NoAvailableStreamIDError: @@ -116,10 +117,10 @@ def request( h2_stream = SyncHTTP2Stream(stream_id=stream_id, connection=self) self.streams[stream_id] = h2_stream self.events[stream_id] = [] - return h2_stream.request(method, url, headers, stream, ext) - except Exception: # noqa: PIE786 - self.max_streams_semaphore.release() - raise + with h2_stream.request( + method, url, headers, stream, ext + ) as response: + yield response def send_connection_init(self, timeout: TimeoutDict) -> None: """ @@ -251,18 +252,15 @@ def acknowledge_received_data( self.socket.write(data_to_send, timeout) def close_stream(self, stream_id: int) -> None: - try: - logger.trace("close_stream stream_id=%r", stream_id) - del self.streams[stream_id] - del self.events[stream_id] - - if not self.streams: - if self.state == ConnectionState.ACTIVE: - self.state = ConnectionState.IDLE - elif self.state == ConnectionState.FULL: - self.close() - finally: - self.max_streams_semaphore.release() + logger.trace("close_stream stream_id=%r", stream_id) + del self.streams[stream_id] + del self.events[stream_id] + + if not self.streams: + if self.state == ConnectionState.ACTIVE: + self.state = ConnectionState.IDLE + elif self.state == ConnectionState.FULL: + self.close() class SyncHTTP2Stream: @@ -270,14 +268,15 @@ def __init__(self, stream_id: int, connection: SyncHTTP2Connection) -> None: self.stream_id = stream_id self.connection = connection + @contextmanager def request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: headers = [] if headers is None else [(k.lower(), v) for (k, v) in headers] stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -295,14 +294,16 @@ def request( # Receive the response. status_code, headers = self.receive_response(timeout) - response_stream = IteratorByteStream( - iterator=self.body_iter(timeout), close_func=self._response_closed - ) + response_stream = self.body_iter(timeout) ext = { "http_version": "HTTP/2", } - return (status_code, headers, response_stream, ext) + + try: + yield (status_code, headers, response_stream, ext) + finally: + self._response_closed() def send_headers( self, @@ -349,7 +350,9 @@ def send_headers( self.connection.send_headers(self.stream_id, headers, end_stream, timeout) - def send_body(self, stream: SyncByteStream, timeout: TimeoutDict) -> None: + def send_body( + self, stream: Iterable[bytes], timeout: TimeoutDict + ) -> None: for data in stream: while data: max_flow = self.connection.wait_for_outgoing_flow( diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index f5576c01..b3823955 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -1,13 +1,13 @@ from http import HTTPStatus from ssl import SSLContext -from typing import Tuple, cast +from typing import Iterable, Iterator, Tuple, cast +from .._compat import ExitStack, contextmanager from .._exceptions import ProxyError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger, url_to_origin -from .base import SyncByteStream from .connection import SyncHTTPConnection -from .connection_pool import SyncConnectionPool, ResponseByteStream +from .connection_pool import SyncConnectionPool logger = get_logger(__name__) @@ -87,14 +87,15 @@ def __init__( max_keepalive=max_keepalive, ) + @contextmanager def request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: if self._keepalive_expiry is not None: self._keepalive_sweep() @@ -109,9 +110,10 @@ def request( method, url, ) - return self._forward_request( + with self._forward_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response else: # By default HTTPS should be tunnelled. logger.trace( @@ -121,18 +123,20 @@ def request( method, url, ) - return self._tunnel_request( + with self._tunnel_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response + @contextmanager def _forward_request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: """ Forwarded proxy requests include the entire URL as the HTTP target, rather than just the path. @@ -162,24 +166,23 @@ def _forward_request( url = self.proxy_origin + (target,) headers = merge_headers(self.proxy_headers, headers) - (status_code, headers, stream, ext) = connection.request( - method, url, headers=headers, stream=stream, ext=ext - ) - - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - - return status_code, headers, wrapped_stream, ext + try: + with connection.request( + method, url, headers=headers, stream=stream, ext=ext + ) as response: + yield response + finally: + self._response_closed(connection) + @contextmanager def _tunnel_request( self, method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: """ Tunnelled proxy requests require an initial CONNECT request to establish the connection, and then send regular requests. @@ -189,82 +192,78 @@ def _tunnel_request( origin = url_to_origin(url) connection = self._get_connection_from_pool(origin) - if connection is None: - scheme, host, port = origin - - # First, create a connection to the proxy server - proxy_connection = SyncHTTPConnection( - origin=self.proxy_origin, - http2=self._http2, - ssl_context=self._ssl_context, - ) - - # Issue a CONNECT request... - - # CONNECT www.example.org:80 HTTP/1.1 - # [proxy-headers] - target = b"%b:%d" % (host, port) - connect_url = self.proxy_origin + (target,) - connect_headers = [(b"Host", target), (b"Accept", b"*/*")] - connect_headers = merge_headers(connect_headers, self.proxy_headers) + with ExitStack() as exit_stack: + if connection is None: + scheme, host, port = origin - try: - ( - proxy_status_code, - _, - proxy_stream, - _, - ) = proxy_connection.request( - b"CONNECT", connect_url, headers=connect_headers, ext=ext + # First, create a connection to the proxy server + proxy_connection = SyncHTTPConnection( + origin=self.proxy_origin, + http2=self._http2, + ssl_context=self._ssl_context, ) - proxy_reason = get_reason_phrase(proxy_status_code) - logger.trace( - "tunnel_response proxy_status_code=%r proxy_reason=%r ", - proxy_status_code, - proxy_reason, + # Issue a CONNECT request... + + # CONNECT www.example.org:80 HTTP/1.1 + # [proxy-headers] + target = b"%b:%d" % (host, port) + connect_url = self.proxy_origin + (target,) + connect_headers = [(b"Host", target), (b"Accept", b"*/*")] + connect_headers = merge_headers(connect_headers, self.proxy_headers) + + try: + proxy_response = exit_stack.enter_context( + proxy_connection.request( + b"CONNECT", connect_url, headers=connect_headers, ext=ext + ) + ) + proxy_status_code, _, proxy_stream, _ = proxy_response + proxy_reason = get_reason_phrase(proxy_status_code) + logger.trace( + "tunnel_response proxy_status_code=%r proxy_reason=%r ", + proxy_status_code, + proxy_reason, + ) + # Read the response data without closing the socket + for _ in proxy_stream: + pass + + # See if the tunnel was successfully established. + if proxy_status_code < 200 or proxy_status_code > 299: + msg = "%d %s" % (proxy_status_code, proxy_reason) + raise ProxyError(msg) + + # Upgrade to TLS if required + # We assume the target speaks TLS on the specified port + if scheme == b"https": + proxy_connection.start_tls(host, timeout) + except Exception as exc: + proxy_connection.close() + raise ProxyError(exc) + + # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. + # This means the proxy connection is now unusable, and we must create + # a new one for regular requests, making sure to use the same socket to + # retain the tunnel. + connection = SyncHTTPConnection( + origin=origin, + http2=self._http2, + ssl_context=self._ssl_context, + socket=proxy_connection.socket, ) - # Read the response data without closing the socket - for _ in proxy_stream: - pass - - # See if the tunnel was successfully established. - if proxy_status_code < 200 or proxy_status_code > 299: - msg = "%d %s" % (proxy_status_code, proxy_reason) - raise ProxyError(msg) - - # Upgrade to TLS if required - # We assume the target speaks TLS on the specified port - if scheme == b"https": - proxy_connection.start_tls(host, timeout) - except Exception as exc: - proxy_connection.close() - raise ProxyError(exc) - - # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. - # This means the proxy connection is now unusable, and we must create - # a new one for regular requests, making sure to use the same socket to - # retain the tunnel. - connection = SyncHTTPConnection( - origin=origin, - http2=self._http2, - ssl_context=self._ssl_context, - socket=proxy_connection.socket, - ) - self._add_to_pool(connection, timeout) - - # Once the connection has been established we can send requests on - # it as normal. - (status_code, headers, stream, ext) = connection.request( - method, - url, - headers=headers, - stream=stream, - ext=ext, - ) + self._add_to_pool(connection, timeout) - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - - return status_code, headers, wrapped_stream, ext + # Once the connection has been established we can send requests on + # it as normal. + try: + with connection.request( + method, + url, + headers=headers, + stream=stream, + ext=ext, + ) as response: + yield response + finally: + self._response_closed(connection) diff --git a/setup.py b/setup.py index 90dbd2d8..c6753844 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,13 @@ def get_packages(package): packages=get_packages("httpcore"), include_package_data=True, zip_safe=False, - install_requires=["h11==0.*", "sniffio==1.*"], + install_requires=[ + "h11==0.*", + "sniffio==1.*", + # Backports. + "async_generator; python_version<'3.7'", + "async-exit-stack; python_version<'3.7'", + ], extras_require={ "http2": ["h2>=3,<5"], }, diff --git a/tests/async_tests/test_connection_pool.py b/tests/async_tests/test_connection_pool.py index 5b30209d..4052d985 100644 --- a/tests/async_tests/test_connection_pool.py +++ b/tests/async_tests/test_connection_pool.py @@ -1,13 +1,14 @@ -from typing import AsyncIterator, Tuple +from typing import AsyncIterable, AsyncIterator, Tuple import pytest import httpcore from httpcore._async.base import ConnectionState +from httpcore._compat import AsyncExitStack, asynccontextmanager from httpcore._types import URL, Headers -class MockConnection(object): +class MockConnection(httpcore.AsyncHTTPTransport): def __init__(self, http_version): self.origin = (b"http", b"example.org", 80) self.state = ConnectionState.PENDING @@ -15,14 +16,15 @@ def __init__(self, http_version): self.is_http2 = http_version == "HTTP/2" self.stream_count = 0 + @asynccontextmanager async def arequest( self, method: bytes, url: URL, headers: Headers = None, - stream: httpcore.AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, httpcore.AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: self.state = ConnectionState.ACTIVE self.stream_count += 1 @@ -34,11 +36,12 @@ async def on_close(): async def aiterator() -> AsyncIterator[bytes]: yield b"" - stream = httpcore.AsyncIteratorByteStream( - aiterator=aiterator(), aclose_func=on_close - ) + stream = aiterator() - return 200, [], stream, {} + try: + yield 200, [], stream, {} + finally: + await on_close() async def aclose(self): pass @@ -63,14 +66,8 @@ def _create_connection(self, **kwargs): return MockConnection(self.http_version) -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - body = [] - async for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - await stream.aclose() +async def read_body(stream: AsyncIterable[bytes]) -> bytes: + return b"".join([chunk async for chunk in stream]) @pytest.mark.trio @@ -80,21 +77,25 @@ async def test_sequential_requests(http_version) -> None: info = await http.get_connection_info() assert info == {} - response = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with http.arequest( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + await read_body(stream) - await read_body(stream) info = await http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} - response = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with http.arequest( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + await read_body(stream) - await read_body(stream) info = await http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} @@ -105,25 +106,36 @@ async def test_concurrent_requests_h11() -> None: info = await http.get_connection_info() assert info == {} - response_1 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with AsyncExitStack() as exit_stack2: + async with AsyncExitStack() as exit_stack1: + response_1 = await exit_stack1.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + response_2 = await exit_stack2.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 + info = await http.get_connection_info() + assert info == { + "http://example.org": [ + "ConnectionState.ACTIVE", + "ConnectionState.ACTIVE", + ] + } + + await read_body(stream_1) + + info = await http.get_connection_info() + assert info == { + "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] + } + + await read_body(stream_2) - response_2 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = await http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.ACTIVE"] - } - - await read_body(stream_1) - info = await http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] - } - - await read_body(stream_2) info = await http.get_connection_info() assert info == { "http://example.org": ["ConnectionState.IDLE", "ConnectionState.IDLE"] @@ -136,20 +148,29 @@ async def test_concurrent_requests_h2() -> None: info = await http.get_connection_info() assert info == {} - response_1 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with AsyncExitStack() as exit_stack2: + async with AsyncExitStack() as exit_stack1: + response_1 = await exit_stack1.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} - response_2 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + response_2 = await exit_stack2.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 - await read_body(stream_1) - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + await read_body(stream_1) + + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + await read_body(stream_2) - await read_body(stream_2) info = await http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index ac248325..cea70266 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -1,9 +1,12 @@ import platform import sys +from functools import partial +from typing import AsyncIterable import pytest import httpcore +from httpcore._compat import AsyncExitStack from httpcore._types import URL from tests.conftest import HTTPS_SERVER_URL from tests.utils import Server, lookup_async_backend @@ -14,14 +17,8 @@ def backend(request): return request.param -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - body = [] - async for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - await stream.aclose() +async def read_body(stream: AsyncIterable[bytes]) -> bytes: + return b"".join([chunk async for chunk in stream]) @pytest.mark.anyio @@ -30,8 +27,9 @@ async def test_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -45,8 +43,9 @@ async def test_https_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if https_server.sends_reason else "" @@ -61,7 +60,8 @@ async def test_request_unsupported_protocol(backend: str) -> None: url = (b"ftp", b"example.org", 443, b"/") headers = [(b"host", b"example.org")] with pytest.raises(httpcore.UnsupportedProtocol): - await http.arequest(method, url, headers) + async with http.arequest(method, url, headers): + pass # pragma: no cover @pytest.mark.anyio @@ -70,8 +70,9 @@ async def test_http2_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/2"} @@ -84,8 +85,9 @@ async def test_closing_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header, (b"connection", b"close")] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -99,8 +101,9 @@ async def test_http_request_reuse_connection(backend: str, server: Server) -> No method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -110,8 +113,9 @@ async def test_http_request_reuse_connection(backend: str, server: Server) -> No method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -127,8 +131,9 @@ async def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if https_server.sends_reason else "" @@ -138,8 +143,9 @@ async def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if https_server.sends_reason else "" @@ -155,8 +161,9 @@ async def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -170,8 +177,9 @@ async def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -194,8 +202,9 @@ async def test_http_proxy( max_connections=max_connections, backend=backend, ) as http: - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -220,7 +229,8 @@ async def test_proxy_socket_does_not_leak_when_the_connection_hasnt_been_added_t async with httpcore.AsyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http: for _ in range(100): try: - _ = await http.arequest(method, url, headers) + async with http.arequest(method, url, headers) as _: + pass except (httpcore.ProxyError, httpcore.RemoteProtocolError): pass @@ -243,8 +253,9 @@ async def test_http_request_local_address(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -272,8 +283,9 @@ async def test_proxy_https_requests( max_connections=max_connections, http2=http2, ) as http: - status_code, headers, stream, ext = await http.arequest(method, url, headers) - _ = await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + _ = await read_body(stream) assert status_code == 200 assert ext["http_version"] == "HTTP/2" if http2 else "HTTP/1.1" @@ -325,15 +337,20 @@ async def test_connection_pool_get_connection_info( url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - _, _, stream_1, _ = await http.arequest(method, url, headers) - _, _, stream_2, _ = await http.arequest(method, url, headers) + async with AsyncExitStack() as exit_stack: + _, _, stream_1, _ = await exit_stack.enter_async_context( + http.arequest(method, url, headers) + ) + _, _, stream_2, _ = await exit_stack.enter_async_context( + http.arequest(method, url, headers) + ) - try: - stats = await http.get_connection_info() - assert stats == expected_during_active - finally: - await read_body(stream_1) - await read_body(stream_2) + try: + stats = await http.get_connection_info() + assert stats == expected_during_active + finally: + await read_body(stream_1) + await read_body(stream_2) stats = await http.get_connection_info() assert stats == expected_during_idle @@ -355,12 +372,13 @@ async def test_http_request_unix_domain_socket( method = b"GET" url = (b"http", b"localhost", None, b"/") headers = [(b"host", b"localhost")] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - assert status_code == 200 - reason = "OK" if uds_server.sends_reason else "" - assert ext == {"http_version": "HTTP/1.1", "reason": reason} - body = await read_body(stream) - assert body == b"Hello, world!" + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + assert status_code == 200 + expected_reason = "OK" if uds_server.sends_reason else "" + assert ext == {"http_version": "HTTP/1.1", "reason": expected_reason} + body = await read_body(stream) + assert body == b"Hello, world!" @pytest.mark.parametrize("max_keepalive", [1, 3, 5]) @@ -376,19 +394,17 @@ async def test_max_keepalive_connections_handled_correctly( url = (b"http", *server.netloc, b"/") headers = [server.host_header] - connections_streams = [] - for _ in range(connections_number): - _, _, stream, _ = await http.arequest(method, url, headers) - connections_streams.append(stream) + async with AsyncExitStack() as exit_stack: + for _ in range(connections_number): + _, _, stream, _ = await exit_stack.enter_async_context( + http.arequest(method, url, headers) + ) + exit_stack.push_async_callback(partial(read_body, stream)) - try: - for i in range(len(connections_streams)): - await read_body(connections_streams[i]) - finally: - stats = await http.get_connection_info() + stats = await http.get_connection_info() - connections_in_pool = next(iter(stats.values())) - assert len(connections_in_pool) == min(connections_number, max_keepalive) + connections_in_pool = next(iter(stats.values())) + assert len(connections_in_pool) == min(connections_number, max_keepalive) @pytest.mark.anyio @@ -397,8 +413,9 @@ async def test_explicit_backend_name(server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -425,10 +442,9 @@ async def test_broken_socket_detection_many_open_files( # * Second attempt would have failed without a fix, due to a "filedescriptor # out of range in select()" exception. for _ in range(2): - status_code, response_headers, stream, ext = await http.arequest( - method, url, headers - ) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, response_headers, stream, ext = response + await read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -453,7 +469,8 @@ async def test_cannot_connect_tcp(backend: str, url) -> None: async with httpcore.AsyncConnectionPool(backend=backend) as http: method = b"GET" with pytest.raises(httpcore.ConnectError): - await http.arequest(method, url) + async with http.arequest(method, url) as _: + pass # pragma: no cover @pytest.mark.anyio @@ -466,7 +483,8 @@ async def test_cannot_connect_uds(backend: str) -> None: url = (b"http", b"localhost", None, b"/") async with httpcore.AsyncConnectionPool(backend=backend, uds=uds) as http: with pytest.raises(httpcore.ConnectError): - await http.arequest(method, url) + async with http.arequest(method, url): + pass # pragma: no cover @pytest.mark.skipif( @@ -484,7 +502,8 @@ async def test_connection_timeout_tcp(backend: str, server: Server) -> None: async with httpcore.AsyncConnectionPool(backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - await http.arequest(method, url, headers, ext=ext) + async with http.arequest(method, url, headers, ext=ext): + pass # pragma: no cover @pytest.mark.skipif( @@ -504,4 +523,5 @@ async def test_connection_timeout_uds( async with httpcore.AsyncConnectionPool(uds=uds, backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - await http.arequest(method, url, headers, ext=ext) + async with http.arequest(method, url, headers, ext=ext): + pass # pragma: no cover diff --git a/tests/async_tests/test_retries.py b/tests/async_tests/test_retries.py index fe05a4ab..35380493 100644 --- a/tests/async_tests/test_retries.py +++ b/tests/async_tests/test_retries.py @@ -1,6 +1,6 @@ import queue import time -from typing import Any, List, Optional +from typing import Any, AsyncIterable, List, Optional import pytest @@ -32,11 +32,8 @@ async def open_tcp_stream(self, *args: Any, **kwargs: Any) -> AsyncSocketStream: return await super().open_tcp_stream(*args, **kwargs) -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - return b"".join([chunk async for chunk in stream]) - finally: - await stream.aclose() +async def read_body(stream: AsyncIterable[bytes]) -> bytes: + return b"".join([chunk async for chunk in stream]) @pytest.mark.anyio @@ -52,18 +49,20 @@ async def test_no_retries(server: Server) -> None: async with httpcore.AsyncConnectionPool( max_keepalive_connections=0, backend=backend ) as http: - response = await http.arequest(method, url, headers) - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, _, stream, _ = response + assert status_code == 200 + await read_body(stream) backend.push(httpcore.ConnectTimeout(), httpcore.ConnectError()) with pytest.raises(httpcore.ConnectTimeout): - await http.arequest(method, url, headers) + async with http.arequest(method, url, headers) as response: + pass # pragma: no cover with pytest.raises(httpcore.ConnectError): - await http.arequest(method, url, headers) + async with http.arequest(method, url, headers) as response: + pass # pragma: no cover @pytest.mark.anyio @@ -82,21 +81,21 @@ async def test_retries_enabled(server: Server) -> None: retries=retries, max_keepalive_connections=0, backend=backend ) as http: # Standard case, no failures. - response = await http.arequest(method, url, headers) - assert backend.pop_open_tcp_stream_intervals() == [] - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) + async with http.arequest(method, url, headers) as response: + assert backend.pop_open_tcp_stream_intervals() == [] + status_code, _, stream, _ = response + assert status_code == 200 + await read_body(stream) # One failure, then success. backend.push(httpcore.ConnectError(), None) - response = await http.arequest(method, url, headers) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - ] - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) + async with http.arequest(method, url, headers) as response: + assert backend.pop_open_tcp_stream_intervals() == [ + pytest.approx(0, abs=5e-3), # Retry immediately. + ] + status_code, _, stream, _ = response + assert status_code == 200 + await read_body(stream) # Three failures, then success. backend.push( @@ -105,22 +104,25 @@ async def test_retries_enabled(server: Server) -> None: httpcore.ConnectTimeout(), None, ) - response = await http.arequest(method, url, headers) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - pytest.approx(0.5, rel=0.1), # First backoff. - pytest.approx(1.0, rel=0.1), # Second (increased) backoff. - ] - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) + async with http.arequest(method, url, headers) as response: + assert backend.pop_open_tcp_stream_intervals() == [ + pytest.approx(0, abs=5e-3), # Retry immediately. + pytest.approx(0.5, rel=0.1), # First backoff. + pytest.approx(1.0, rel=0.1), # Second (increased) backoff. + ] + status_code, _, stream, _ = response + assert status_code == 200 + await read_body(stream) # Non-connect exceptions are not retried on. backend.push(httpcore.ReadTimeout(), httpcore.NetworkError()) with pytest.raises(httpcore.ReadTimeout): - await http.arequest(method, url, headers) + async with http.arequest(method, url, headers) as response: + pass # pragma: no cover + with pytest.raises(httpcore.NetworkError): - await http.arequest(method, url, headers) + async with http.arequest(method, url, headers) as response: + pass # pragma: no cover @pytest.mark.anyio @@ -138,12 +140,13 @@ async def test_retries_exceeded(server: Server) -> None: async with httpcore.AsyncConnectionPool( retries=retries, max_keepalive_connections=0, backend=backend ) as http: - response = await http.arequest(method, url, headers) - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, _, stream, _ = response + assert status_code == 200 + await read_body(stream) # First failure is retried on, second one isn't. backend.push(httpcore.ConnectError(), httpcore.ConnectTimeout()) with pytest.raises(httpcore.ConnectTimeout): - await http.arequest(method, url, headers) + async with http.arequest(method, url, headers) as response: + pass # pragma: no cover diff --git a/tests/sync_tests/test_connection_pool.py b/tests/sync_tests/test_connection_pool.py index ca5cb433..d3d605fa 100644 --- a/tests/sync_tests/test_connection_pool.py +++ b/tests/sync_tests/test_connection_pool.py @@ -1,13 +1,14 @@ -from typing import Iterator, Tuple +from typing import Iterable, Iterator, Tuple import pytest import httpcore from httpcore._async.base import ConnectionState +from httpcore._compat import ExitStack, contextmanager from httpcore._types import URL, Headers -class MockConnection(object): +class MockConnection(httpcore.SyncHTTPTransport): def __init__(self, http_version): self.origin = (b"http", b"example.org", 80) self.state = ConnectionState.PENDING @@ -15,14 +16,15 @@ def __init__(self, http_version): self.is_http2 = http_version == "HTTP/2" self.stream_count = 0 + @contextmanager def request( self, method: bytes, url: URL, headers: Headers = None, - stream: httpcore.SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Tuple[int, Headers, httpcore.SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: self.state = ConnectionState.ACTIVE self.stream_count += 1 @@ -34,11 +36,12 @@ def on_close(): def iterator() -> Iterator[bytes]: yield b"" - stream = httpcore.IteratorByteStream( - iterator=iterator(), close_func=on_close - ) + stream = iterator() - return 200, [], stream, {} + try: + yield 200, [], stream, {} + finally: + on_close() def close(self): pass @@ -63,14 +66,8 @@ def _create_connection(self, **kwargs): return MockConnection(self.http_version) -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - body = [] - for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - stream.close() +def read_body(stream: Iterable[bytes]) -> bytes: + return b"".join([chunk for chunk in stream]) @@ -80,21 +77,25 @@ def test_sequential_requests(http_version) -> None: info = http.get_connection_info() assert info == {} - response = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with http.request( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + read_body(stream) - read_body(stream) info = http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} - response = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with http.request( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + read_body(stream) - read_body(stream) info = http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} @@ -105,25 +106,36 @@ def test_concurrent_requests_h11() -> None: info = http.get_connection_info() assert info == {} - response_1 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with ExitStack() as exit_stack2: + with ExitStack() as exit_stack1: + response_1 = exit_stack1.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + response_2 = exit_stack2.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 + info = http.get_connection_info() + assert info == { + "http://example.org": [ + "ConnectionState.ACTIVE", + "ConnectionState.ACTIVE", + ] + } + + read_body(stream_1) + + info = http.get_connection_info() + assert info == { + "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] + } + + read_body(stream_2) - response_2 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.ACTIVE"] - } - - read_body(stream_1) - info = http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] - } - - read_body(stream_2) info = http.get_connection_info() assert info == { "http://example.org": ["ConnectionState.IDLE", "ConnectionState.IDLE"] @@ -136,20 +148,29 @@ def test_concurrent_requests_h2() -> None: info = http.get_connection_info() assert info == {} - response_1 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with ExitStack() as exit_stack2: + with ExitStack() as exit_stack1: + response_1 = exit_stack1.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} - response_2 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + response_2 = exit_stack2.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 - read_body(stream_1) - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + read_body(stream_1) + + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + read_body(stream_2) - read_body(stream_2) info = http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index e1db92c8..89b2de77 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -1,9 +1,12 @@ import platform import sys +from functools import partial +from typing import Iterable import pytest import httpcore +from httpcore._compat import ExitStack from httpcore._types import URL from tests.conftest import HTTPS_SERVER_URL from tests.utils import Server, lookup_sync_backend @@ -14,14 +17,8 @@ def backend(request): return request.param -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - body = [] - for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - stream.close() +def read_body(stream: Iterable[bytes]) -> bytes: + return b"".join([chunk for chunk in stream]) @@ -30,8 +27,9 @@ def test_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -45,8 +43,9 @@ def test_https_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if https_server.sends_reason else "" @@ -61,7 +60,8 @@ def test_request_unsupported_protocol(backend: str) -> None: url = (b"ftp", b"example.org", 443, b"/") headers = [(b"host", b"example.org")] with pytest.raises(httpcore.UnsupportedProtocol): - http.request(method, url, headers) + with http.request(method, url, headers): + pass # pragma: no cover @@ -70,8 +70,9 @@ def test_http2_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/2"} @@ -84,8 +85,9 @@ def test_closing_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header, (b"connection", b"close")] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -99,8 +101,9 @@ def test_http_request_reuse_connection(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -110,8 +113,9 @@ def test_http_request_reuse_connection(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -127,8 +131,9 @@ def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if https_server.sends_reason else "" @@ -138,8 +143,9 @@ def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if https_server.sends_reason else "" @@ -155,8 +161,9 @@ def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -170,8 +177,9 @@ def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -194,8 +202,9 @@ def test_http_proxy( max_connections=max_connections, backend=backend, ) as http: - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -220,7 +229,8 @@ def test_proxy_socket_does_not_leak_when_the_connection_hasnt_been_added_to_pool with httpcore.SyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http: for _ in range(100): try: - _ = http.request(method, url, headers) + with http.request(method, url, headers) as _: + pass except (httpcore.ProxyError, httpcore.RemoteProtocolError): pass @@ -243,8 +253,9 @@ def test_http_request_local_address(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -272,8 +283,9 @@ def test_proxy_https_requests( max_connections=max_connections, http2=http2, ) as http: - status_code, headers, stream, ext = http.request(method, url, headers) - _ = read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + _ = read_body(stream) assert status_code == 200 assert ext["http_version"] == "HTTP/2" if http2 else "HTTP/1.1" @@ -325,15 +337,20 @@ def test_connection_pool_get_connection_info( url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - _, _, stream_1, _ = http.request(method, url, headers) - _, _, stream_2, _ = http.request(method, url, headers) + with ExitStack() as exit_stack: + _, _, stream_1, _ = exit_stack.enter_context( + http.request(method, url, headers) + ) + _, _, stream_2, _ = exit_stack.enter_context( + http.request(method, url, headers) + ) - try: - stats = http.get_connection_info() - assert stats == expected_during_active - finally: - read_body(stream_1) - read_body(stream_2) + try: + stats = http.get_connection_info() + assert stats == expected_during_active + finally: + read_body(stream_1) + read_body(stream_2) stats = http.get_connection_info() assert stats == expected_during_idle @@ -355,12 +372,13 @@ def test_http_request_unix_domain_socket( method = b"GET" url = (b"http", b"localhost", None, b"/") headers = [(b"host", b"localhost")] - status_code, headers, stream, ext = http.request(method, url, headers) - assert status_code == 200 - reason = "OK" if uds_server.sends_reason else "" - assert ext == {"http_version": "HTTP/1.1", "reason": reason} - body = read_body(stream) - assert body == b"Hello, world!" + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + assert status_code == 200 + expected_reason = "OK" if uds_server.sends_reason else "" + assert ext == {"http_version": "HTTP/1.1", "reason": expected_reason} + body = read_body(stream) + assert body == b"Hello, world!" @pytest.mark.parametrize("max_keepalive", [1, 3, 5]) @@ -376,19 +394,17 @@ def test_max_keepalive_connections_handled_correctly( url = (b"http", *server.netloc, b"/") headers = [server.host_header] - connections_streams = [] - for _ in range(connections_number): - _, _, stream, _ = http.request(method, url, headers) - connections_streams.append(stream) + with ExitStack() as exit_stack: + for _ in range(connections_number): + _, _, stream, _ = exit_stack.enter_context( + http.request(method, url, headers) + ) + exit_stack.callback(partial(read_body, stream)) - try: - for i in range(len(connections_streams)): - read_body(connections_streams[i]) - finally: - stats = http.get_connection_info() + stats = http.get_connection_info() - connections_in_pool = next(iter(stats.values())) - assert len(connections_in_pool) == min(connections_number, max_keepalive) + connections_in_pool = next(iter(stats.values())) + assert len(connections_in_pool) == min(connections_number, max_keepalive) @@ -397,8 +413,9 @@ def test_explicit_backend_name(server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -425,10 +442,9 @@ def test_broken_socket_detection_many_open_files( # * Second attempt would have failed without a fix, due to a "filedescriptor # out of range in select()" exception. for _ in range(2): - status_code, response_headers, stream, ext = http.request( - method, url, headers - ) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, response_headers, stream, ext = response + read_body(stream) assert status_code == 200 reason = "OK" if server.sends_reason else "" @@ -453,7 +469,8 @@ def test_cannot_connect_tcp(backend: str, url) -> None: with httpcore.SyncConnectionPool(backend=backend) as http: method = b"GET" with pytest.raises(httpcore.ConnectError): - http.request(method, url) + with http.request(method, url) as _: + pass # pragma: no cover @@ -466,7 +483,8 @@ def test_cannot_connect_uds(backend: str) -> None: url = (b"http", b"localhost", None, b"/") with httpcore.SyncConnectionPool(backend=backend, uds=uds) as http: with pytest.raises(httpcore.ConnectError): - http.request(method, url) + with http.request(method, url): + pass # pragma: no cover @pytest.mark.skipif( @@ -484,7 +502,8 @@ def test_connection_timeout_tcp(backend: str, server: Server) -> None: with httpcore.SyncConnectionPool(backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - http.request(method, url, headers, ext=ext) + with http.request(method, url, headers, ext=ext): + pass # pragma: no cover @pytest.mark.skipif( @@ -504,4 +523,5 @@ def test_connection_timeout_uds( with httpcore.SyncConnectionPool(uds=uds, backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - http.request(method, url, headers, ext=ext) + with http.request(method, url, headers, ext=ext): + pass # pragma: no cover diff --git a/tests/sync_tests/test_retries.py b/tests/sync_tests/test_retries.py index d6a5aa77..c1deaa93 100644 --- a/tests/sync_tests/test_retries.py +++ b/tests/sync_tests/test_retries.py @@ -1,6 +1,6 @@ import queue import time -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional import pytest @@ -32,11 +32,8 @@ def open_tcp_stream(self, *args: Any, **kwargs: Any) -> SyncSocketStream: return super().open_tcp_stream(*args, **kwargs) -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - return b"".join([chunk for chunk in stream]) - finally: - stream.close() +def read_body(stream: Iterable[bytes]) -> bytes: + return b"".join([chunk for chunk in stream]) @@ -52,18 +49,20 @@ def test_no_retries(server: Server) -> None: with httpcore.SyncConnectionPool( max_keepalive_connections=0, backend=backend ) as http: - response = http.request(method, url, headers) - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) + with http.request(method, url, headers) as response: + status_code, _, stream, _ = response + assert status_code == 200 + read_body(stream) backend.push(httpcore.ConnectTimeout(), httpcore.ConnectError()) with pytest.raises(httpcore.ConnectTimeout): - http.request(method, url, headers) + with http.request(method, url, headers) as response: + pass # pragma: no cover with pytest.raises(httpcore.ConnectError): - http.request(method, url, headers) + with http.request(method, url, headers) as response: + pass # pragma: no cover @@ -82,21 +81,21 @@ def test_retries_enabled(server: Server) -> None: retries=retries, max_keepalive_connections=0, backend=backend ) as http: # Standard case, no failures. - response = http.request(method, url, headers) - assert backend.pop_open_tcp_stream_intervals() == [] - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) + with http.request(method, url, headers) as response: + assert backend.pop_open_tcp_stream_intervals() == [] + status_code, _, stream, _ = response + assert status_code == 200 + read_body(stream) # One failure, then success. backend.push(httpcore.ConnectError(), None) - response = http.request(method, url, headers) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - ] - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) + with http.request(method, url, headers) as response: + assert backend.pop_open_tcp_stream_intervals() == [ + pytest.approx(0, abs=5e-3), # Retry immediately. + ] + status_code, _, stream, _ = response + assert status_code == 200 + read_body(stream) # Three failures, then success. backend.push( @@ -105,22 +104,25 @@ def test_retries_enabled(server: Server) -> None: httpcore.ConnectTimeout(), None, ) - response = http.request(method, url, headers) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - pytest.approx(0.5, rel=0.1), # First backoff. - pytest.approx(1.0, rel=0.1), # Second (increased) backoff. - ] - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) + with http.request(method, url, headers) as response: + assert backend.pop_open_tcp_stream_intervals() == [ + pytest.approx(0, abs=5e-3), # Retry immediately. + pytest.approx(0.5, rel=0.1), # First backoff. + pytest.approx(1.0, rel=0.1), # Second (increased) backoff. + ] + status_code, _, stream, _ = response + assert status_code == 200 + read_body(stream) # Non-connect exceptions are not retried on. backend.push(httpcore.ReadTimeout(), httpcore.NetworkError()) with pytest.raises(httpcore.ReadTimeout): - http.request(method, url, headers) + with http.request(method, url, headers) as response: + pass # pragma: no cover + with pytest.raises(httpcore.NetworkError): - http.request(method, url, headers) + with http.request(method, url, headers) as response: + pass # pragma: no cover @@ -138,12 +140,13 @@ def test_retries_exceeded(server: Server) -> None: with httpcore.SyncConnectionPool( retries=retries, max_keepalive_connections=0, backend=backend ) as http: - response = http.request(method, url, headers) - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) + with http.request(method, url, headers) as response: + status_code, _, stream, _ = response + assert status_code == 200 + read_body(stream) # First failure is retried on, second one isn't. backend.push(httpcore.ConnectError(), httpcore.ConnectTimeout()) with pytest.raises(httpcore.ConnectTimeout): - http.request(method, url, headers) + with http.request(method, url, headers) as response: + pass # pragma: no cover diff --git a/tests/test_threadsafety.py b/tests/test_threadsafety.py index 81cdd95f..d491833f 100644 --- a/tests/test_threadsafety.py +++ b/tests/test_threadsafety.py @@ -1,4 +1,5 @@ import concurrent.futures +from typing import Iterable import pytest @@ -7,11 +8,8 @@ from .utils import Server -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - return b"".join(chunk for chunk in stream) - finally: - stream.close() +def read_body(stream: Iterable[bytes]) -> bytes: + return b"".join(chunk for chunk in stream) @pytest.mark.parametrize( @@ -30,8 +28,9 @@ def request(http: httpcore.SyncHTTPTransport) -> int: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, _, stream, _ = response + read_body(stream) return status_code with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: diff --git a/unasync.py b/unasync.py index d3b36993..6617bb22 100755 --- a/unasync.py +++ b/unasync.py @@ -4,8 +4,13 @@ import sys SUBS = [ - ('AsyncIteratorByteStream', 'IteratorByteStream'), ('AsyncIterator', 'Iterator'), + ('AsyncIterable', 'Iterable'), + ('asynccontextmanager', 'contextmanager'), + ('AsyncContextManager', 'ContextManager'), + ('AsyncExitStack', 'ExitStack'), + ('enter_async_context', 'enter_context'), + ('push_async_callback', 'callback'), ('AutoBackend', 'SyncBackend'), ('Async([A-Z][A-Za-z0-9_]*)', r'Sync\2'), ('async def', 'def'),