diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py index c959ddbb..ef9a11c1 100644 --- a/httpcore/_backends/curio.py +++ b/httpcore/_backends/curio.py @@ -1,5 +1,7 @@ -from ssl import SSLContext -from typing import Optional +import select +import socket +from ssl import SSLContext, SSLSocket +from typing import Dict, Optional, Type, Union import curio import curio.io @@ -10,6 +12,7 @@ ConnectTimeout, ReadError, ReadTimeout, + TimeoutException, WriteError, WriteTimeout, map_exceptions, @@ -50,7 +53,13 @@ def semaphore(self) -> curio.Semaphore: return self._semaphore async def acquire(self, timeout: float = None) -> None: - await self.semaphore.acquire() + exc_map: Dict[Type[Exception], Type[Exception]] = { + curio.TaskTimeout: TimeoutException, + } + acquire_timeout: int = convert_timeout(timeout) + + with map_exceptions(exc_map): + return await curio.timeout_after(acquire_timeout, self.semaphore.acquire()) async def release(self) -> None: await self.semaphore.release() @@ -64,10 +73,14 @@ def __init__(self, socket: curio.io.Socket) -> None: self.stream = socket.as_stream() def get_http_version(self) -> str: - if hasattr(self.socket, "_socket") and hasattr(self.socket._socket, "_sslobj"): - ident = self.socket._socket._sslobj.selected_alpn_protocol() - else: - ident = "http/1.1" + ident: Optional[str] = "http/1.1" + + if hasattr(self.socket, "_socket"): + raw_socket: Union[SSLSocket, socket.socket] = self.socket._socket + + if isinstance(raw_socket, SSLSocket): + ident = raw_socket.selected_alpn_protocol() + return "HTTP/2" if ident == "h2" else "HTTP/1.1" async def start_tls( @@ -118,11 +131,13 @@ async def write(self, data: bytes, timeout: TimeoutDict) -> None: await curio.timeout_after(write_timeout, self.stream.write(data)) async def aclose(self) -> None: - # we dont need to close the self.socket, since it's closed by stream closing await self.stream.close() + await self.socket.close() def is_connection_dropped(self) -> bool: - return self.socket._closed + rready, _, _ = select.select([self.socket.fileno()], [], [], 0) + + return bool(rready) class CurioBackend(AsyncBackend): diff --git a/tests/marks/curio.py b/tests/marks/curio.py index 7504ee8f..616b2766 100644 --- a/tests/marks/curio.py +++ b/tests/marks/curio.py @@ -19,7 +19,7 @@ def curio_pytest_pycollect_makeitem(collector, name, obj): if collector.funcnamefilter(name) and _is_coroutine(obj): item = pytest.Function.from_parent(collector, name=name) if "curio" in item.keywords: - return list(collector._genfunctions(name, obj)) + return list(collector._genfunctions(name, obj)) # pragma: nocover @pytest.hookimpl(tryfirst=True, hookwrapper=True)