Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement curio backend #168

Merged
merged 10 commits into from
Sep 5, 2020
4 changes: 4 additions & 0 deletions httpcore/_backends/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def backend(self) -> AsyncBackend:
from .trio import TrioBackend

self._backend_implementation = TrioBackend()
elif backend == "curio":
from .curio import CurioBackend

self._backend_implementation = CurioBackend()
else: # pragma: nocover
raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
return self._backend_implementation
Expand Down
202 changes: 202 additions & 0 deletions httpcore/_backends/curio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import select
from ssl import SSLContext, SSLSocket
from typing import Optional

import curio
import curio.io

from .._exceptions import (
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._types import TimeoutDict
from .._utils import get_logger
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream

logger = get_logger(__name__)

ONE_DAY_IN_SECONDS = float(60 * 60 * 24)


def convert_timeout(value: Optional[float]) -> float:
return value if value is not None else ONE_DAY_IN_SECONDS


class Lock(AsyncLock):
def __init__(self) -> None:
self._lock = curio.Lock()

async def acquire(self) -> None:
await self._lock.acquire()

async def release(self) -> None:
await self._lock.release()


class Semaphore(AsyncSemaphore):
def __init__(self, max_value: int, exc_class: type) -> None:
self.max_value = max_value
self.exc_class = exc_class

@property
def semaphore(self) -> curio.Semaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = curio.Semaphore(value=self.max_value)
return self._semaphore

async def acquire(self, timeout: float = None) -> None:
timeout = convert_timeout(timeout)

try:
return await curio.timeout_after(timeout, self.semaphore.acquire())
except curio.TaskTimeout:
raise self.exc_class()

async def release(self) -> None:
await self.semaphore.release()


class SocketStream(AsyncSocketStream):
def __init__(self, socket: curio.io.Socket) -> None:
self.read_lock = curio.Lock()
self.write_lock = curio.Lock()
self.socket = socket
self.stream = socket.as_stream()

def get_http_version(self) -> str:
if hasattr(self.socket, "_socket"):
raw_socket = self.socket._socket

if isinstance(raw_socket, SSLSocket):
ident = raw_socket.selected_alpn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"

return "HTTP/1.1"

async def start_tls(
self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict
) -> "AsyncSocketStream":
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}

with map_exceptions(exc_map):
wrapped_sock = curio.io.Socket(
ssl_context.wrap_socket(
self.socket._socket,
do_handshake_on_connect=False,
server_hostname=hostname.decode("ascii"),
)
)

await curio.timeout_after(
connect_timeout,
wrapped_sock.do_handshake(),
)

return SocketStream(wrapped_sock)

async def read(self, n: int, timeout: TimeoutDict) -> bytes:
read_timeout = convert_timeout(timeout.get("read"))
exc_map = {
curio.TaskTimeout: ReadTimeout,
curio.CurioError: ReadError,
OSError: ReadError,
}

with map_exceptions(exc_map):
async with self.read_lock:
return await curio.timeout_after(read_timeout, self.stream.read(n))

async def write(self, data: bytes, timeout: TimeoutDict) -> None:
write_timeout = convert_timeout(timeout.get("write"))
exc_map = {
curio.TaskTimeout: WriteTimeout,
curio.CurioError: WriteError,
OSError: WriteError,
}

with map_exceptions(exc_map):
async with self.write_lock:
await curio.timeout_after(write_timeout, self.stream.write(data))

async def aclose(self) -> None:
await self.stream.close()
await self.socket.close()

def is_connection_dropped(self) -> bool:
rready, _, _ = select.select([self.socket.fileno()], [], [], 0)

return bool(rready)


class CurioBackend(AsyncBackend):
async def open_tcp_stream(
self,
hostname: bytes,
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
*,
local_address: Optional[str],
) -> AsyncSocketStream:
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}
host = hostname.decode("ascii")
kwargs = (
{} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host}
)

with map_exceptions(exc_map):
sock: curio.io.Socket = await curio.timeout_after(
connect_timeout,
curio.open_connection(hostname, port, **kwargs),
)

return SocketStream(sock)

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
connect_timeout = convert_timeout(timeout.get("connect"))
exc_map = {
curio.TaskTimeout: ConnectTimeout,
curio.CurioError: ConnectError,
OSError: ConnectError,
}
host = hostname.decode("ascii")
kwargs = (
{} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host}
)

with map_exceptions(exc_map):
sock: curio.io.Socket = await curio.timeout_after(
connect_timeout, curio.open_unix_connection(path, **kwargs)
)

return SocketStream(sock)

def create_lock(self) -> AsyncLock:
return Lock()

def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
return Semaphore(max_value, exc_class)

async def time(self) -> float:
return await curio.clock()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Optionals
trio
trio-typing
curio

# Docs
mkautodoc
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_packages(package):
"Topic :: Internet :: WWW/HTTP",
"Framework :: AsyncIO",
"Framework :: Trio",
"Framework :: Curio",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
Expand Down
26 changes: 23 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,42 @@

from httpcore._types import URL

from .marks.curio import curio_pytest_pycollect_makeitem, curio_pytest_pyfunc_call

PROXY_HOST = "127.0.0.1"
PROXY_PORT = 8080


def pytest_configure(config):
config.addinivalue_line(
"markers",
"curio: mark the test as a coroutine, it will be run using a Curio kernel.",
)


@pytest.mark.tryfirst
def pytest_pycollect_makeitem(collector, name, obj):
curio_pytest_pycollect_makeitem(collector, name, obj)


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_pyfunc_call(pyfuncitem):
yield from curio_pytest_pyfunc_call(pyfuncitem)
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture(
params=[
pytest.param("asyncio", marks=pytest.mark.asyncio),
pytest.param("trio", marks=pytest.mark.trio),
pytest.param("curio", marks=pytest.mark.curio),
]
)
def async_environment(request: typing.Any) -> str:
"""
Mark a test function to be run on both asyncio and trio.
Mark a test function to be run on asyncio, trio and curio.
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved

Equivalent to having a pair of tests, each respectively marked with
'@pytest.mark.asyncio' and '@pytest.mark.trio'.
Equivalent to having three tests, each respectively marked with
'@pytest.mark.asyncio', '@pytest.mark.trio' and '@pytest.mark.curio'.

Intended usage:

Expand Down
Empty file added tests/marks/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions tests/marks/curio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import functools
import inspect

import curio
import curio.debug
import curio.meta
import curio.monitor
import pytest


def _is_coroutine(obj):
"""Check to see if an object is really a coroutine."""
return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)


@pytest.mark.tryfirst
def curio_pytest_pycollect_makeitem(collector, name, obj):
"""A pytest hook to collect coroutines in a test module."""
if collector.funcnamefilter(name) and _is_coroutine(obj):
item = pytest.Function.from_parent(collector, name=name)
if "curio" in item.keywords:
return list(collector._genfunctions(name, obj)) # pragma: nocover


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def curio_pytest_pyfunc_call(pyfuncitem):
"""Run curio marked test functions in a Curio kernel
instead of a normal function call."""
if pyfuncitem.get_closest_marker("curio"):
pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj)
yield


def wrap_in_sync(func):
"""Return a sync wrapper around an async function executing it in a Kernel."""

@functools.wraps(func)
def inner(**kwargs):
coro = func(**kwargs)
curio.Kernel().run(coro, shutdown=True)

return inner
1 change: 1 addition & 0 deletions unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
('__aiter__', '__iter__'),
('@pytest.mark.asyncio', ''),
('@pytest.mark.trio', ''),
('@pytest.mark.curio', ''),
('@pytest.mark.usefixtures.*', ''),
]
COMPILED_SUBS = [
Expand Down