Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add socket-load-balance flag #2472

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ Options:
buffer of an incomplete event.
--factory Treat APP as an application factory, i.e. a
() -> <ASGI app> callable.
--socket-load-balance Use kernel support for socket load balancing
--help Show this message and exit.
```

Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ Options:
buffer of an incomplete event.
--factory Treat APP as an application factory, i.e. a
() -> <ASGI app> callable.
--socket-load-balance Use kernel support for socket load balancing
--help Show this message and exit.
```

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ filterwarnings = [
source_pkgs = ["uvicorn", "tests"]
plugins = ["coverage_conditional_plugin"]
omit = ["uvicorn/workers.py", "uvicorn/__main__.py"]
concurrency = ["multiprocessing", "thread"]
parallel = true
sigterm = true

[tool.coverage.report]
precision = 2
Expand Down
1 change: 1 addition & 0 deletions scripts/coverage
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ export SOURCE_FILES="uvicorn tests"

set -x

${PREFIX}coverage combine
${PREFIX}coverage report
75 changes: 74 additions & 1 deletion tests/supervisors/test_multiprocess.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations

import dataclasses
import functools
import multiprocessing.managers
import os
import signal
import socket
import sys
import threading
import time
from typing import Any, Callable
from typing import Any, Callable, Generic, TypeVar

import httpx
import pytest

from uvicorn import Config
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.server import Server
from uvicorn.supervisors import Multiprocess
from uvicorn.supervisors.multiprocess import Process

Expand Down Expand Up @@ -169,3 +174,71 @@ def test_multiprocess_sigttou() -> None:
assert len(supervisor.processes) == 1
supervisor.signal_queue.append(signal.SIGINT)
supervisor.join_all()


T = TypeVar("T")


@dataclasses.dataclass
class Box(Generic[T]):
v: T


async def lb_app(
d: multiprocessing.managers.DictProxy,
started: threading.Event,
scope: Scope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
) -> None: # pragma: py-darwin pragma: py-win32
if scope["type"] == "lifespan":
await receive()
scope["state"]["count"] = box = Box(0)
await send({"type": "lifespan.startup.complete"})
started.set()
await receive()
d[os.getpid()] = box.v
await send({"type": "lifespan.shutdown.complete"})
return

scope["state"]["count"].v += 1
headers = [(b"content-type", b"text/plain")]
await send({"type": "http.response.start", "status": 200, "headers": headers})
await send({"type": "http.response.body", "body": b"hello"})


@pytest.mark.skipif(
not ((sys.platform == "linux" and hasattr(socket, "SO_REUSEPORT")) or hasattr(socket, "SO_REUSEPORT_LB")),
reason="unsupported",
)
def test_multiprocess_socket_balance() -> None: # pragma: py-darwin pragma: py-win32
with multiprocessing.Manager() as m:
started = m.Event()
d = m.dict()
app = functools.partial(lb_app, d, started)
config = Config(app=app, workers=2, socket_load_balance=True, port=0, interface="asgi3")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs a test for two independent calls to Multiprocess(...).run() running on the same port:

def test_bind_to_used_port():
    port = ephemeral_port_reserve.reserve() 
    config = Config(app=app, workers=2, socket_load_balance=True, port=port, interface="asgi3")
    with multiprocessing.get_context("spawn").Pool(max_workers=1) as pool:
        try:
            f1 = pool.apply_async(uvicorn.main, ...)
            # wait for started
            with pytest.raises(OSError, match="already bound"):
                uvicorn.main(...)
        finally: pool.terminate()

server = Server(config=config)
with config.bind_socket() as sock:
port = sock.getsockname()[1]
try:
supervisor = Multiprocess(config, target=server.run, sockets=[sock])
threading.Thread(target=supervisor.run, daemon=True).start()
if not started.wait(timeout=5): # pragma: no cover
raise TimeoutError
with httpx.Client():
for i in range(100):
httpx.get(f"http://localhost:{port}/").raise_for_status()
finally:
supervisor.signal_queue.append(signal.SIGINT)
supervisor.join_all()
min_conn, max_conn = sorted(d.values())
assert (max_conn - min_conn) < 25


def test_multiprocess_not_supported(monkeypatch):
monkeypatch.delattr(socket, "SO_REUSEPORT")
config = Config(app=app, workers=2, socket_load_balance=True, port=0, interface="asgi3")
with config.bind_socket() as sock:
supervisor = Multiprocess(config, target=run, sockets=[sock])
with pytest.raises(RuntimeError, match="socket_load_balance not supported"):
supervisor.run()
4 changes: 2 additions & 2 deletions tests/supervisors/test_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_should_not_reload_when_exclude_pattern_match_file_is_changed(self, touc
app="tests.test_config:asgi_app",
reload=True,
reload_includes=["*"],
reload_excludes=["*.js"],
reload_excludes=["*.js", ".coverage.*"],
)
reloader = self._setup_reloader(config)

Expand Down Expand Up @@ -242,7 +242,7 @@ def test_override_defaults(self, touch_soon) -> None:
reload=True,
# We need to add *.txt otherwise no regular files will match
reload_includes=[".*", "*.txt"],
reload_excludes=["*.py"],
reload_excludes=["*.py", ".coverage.*"],
)
reloader = self._setup_reloader(config)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
from unittest.mock import patch

from uvicorn._subprocess import SpawnProcess, get_subprocess, subprocess_started
from uvicorn._subprocess import SocketSharePickle, SpawnProcess, get_subprocess, subprocess_started
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.config import Config

Expand Down Expand Up @@ -36,7 +36,7 @@ def test_subprocess_started() -> None:

with patch("tests.test_subprocess.server_run") as mock_run:
with patch.object(config, "configure_logging") as mock_config_logging:
subprocess_started(config, server_run, [fdsock], None)
subprocess_started(config, server_run, [SocketSharePickle(fdsock)], None)
mock_run.assert_called_once()
mock_config_logging.assert_called_once()

Expand Down
47 changes: 42 additions & 5 deletions uvicorn/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import multiprocessing
import os
import socket
import sys
from multiprocessing.context import SpawnProcess
from socket import socket
from typing import Callable

from uvicorn.config import Config
Expand All @@ -18,10 +18,42 @@
spawn = multiprocessing.get_context("spawn")


class SocketSharePickle:
def __init__(self, sock: socket.socket):
self._sock = sock

def get(self) -> socket.socket:
return self._sock


class SocketShareRebind:
def __init__(self, sock: socket.socket):
if not (sys.platform == "linux" and hasattr(socket, "SO_REUSEPORT")) or hasattr(socket, "SO_REUSEPORT_LB"):
raise RuntimeError("socket_load_balance not supported")
else: # pragma: py-darwin pragma: py-win32
sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1)
self._family = sock.family
self._type = sock.type
self._proto = sock.proto
self._sockname = sock.getsockname()

def get(self) -> socket.socket: # pragma: py-darwin pragma: py-win32
try:
sock = socket.socket(family=self._family, type=self._type, proto=self._proto)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1)

sock.bind(self._sockname)
return sock
except BaseException: # pragma: no cover
sock.close()
raise


def get_subprocess(
config: Config,
target: Callable[..., None],
sockets: list[socket],
sockets: list[socket.socket],
) -> SpawnProcess:
"""
Called in the parent process, to instantiate a new child process instance.
Expand All @@ -41,10 +73,15 @@ def get_subprocess(
except (AttributeError, OSError):
stdin_fileno = None

socket_shares: list[SocketShareRebind] | list[SocketSharePickle]
if config.socket_load_balance: # pragma: py-darwin pragma: py-win32
socket_shares = [SocketShareRebind(s) for s in sockets]
else:
socket_shares = [SocketSharePickle(s) for s in sockets]
kwargs = {
"config": config,
"target": target,
"sockets": sockets,
"sockets": socket_shares,
"stdin_fileno": stdin_fileno,
}

Expand All @@ -54,7 +91,7 @@ def get_subprocess(
def subprocess_started(
config: Config,
target: Callable[..., None],
sockets: list[socket],
sockets: list[SocketSharePickle] | list[SocketShareRebind],
stdin_fileno: int | None,
) -> None:
"""
Expand All @@ -77,7 +114,7 @@ def subprocess_started(

try:
# Now we can call into `Server.run(sockets=sockets)`
target(sockets=sockets)
target(sockets=[s.get() for s in sockets])
except KeyboardInterrupt: # pragma: no cover
# supress the exception to avoid a traceback from subprocess.Popen
# the parent already expects us to end, so no vital information is lost
Expand Down
2 changes: 2 additions & 0 deletions uvicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
headers: list[tuple[str, str]] | None = None,
factory: bool = False,
h11_max_incomplete_event_size: int | None = None,
socket_load_balance: bool = False,
):
self.app = app
self.host = host
Expand Down Expand Up @@ -268,6 +269,7 @@ def __init__(
self.encoded_headers: list[tuple[bytes, bytes]] = []
self.factory = factory
self.h11_max_incomplete_event_size = h11_max_incomplete_event_size
self.socket_load_balance = socket_load_balance

self.loaded = False
self.configure_logging()
Expand Down
19 changes: 15 additions & 4 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No
help="Treat APP as an application factory, i.e. a () -> <ASGI app> callable.",
show_default=True,
)
@click.option(
"--socket-load-balance",
is_flag=True,
default=False,
help="Use kernel support for socket load balancing",
show_default=True,
)
def main(
app: str,
host: str,
Expand Down Expand Up @@ -408,6 +415,7 @@ def main(
app_dir: str,
h11_max_incomplete_event_size: int | None,
factory: bool,
socket_load_balance: bool = False,
) -> None:
run(
app,
Expand Down Expand Up @@ -457,6 +465,7 @@ def main(
factory=factory,
app_dir=app_dir,
h11_max_incomplete_event_size=h11_max_incomplete_event_size,
socket_load_balance=socket_load_balance,
)


Expand Down Expand Up @@ -509,6 +518,7 @@ def run(
app_dir: str | None = None,
factory: bool = False,
h11_max_incomplete_event_size: int | None = None,
socket_load_balance: bool = False,
) -> None:
if app_dir is not None:
sys.path.insert(0, app_dir)
Expand Down Expand Up @@ -560,6 +570,7 @@ def run(
use_colors=use_colors,
factory=factory,
h11_max_incomplete_event_size=h11_max_incomplete_event_size,
socket_load_balance=socket_load_balance,
)
server = Server(config=config)

Expand All @@ -570,11 +581,11 @@ def run(

try:
if config.should_reload:
sock = config.bind_socket()
ChangeReload(config, target=server.run, sockets=[sock]).run()
with config.bind_socket() as sock:
ChangeReload(config, target=server.run, sockets=[sock]).run()
elif config.workers > 1:
sock = config.bind_socket()
Multiprocess(config, target=server.run, sockets=[sock]).run()
with config.bind_socket() as sock:
Multiprocess(config, target=server.run, sockets=[sock]).run()
else:
server.run()
except KeyboardInterrupt:
Expand Down
Loading