diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index f4ba9e40..8603a854 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -191,6 +191,36 @@ async def test_http_proxy( assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} +@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"]) +@pytest.mark.parametrize("protocol,port", [(b"http", 80), (b"https", 443)]) +@pytest.mark.trio +async def test_proxy_socket_does_not_leak_when_the_connection_hasnt_been_added_to_pool( + proxy_server: URL, + server: Server, + proxy_mode: str, + protocol: bytes, + port: int, +): + method = b"GET" + url = (protocol, b"example.com", port, b"/") + headers = [(b"host", b"example.org")] + + with pytest.warns(None) as recorded_warnings: + async with httpcore.AsyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http: + for _ in range(100): + try: + _ = await http.arequest(method, url, headers) + except httpcore.RemoteProtocolError: + pass + + # have to filter out https://github.com/encode/httpx/issues/825 from other tests + warnings_list = [ + *filter(lambda warn: "asyncio" not in warn.filename, recorded_warnings.list) + ] + + assert len(warnings_list) == 0 + + @pytest.mark.anyio async def test_http_request_local_address(backend: str, server: Server) -> None: if backend == "auto" and lookup_async_backend() == "trio": diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index 1e128290..5d384c00 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -191,6 +191,36 @@ def test_http_proxy( assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} +@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"]) +@pytest.mark.parametrize("protocol,port", [(b"http", 80), (b"https", 443)]) + +def test_proxy_socket_does_not_leak_when_the_connection_hasnt_been_added_to_pool( + proxy_server: URL, + server: Server, + proxy_mode: str, + protocol: bytes, + port: int, +): + method = b"GET" + url = (protocol, b"example.com", port, b"/") + headers = [(b"host", b"example.org")] + + with pytest.warns(None) as recorded_warnings: + with httpcore.SyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http: + for _ in range(100): + try: + _ = http.request(method, url, headers) + except httpcore.RemoteProtocolError: + pass + + # have to filter out https://github.com/encode/httpx/issues/825 from other tests + warnings_list = [ + *filter(lambda warn: "asyncio" not in warn.filename, recorded_warnings.list) + ] + + assert len(warnings_list) == 0 + + def test_http_request_local_address(backend: str, server: Server) -> None: if backend == "sync" and lookup_sync_backend() == "trio": diff --git a/tests/utils.py b/tests/utils.py index a6762bde..bb6a125d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,9 @@ import contextlib import socket import subprocess +import tempfile import time -from typing import Tuple +from typing import List, Tuple import sniffio @@ -48,13 +49,34 @@ def host_header(self) -> Tuple[bytes, bytes]: def http_proxy_server(proxy_host: str, proxy_port: int): proc = None - try: - command = ["pproxy", "-l", f"http://{proxy_host}:{proxy_port}/"] - proc = subprocess.Popen(command) - _wait_can_connect(proxy_host, proxy_port) + with create_proxy_block_file(["example.com"]) as block_file_name: + try: + command = [ + "pproxy", + "-b", + block_file_name, + "-l", + f"http://{proxy_host}:{proxy_port}/", + ] + proc = subprocess.Popen(command) + + _wait_can_connect(proxy_host, proxy_port) + + yield b"http", proxy_host.encode(), proxy_port, b"/" + finally: + if proc is not None: + proc.kill() + + +@contextlib.contextmanager +def create_proxy_block_file(blocked_domains: List[str]): + with tempfile.NamedTemporaryFile(delete=True, mode="w+") as file: + + for domain in blocked_domains: + file.write(domain) + file.write("\n") + + file.flush() - yield b"http", proxy_host.encode(), proxy_port, b"/" - finally: - if proc is not None: - proc.kill() + yield file.name