Skip to content

Commit

Permalink
Fixed PR remarks (#94)
Browse files Browse the repository at this point in the history
Added timeout handling to Semaphore::acquire and tried to avoid private API usage in SocketStream::get_http_version, also changed is_connection_dropped behaviour
  • Loading branch information
cdeler committed Sep 1, 2020
1 parent 659bd94 commit dc306d0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
33 changes: 24 additions & 9 deletions httpcore/_backends/curio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ssl import SSLContext
from typing import Optional
import select
import socket
from ssl import SSLContext, SSLSocket
from typing import Dict, Optional, Type, Union

import curio
import curio.io
Expand All @@ -10,6 +12,7 @@
ConnectTimeout,
ReadError,
ReadTimeout,
TimeoutException,
WriteError,
WriteTimeout,
map_exceptions,
Expand Down Expand Up @@ -50,7 +53,13 @@ def semaphore(self) -> curio.Semaphore:
return self._semaphore

async def acquire(self, timeout: float = None) -> None:
await self.semaphore.acquire()
exc_map: Dict[Type[Exception], Type[Exception]] = {
curio.TaskTimeout: TimeoutException,
}
acquire_timeout: int = convert_timeout(timeout)

with map_exceptions(exc_map):
return await curio.timeout_after(acquire_timeout, self.semaphore.acquire())

async def release(self) -> None:
await self.semaphore.release()
Expand All @@ -64,10 +73,14 @@ def __init__(self, socket: curio.io.Socket) -> None:
self.stream = socket.as_stream()

def get_http_version(self) -> str:
if hasattr(self.socket, "_socket") and hasattr(self.socket._socket, "_sslobj"):
ident = self.socket._socket._sslobj.selected_alpn_protocol()
else:
ident = "http/1.1"
ident: Optional[str] = "http/1.1"

if hasattr(self.socket, "_socket"):
raw_socket: Union[SSLSocket, socket.socket] = self.socket._socket

if isinstance(raw_socket, SSLSocket):
ident = raw_socket.selected_alpn_protocol()

return "HTTP/2" if ident == "h2" else "HTTP/1.1"

async def start_tls(
Expand Down Expand Up @@ -118,11 +131,13 @@ async def write(self, data: bytes, timeout: TimeoutDict) -> None:
await curio.timeout_after(write_timeout, self.stream.write(data))

async def aclose(self) -> None:
# we dont need to close the self.socket, since it's closed by stream closing
await self.stream.close()
await self.socket.close()

def is_connection_dropped(self) -> bool:
return self.socket._closed
rready, _, _ = select.select([self.socket.fileno()], [], [], 0)

return bool(rready)


class CurioBackend(AsyncBackend):
Expand Down
2 changes: 1 addition & 1 deletion tests/marks/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def curio_pytest_pycollect_makeitem(collector, name, obj):
if collector.funcnamefilter(name) and _is_coroutine(obj):
item = pytest.Function.from_parent(collector, name=name)
if "curio" in item.keywords:
return list(collector._genfunctions(name, obj))
return list(collector._genfunctions(name, obj)) # pragma: nocover


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
Expand Down

0 comments on commit dc306d0

Please sign in to comment.