diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..41bf0c98 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,32 @@ +name: lint + +on: + pull_request: + types: [opened, synchronize] + branches: + - master + +env: + MATURIN_VERSION: 1.2.3 + PYTHON_VERSION: 3.11 + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ env.PYTHON_VERSION }} + uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + - name: Install + run: | + python -m venv .venv + source .venv/bin/activate + pip install maturin==${{ env.MATURIN_VERSION }} + maturin develop --extras=lint + - name: Lint + run: | + source .venv/bin/activate + make lint diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..883e8a08 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +fail_fast: true + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-yaml + - id: check-toml + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-added-large-files + +- repo: local + hooks: + - id: lint-python + name: Lint Python + entry: make lint-python + types: [python] + language: system + pass_filenames: false + - id: lint-rust + name: Lint Rust + entry: make lint-rust + types: [rust] + language: system + pass_filenames: false diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 00000000..75306517 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +max_width = 120 diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..32b26753 --- /dev/null +++ b/Makefile @@ -0,0 +1,67 @@ +.DEFAULT_GOAL := all +black = black granian tests +ruff = ruff granian tests + +.PHONY: build-dev +build-dev: + @rm -f granian/*.so + maturin develop --extras test + +.PHONY: format +format: + $(black) + $(ruff) --fix --exit-zero + cargo fmt + +.PHONY: lint-python +lint-python: + $(ruff) + $(black) --check --diff + +.PHONY: lint-rust +lint-rust: + cargo fmt --version + cargo fmt --all -- --check + cargo clippy --version + cargo clippy --tests -- \ + -D warnings \ + -W clippy::pedantic \ + -W clippy::dbg_macro \ + -W clippy::print_stdout \ + -A clippy::cast-possible-truncation \ + -A clippy::cast-possible-wrap \ + -A clippy::cast-precision-loss \ + -A clippy::cast-sign-loss \ + -A clippy::declare-interior-mutable-const \ + -A clippy::float-cmp \ + -A clippy::fn-params-excessive-bools \ + -A clippy::if-not-else \ + -A clippy::inline-always \ + -A clippy::manual-let-else \ + -A clippy::match-bool \ + -A clippy::match-same-arms \ + -A clippy::missing-errors-doc \ + -A clippy::missing-panics-doc \ + -A clippy::module-name-repetitions \ + -A clippy::must-use-candidate \ + -A clippy::needless-pass-by-value \ + -A clippy::similar-names \ + -A clippy::single-match-else \ + -A clippy::struct-excessive-bools \ + -A clippy::too-many-arguments \ + -A clippy::too-many-lines \ + -A clippy::type-complexity \ + -A clippy::unnecessary-wraps \ + -A clippy::unused-self \ + -A clippy::used-underscore-binding \ + -A clippy::wrong-self-convention + +.PHONY: lint +lint: lint-python lint-rust + +.PHONY: test +test: + pytest -v test + +.PHONY: all +all: format build-dev lint test diff --git a/granian/__init__.py b/granian/__init__.py index ce03110a..37ebb755 100644 --- a/granian/__init__.py +++ b/granian/__init__.py @@ -1 +1 @@ -from .server import Granian +from .server import Granian # noqa diff --git a/granian/__main__.py b/granian/__main__.py index c368da26..f1f1f304 100644 --- a/granian/__main__.py +++ b/granian/__main__.py @@ -1,3 +1,4 @@ from granian.cli import cli + cli() diff --git a/granian/__version__.py b/granian/__version__.py index 906d362f..ef7eb44d 100644 --- a/granian/__version__.py +++ b/granian/__version__.py @@ -1 +1 @@ -__version__ = "0.6.0" +__version__ = '0.6.0' diff --git a/granian/_granian.pyi b/granian/_granian.pyi index 797d95f0..836b1283 100644 --- a/granian/_granian.pyi +++ b/granian/_granian.pyi @@ -1,12 +1,10 @@ -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple from ._types import WebsocketMessage - class ASGIScope: def as_dict(self, root_path: str) -> Dict[str, Any]: ... - class RSGIHeaders: def __contains__(self, key: str) -> bool: ... def keys(self) -> List[str]: ... @@ -14,7 +12,6 @@ class RSGIHeaders: def items(self) -> List[Tuple[str]]: ... def get(self, key: str, default: Any = None) -> Any: ... - class RSGIScope: proto: str http_version: str @@ -29,12 +26,10 @@ class RSGIScope: @property def headers(self) -> RSGIHeaders: ... - class RSGIHTTPStreamTransport: async def send_bytes(self, data: bytes): ... async def send_str(self, data: str): ... - class RSGIHTTPProtocol: async def __call__(self) -> bytes: ... def response_empty(self, status: int, headers: List[Tuple[str, str]]): ... @@ -43,25 +38,17 @@ class RSGIHTTPProtocol: def response_file(self, status: int, headers: List[Tuple[str, str]], file: str): ... def response_stream(self, status: int, headers: List[Tuple[str, str]]) -> RSGIHTTPStreamTransport: ... - class RSGIWebsocketTransport: async def receive(self) -> WebsocketMessage: ... async def send_bytes(self, data: bytes): ... async def send_str(self, data: str): ... - class RSGIWebsocketProtocol: async def accept(self) -> RSGIWebsocketTransport: ... def close(self, status: Optional[int]) -> Tuple[int, bool]: ... - -class RSGIProtocolError(RuntimeError): - ... - - -class RSGIProtocolClosed(RuntimeError): - ... - +class RSGIProtocolError(RuntimeError): ... +class RSGIProtocolClosed(RuntimeError): ... class WSGIScope: def to_environ(self, environ: Dict[str, Any]) -> Dict[str, Any]: ... diff --git a/granian/_internal.py b/granian/_internal.py index 26fedbdf..6ebc300a 100644 --- a/granian/_internal.py +++ b/granian/_internal.py @@ -2,22 +2,21 @@ import re import sys import traceback - from types import ModuleType from typing import Callable, List, Optional def get_import_components(path: str) -> List[Optional[str]]: - return (re.split(r":(?![\\/])", path, 1) + [None])[:2] + return (re.split(r':(?![\\/])', path, 1) + [None])[:2] def prepare_import(path: str) -> str: path = os.path.realpath(path) fname, ext = os.path.splitext(path) - if ext == ".py": + if ext == '.py': path = fname - if os.path.basename(path) == "__init__": + if os.path.basename(path) == '__init__': path = os.path.dirname(path) module_name = [] @@ -27,26 +26,22 @@ def prepare_import(path: str) -> str: path, name = os.path.split(path) module_name.append(name) - if not os.path.exists(os.path.join(path, "__init__.py")): + if not os.path.exists(os.path.join(path, '__init__.py')): break if sys.path[0] != path: sys.path.insert(0, path) - return ".".join(module_name[::-1]) + return '.'.join(module_name[::-1]) -def load_module( - module_name: str, - raise_on_failure: bool = True -) -> Optional[ModuleType]: +def load_module(module_name: str, raise_on_failure: bool = True) -> Optional[ModuleType]: try: __import__(module_name) except ImportError: if sys.exc_info()[-1].tb_next: raise RuntimeError( - f"While importing '{module_name}', an ImportError was raised:" - f"\n\n{traceback.format_exc()}" + f"While importing '{module_name}', an ImportError was raised:" f"\n\n{traceback.format_exc()}" ) elif raise_on_failure: raise RuntimeError(f"Could not import '{module_name}'.") @@ -58,9 +53,9 @@ def load_module( def load_target(target: str) -> Callable[..., None]: path, name = get_import_components(target) path = prepare_import(path) if path else None - name = name or "app" + name = name or 'app' module = load_module(path) rv = module - for element in name.split("."): + for element in name.split('.'): rv = getattr(rv, element) return rv diff --git a/granian/_loops.py b/granian/_loops.py index b38f0adc..22145c46 100644 --- a/granian/_loops.py +++ b/granian/_loops.py @@ -2,12 +2,11 @@ import os import signal import sys - from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple class Registry: - __slots__ = ["_data"] + __slots__ = ['_data'] def __init__(self): self._data: Dict[str, Callable[..., Any]] = {} @@ -22,6 +21,7 @@ def register(self, key: str) -> Callable[[], Callable[..., Any]]: def wrap(builder: Callable[..., Any]) -> Callable[..., Any]: self._data[key] = builder return builder + return wrap def get(self, key: str) -> Callable[..., Any]: @@ -31,18 +31,13 @@ def get(self, key: str) -> Callable[..., Any]: raise RuntimeError(f"'{key}' implementation not available.") - class BuilderRegistry(Registry): __slots__ = [] def __init__(self): self._data: Dict[str, Tuple[Callable[..., Any], List[str]]] = {} - def register( - self, - key: str, - packages: Optional[List[str]] = None - ) -> Callable[[], Callable[..., Any]]: + def register(self, key: str, packages: Optional[List[str]] = None) -> Callable[[], Callable[..., Any]]: packages = packages or [] def wrap(builder: Callable[..., Any]) -> Callable[..., Any]: @@ -56,6 +51,7 @@ def wrap(builder: Callable[..., Any]) -> Callable[..., Any]: if implemented: self._data[key] = (builder, loaded_packages) return builder + return wrap def get(self, key: str) -> Callable[..., Any]: diff --git a/granian/asgi.py b/granian/asgi.py index e0b3fd5d..bb955e54 100644 --- a/granian/asgi.py +++ b/granian/asgi.py @@ -1,5 +1,4 @@ import asyncio - from functools import wraps from ._granian import ASGIScope as Scope @@ -22,12 +21,7 @@ def __init__(self, callable): async def handle(self): try: await self.callable( - { - "type": "lifespan", - "asgi": {"version": "3.0", "spec_version": "2.3"} - }, - self.receive, - self.send + {'type': 'lifespan', 'asgi': {'version': '3.0', 'spec_version': '2.3'}}, self.receive, self.send ) except Exception: self.errored = True @@ -43,7 +37,7 @@ async def startup(self): loop = asyncio.get_event_loop() _handler_task = loop.create_task(self.handle()) - await self.event_queue.put({"type": "lifespan.startup"}) + await self.event_queue.put({'type': 'lifespan.startup'}) await self.event_startup.wait() if self.failure_startup or (self.errored and not self.unsupported): @@ -53,7 +47,7 @@ async def shutdown(self): if self.errored: return - await self.event_queue.put({"type": "lifespan.shutdown"}) + await self.event_queue.put({'type': 'lifespan.shutdown'}) await self.event_shutdown.wait() if self.failure_shutdown or (self.errored and not self.unsupported): @@ -89,14 +83,14 @@ def _handle_shutdown_failed(self, message): # self.logger.error(message["message"]) _event_handlers = { - "lifespan.startup.complete": _handle_startup_complete, - "lifespan.startup.failed": _handle_startup_failed, - "lifespan.shutdown.complete": _handle_shutdown_complete, - "lifespan.shutdown.failed": _handle_shutdown_failed + 'lifespan.startup.complete': _handle_startup_complete, + 'lifespan.startup.failed': _handle_startup_failed, + 'lifespan.shutdown.complete': _handle_shutdown_complete, + 'lifespan.shutdown.failed': _handle_shutdown_failed, } async def send(self, message): - handler = self._event_handlers[message["type"]] + handler = self._event_handlers[message['type']] handler(self, message) @@ -108,6 +102,7 @@ def _send_wrapper(proto): @wraps(proto) def send(data): return proto(_noop_coro, data) + return send @@ -116,9 +111,6 @@ def _callback_wrapper(callback, scope_opts): @wraps(callback) def wrapper(scope: Scope, proto): - return callback( - scope.as_dict(root_url_path), - proto.receive, - _send_wrapper(proto.send) - ) + return callback(scope.as_dict(root_url_path), proto.receive, _send_wrapper(proto.send)) + return wrapper diff --git a/granian/cli.py b/granian/cli.py index c37ec96a..0bbbce3b 100644 --- a/granian/cli.py +++ b/granian/cli.py @@ -1,107 +1,55 @@ import json - from pathlib import Path from typing import Optional import typer from .__version__ import __version__ -from .constants import Interfaces, HTTPModes, Loops, ThreadModes +from .constants import HTTPModes, Interfaces, Loops, ThreadModes from .log import LogLevels from .server import Granian -cli = typer.Typer(name="granian", context_settings={"ignore_unknown_options": True}) +cli = typer.Typer(name='granian', context_settings={'ignore_unknown_options': True}) def version_callback(value: bool): if value: - typer.echo(f"{cli.info.name} {__version__}") + typer.echo(f'{cli.info.name} {__version__}') raise typer.Exit() @cli.command() def main( - app: str = typer.Argument(..., help="Application target to serve."), - host: str = typer.Option("127.0.0.1", help="Host address to bind to."), - port: int = typer.Option(8000, help="Port to bind to."), - interface: Interfaces = typer.Option( - Interfaces.RSGI.value, - help="Application interface type." - ), - http: HTTPModes = typer.Option( - HTTPModes.auto.value, - help="HTTP version." - ), - websockets: bool = typer.Option( - True, - "--ws/--no-ws", - help="Enable websockets handling", - show_default="enabled" - ), - workers: int = typer.Option(1, min=1, help="Number of worker processes."), - threads: int = typer.Option(1, min=1, help="Number of threads."), - threading_mode: ThreadModes = typer.Option( - ThreadModes.workers.value, - help="Threading mode to use." - ), - loop: Loops = typer.Option(Loops.auto.value, help="Event loop implementation"), - loop_opt: bool = typer.Option( - False, - "--opt/--no-opt", - help="Enable loop optimizations", - show_default="disabled" - ), - backlog: int = typer.Option( - 1024, - min=128, - help="Maximum number of connections to hold in backlog." - ), - log_level: LogLevels = typer.Option( - LogLevels.info.value, - help="Log level", - case_sensitive=False - ), + app: str = typer.Argument(..., help='Application target to serve.'), + host: str = typer.Option('127.0.0.1', help='Host address to bind to.'), + port: int = typer.Option(8000, help='Port to bind to.'), + interface: Interfaces = typer.Option(Interfaces.RSGI.value, help='Application interface type.'), + http: HTTPModes = typer.Option(HTTPModes.auto.value, help='HTTP version.'), + websockets: bool = typer.Option(True, '--ws/--no-ws', help='Enable websockets handling', show_default='enabled'), + workers: int = typer.Option(1, min=1, help='Number of worker processes.'), + threads: int = typer.Option(1, min=1, help='Number of threads.'), + threading_mode: ThreadModes = typer.Option(ThreadModes.workers.value, help='Threading mode to use.'), + loop: Loops = typer.Option(Loops.auto.value, help='Event loop implementation'), + loop_opt: bool = typer.Option(False, '--opt/--no-opt', help='Enable loop optimizations', show_default='disabled'), + backlog: int = typer.Option(1024, min=128, help='Maximum number of connections to hold in backlog.'), + log_level: LogLevels = typer.Option(LogLevels.info.value, help='Log level', case_sensitive=False), log_config: Optional[Path] = typer.Option( - None, - help="Logging configuration file (json)", - exists=True, - file_okay=True, - dir_okay=False, - readable=True + None, help='Logging configuration file (json)', exists=True, file_okay=True, dir_okay=False, readable=True ), ssl_keyfile: Optional[Path] = typer.Option( - None, - help="SSL key file", - exists=True, - file_okay=True, - dir_okay=False, - readable=True + None, help='SSL key file', exists=True, file_okay=True, dir_okay=False, readable=True ), ssl_certificate: Optional[Path] = typer.Option( - None, - help="SSL certificate file", - exists=True, - file_okay=True, - dir_okay=False, - readable=True - ), - url_path_prefix: Optional[str] = typer.Option( - None, - help="URL path prefix the app is mounted on" + None, help='SSL certificate file', exists=True, file_okay=True, dir_okay=False, readable=True ), + url_path_prefix: Optional[str] = typer.Option(None, help='URL path prefix the app is mounted on'), reload: bool = typer.Option( - False, - "--reload/--no-reload", - help="Enable auto reload on application's files changes" + False, '--reload/--no-reload', help="Enable auto reload on application's files changes" ), _: Optional[bool] = typer.Option( - None, - "--version", - callback=version_callback, - is_eager=True, - help="Shows the version and exit." - ) + None, '--version', callback=version_callback, is_eager=True, help='Shows the version and exit.' + ), ): log_dictconfig = None if log_config: @@ -109,7 +57,7 @@ def main( try: log_dictconfig = json.loads(log_config_file.read()) except Exception: - print("Unable to parse provided logging config.") + print('Unable to parse provided logging config.') raise typer.Exit(1) Granian( @@ -131,5 +79,5 @@ def main( ssl_cert=ssl_certificate, ssl_key=ssl_keyfile, url_path_prefix=url_path_prefix, - reload=reload + reload=reload, ).serve() diff --git a/granian/constants.py b/granian/constants.py index 1579afea..f0b2100d 100644 --- a/granian/constants.py +++ b/granian/constants.py @@ -2,23 +2,23 @@ class Interfaces(str, Enum): - ASGI = "asgi" - RSGI = "rsgi" - WSGI = "wsgi" + ASGI = 'asgi' + RSGI = 'rsgi' + WSGI = 'wsgi' class HTTPModes(str, Enum): - auto = "auto" - http1 = "1" - http2 = "2" + auto = 'auto' + http1 = '1' + http2 = '2' class ThreadModes(str, Enum): - runtime = "runtime" - workers = "workers" + runtime = 'runtime' + workers = 'workers' class Loops(str, Enum): - auto = "auto" - asyncio = "asyncio" - uvloop = "uvloop" + auto = 'auto' + asyncio = 'asyncio' + uvloop = 'uvloop' diff --git a/granian/log.py b/granian/log.py index 378e9a4b..df21b69f 100644 --- a/granian/log.py +++ b/granian/log.py @@ -1,18 +1,17 @@ import copy import logging import logging.config - from enum import Enum from typing import Any, Dict, Optional class LogLevels(str, Enum): - critical = "critical" - error = "error" - warning = "warning" - warn = "warn" - info = "info" - debug = "debug" + critical = 'critical' + error = 'error' + warning = 'warning' + warn = 'warn' + info = 'info' + debug = 'debug' log_levels_map = { @@ -21,27 +20,21 @@ class LogLevels(str, Enum): LogLevels.warning: logging.WARNING, LogLevels.warn: logging.WARN, LogLevels.info: logging.INFO, - LogLevels.debug: logging.DEBUG + LogLevels.debug: logging.DEBUG, } LOGGING_CONFIG = { - "version": 1, - "disable_existing_loggers": False, - "root": {"level": "INFO", "handlers": ["console"]}, - "formatters": { - "generic": { - "()": "logging.Formatter", - "fmt": "[%(levelname)s] %(message)s", - "datefmt": "[%Y-%m-%d %H:%M:%S %z]" + 'version': 1, + 'disable_existing_loggers': False, + 'root': {'level': 'INFO', 'handlers': ['console']}, + 'formatters': { + 'generic': { + '()': 'logging.Formatter', + 'fmt': '[%(levelname)s] %(message)s', + 'datefmt': '[%Y-%m-%d %H:%M:%S %z]', } }, - "handlers": { - "console": { - "formatter": "generic", - "class": "logging.StreamHandler", - "stream": "ext://sys.stdout", - } - } + 'handlers': {'console': {'formatter': 'generic', 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stdout'}}, } logger = logging.getLogger() @@ -51,5 +44,5 @@ def configure_logging(level: LogLevels, config: Optional[Dict[str, Any]] = None) log_config = copy.deepcopy(LOGGING_CONFIG) if config: log_config.update(config) - log_config["root"]["level"] = log_levels_map[level] + log_config['root']['level'] = log_levels_map[level] logging.config.dictConfig(log_config) diff --git a/granian/net.py b/granian/net.py index 26a60d3f..d0b2bc5f 100644 --- a/granian/net.py +++ b/granian/net.py @@ -2,7 +2,5 @@ from ._granian import ListenerHolder as SocketHolder -copyreg.pickle( - SocketHolder, - lambda v: (SocketHolder, v.__getstate__()) -) + +copyreg.pickle(SocketHolder, lambda v: (SocketHolder, v.__getstate__())) diff --git a/granian/rsgi.py b/granian/rsgi.py index 64555e01..4f9ecbba 100644 --- a/granian/rsgi.py +++ b/granian/rsgi.py @@ -2,12 +2,12 @@ from typing import Union from ._granian import ( - RSGIHTTPProtocol as HTTPProtocol, - RSGIWebsocketProtocol as WebsocketProtocol, - RSGIHeaders as Headers, - RSGIScope as Scope, - RSGIProtocolError as ProtocolError, - RSGIProtocolClosed as ProtocolClosed + RSGIHeaders as Headers, # noqa + RSGIHTTPProtocol as HTTPProtocol, # noqa + RSGIProtocolClosed as ProtocolClosed, # noqa + RSGIProtocolError as ProtocolError, # noqa + RSGIScope as Scope, # noqa + RSGIWebsocketProtocol as WebsocketProtocol, # noqa ) diff --git a/granian/server.py b/granian/server.py index a9800571..f0f610e9 100644 --- a/granian/server.py +++ b/granian/server.py @@ -5,7 +5,6 @@ import ssl import sys import threading - from functools import partial from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -16,11 +15,12 @@ from ._granian import ASGIWorker, RSGIWorker, WSGIWorker from ._internal import load_target from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap -from .constants import Interfaces, HTTPModes, Loops, ThreadModes +from .constants import HTTPModes, Interfaces, Loops, ThreadModes from .log import LogLevels, configure_logging, logger from .net import SocketHolder from .wsgi import _callback_wrapper as _wsgi_call_wrap + multiprocessing.allow_connection_pickling() @@ -30,7 +30,7 @@ class Granian: def __init__( self, target: str, - address: str = "127.0.0.1", + address: str = '127.0.0.1', port: int = 8000, interface: Interfaces = Interfaces.RSGI, workers: int = 1, @@ -48,7 +48,7 @@ def __init__( ssl_cert: Optional[Path] = None, ssl_key: Optional[Path] = None, url_path_prefix: Optional[str] = None, - reload: bool = False + reload: bool = False, ): self.target = target self.bind_addr = address @@ -75,11 +75,7 @@ def __init__( self.procs: List[multiprocessing.Process] = [] self.exit_event = threading.Event() - def build_ssl_context( - self, - cert: Optional[Path], - key: Optional[Path] - ): + def build_ssl_context(self, cert: Optional[Path], key: Optional[Path]): if not (cert and key): self.ssl_ctx = (False, None, None) return @@ -108,7 +104,7 @@ def _spawn_asgi_worker( log_level, log_config, ssl_ctx, - scope_opts + scope_opts, ): from granian._loops import loops, set_loop_signals @@ -129,29 +125,12 @@ def _spawn_asgi_worker( wcallback = future_watcher_wrapper(wcallback) worker = ASGIWorker( - worker_id, - sfd, - threads, - pthreads, - http_mode, - http1_buffer_size, - websockets, - loop_opt, - *ssl_ctx - ) - serve = getattr(worker, { - ThreadModes.runtime: "serve_rth", - ThreadModes.workers: "serve_wth" - }[threading_mode]) - serve( - wcallback, - loop, - contextvars.copy_context(), - shutdown_event.wait() + worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx ) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + serve(wcallback, loop, contextvars.copy_context(), shutdown_event.wait()) loop.run_until_complete(lifespan_handler.shutdown()) - @staticmethod def _spawn_rsgi_worker( worker_id, @@ -168,7 +147,7 @@ def _spawn_rsgi_worker( log_level, log_config, ssl_ctx, - scope_opts + scope_opts, ): from granian._loops import loops, set_loop_signals @@ -176,38 +155,23 @@ def _spawn_rsgi_worker( loop = loops.get(loop_impl) sfd = socket.fileno() target = callback_loader() - callback = ( - getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else - target - ) + callback = getattr(target, '__rsgi__') if hasattr(target, '__rsgi__') else target callback_init = ( - getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else - lambda *args, **kwargs: None + getattr(target, '__rsgi_init__') if hasattr(target, '__rsgi_init__') else lambda *args, **kwargs: None ) shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT]) callback_init(loop) worker = RSGIWorker( - worker_id, - sfd, - threads, - pthreads, - http_mode, - http1_buffer_size, - websockets, - loop_opt, - *ssl_ctx + worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, websockets, loop_opt, *ssl_ctx ) - serve = getattr(worker, { - ThreadModes.runtime: "serve_rth", - ThreadModes.workers: "serve_wth" - }[threading_mode]) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) serve( future_watcher_wrapper(callback) if not loop_opt else callback, loop, contextvars.copy_context(), - shutdown_event.wait() + shutdown_event.wait(), ) @staticmethod @@ -226,7 +190,7 @@ def _spawn_wsgi_worker( log_level, log_config, ssl_ctx, - scope_opts + scope_opts, ): from granian._loops import loops, set_loop_signals @@ -237,46 +201,20 @@ def _spawn_wsgi_worker( shutdown_event = set_loop_signals(loop, [signal.SIGTERM, signal.SIGINT]) - worker = WSGIWorker( - worker_id, - sfd, - threads, - pthreads, - http_mode, - http1_buffer_size, - *ssl_ctx - ) - serve = getattr(worker, { - ThreadModes.runtime: "serve_rth", - ThreadModes.workers: "serve_wth" - }[threading_mode]) - serve( - _wsgi_call_wrap(callback, scope_opts), - loop, - contextvars.copy_context(), - shutdown_event.wait() - ) + worker = WSGIWorker(worker_id, sfd, threads, pthreads, http_mode, http1_buffer_size, *ssl_ctx) + serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode]) + serve(_wsgi_call_wrap(callback, scope_opts), loop, contextvars.copy_context(), shutdown_event.wait()) def _init_shared_socket(self): - self._shd = SocketHolder.from_address( - self.bind_addr, - self.bind_port, - self.backlog - ) + self._shd = SocketHolder.from_address(self.bind_addr, self.bind_port, self.backlog) self._sfd = self._shd.get_fd() def signal_handler(self, *args, **kwargs): self.exit_event.set() - def _spawn_proc( - self, - id, - target, - callback_loader, - socket_loader - ) -> multiprocessing.Process: + def _spawn_proc(self, id, target, callback_loader, socket_loader) -> multiprocessing.Process: return multiprocessing.get_context().Process( - name="granian-worker", + name='granian-worker', target=target, args=( id, @@ -293,10 +231,8 @@ def _spawn_proc( self.log_level, self.log_config, self.ssl_ctx, - { - "url_path_prefix": self.url_path_prefix - } - ) + {'url_path_prefix': self.url_path_prefix}, + ), ) def _spawn_workers(self, sock, spawn_target, target_loader): @@ -305,14 +241,11 @@ def socket_loader(): for idx in range(self.workers): proc = self._spawn_proc( - id=idx + 1, - target=spawn_target, - callback_loader=target_loader, - socket_loader=socket_loader + id=idx + 1, target=spawn_target, callback_loader=target_loader, socket_loader=socket_loader ) proc.start() self.procs.append(proc) - logger.info(f"Spawning worker-{idx + 1} with pid: {proc.pid}") + logger.info(f'Spawning worker-{idx + 1} with pid: {proc.pid}') def _stop_workers(self): for proc in self.procs: @@ -321,7 +254,7 @@ def _stop_workers(self): proc.join() def startup(self, spawn_target, target_loader): - logger.info("Starting granian") + logger.info('Starting granian') for sig in self.SIGNALS: signal.signal(sig, self.signal_handler) @@ -329,13 +262,13 @@ def startup(self, spawn_target, target_loader): self._init_shared_socket() sock = socket.socket(fileno=self._sfd) sock.set_inheritable(True) - logger.info(f"Listening at: {self.bind_addr}:{self.bind_port}") + logger.info(f'Listening at: {self.bind_addr}:{self.bind_port}') self._spawn_workers(sock, spawn_target, target_loader) return sock def shutdown(self): - logger.info("Shutting down granian") + logger.info('Shutting down granian') self._stop_workers() def _serve(self, spawn_target, target_loader): @@ -360,12 +293,12 @@ def serve( self, spawn_target: Optional[Callable[..., None]] = None, target_loader: Optional[Callable[..., Callable[..., Any]]] = None, - wrap_loader: bool = True + wrap_loader: bool = True, ): default_spawners = { Interfaces.ASGI: self._spawn_asgi_worker, Interfaces.RSGI: self._spawn_rsgi_worker, - Interfaces.WSGI: self._spawn_wsgi_worker + Interfaces.WSGI: self._spawn_wsgi_worker, } if target_loader: if wrap_loader: @@ -383,8 +316,5 @@ def serve( "Number of workers will now fallback to 1." ) - serve_method = ( - self._serve_with_reloader if self.reload_on_changes else - self._serve - ) + serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve serve_method(spawn_target, target_loader) diff --git a/granian/wsgi.py b/granian/wsgi.py index 6ba53db4..6b73eaed 100644 --- a/granian/wsgi.py +++ b/granian/wsgi.py @@ -1,6 +1,5 @@ import os import sys - from functools import wraps from typing import Any, List, Tuple @@ -14,30 +13,27 @@ def __init__(self): self.status = 200 self.headers = [] - def __call__( - self, - status: str, - headers: List[Tuple[str, str]], - exc_info: Any = None - ): + def __call__(self, status: str, headers: List[Tuple[str, str]], exc_info: Any = None): self.status = int(status.split(' ', 1)[0]) self.headers = headers def _callback_wrapper(callback, scope_opts): basic_env = dict(os.environ) - basic_env.update({ - 'GATEWAY_INTERFACE': 'CGI/1.1', - 'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '', - 'SERVER_SOFTWARE': 'Granian', - 'wsgi.errors': sys.stderr, - 'wsgi.input_terminated': True, - 'wsgi.input': None, - 'wsgi.multiprocess': False, - 'wsgi.multithread': False, - 'wsgi.run_once': False, - 'wsgi.version': (1, 0) - }) + basic_env.update( + { + 'GATEWAY_INTERFACE': 'CGI/1.1', + 'SCRIPT_NAME': scope_opts.get('url_path_prefix') or '', + 'SERVER_SOFTWARE': 'Granian', + 'wsgi.errors': sys.stderr, + 'wsgi.input_terminated': True, + 'wsgi.input': None, + 'wsgi.multiprocess': False, + 'wsgi.multithread': False, + 'wsgi.run_once': False, + 'wsgi.version': (1, 0), + } + ) @wraps(callback) def wrapper(scope: Scope) -> Tuple[int, List[Tuple[str, str]], bytes]: @@ -46,7 +42,7 @@ def wrapper(scope: Scope) -> Tuple[int, List[Tuple[str, str]], bytes]: if isinstance(rv, list): resp_type = 0 - rv = b"".join(rv) + rv = b''.join(rv) else: resp_type = 1 rv = iter(rv) diff --git a/pyproject.toml b/pyproject.toml index d2a79e0b..0f4d07b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,65 +1,111 @@ [project] -name = "granian" +name = 'granian' authors = [ - {name = "Giovanni Barillari", email = "g@baro.dev"} + {name = 'Giovanni Barillari', email = 'g@baro.dev'} ] classifiers = [ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: MacOS", - "Operating System :: Microsoft :: Windows", - "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Programming Language :: Python", - "Programming Language :: Rust", - "Topic :: Internet :: WWW/HTTP" + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: BSD License', + 'Operating System :: MacOS', + 'Operating System :: Microsoft :: Windows', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: Implementation :: CPython', + 'Programming Language :: Python :: Implementation :: PyPy', + 'Programming Language :: Python', + 'Programming Language :: Rust', + 'Topic :: Internet :: WWW/HTTP' ] dynamic = [ - "description", - "keywords", - "license", - "readme", - "version" + 'description', + 'keywords', + 'license', + 'readme', + 'version' ] -requires-python = ">=3.8" +requires-python = '>=3.8' dependencies = [ - "watchfiles~=0.18", - "typer~=0.4", - "uvloop~=0.17.0; sys_platform != 'win32' and platform_python_implementation == 'CPython'" + 'watchfiles~=0.18', + 'typer~=0.4', + 'uvloop~=0.17.0; sys_platform != "win32" and platform_python_implementation == "CPython"' ] [project.optional-dependencies] +lint = [ + 'black~=23.7.0', + 'ruff~=0.0.287' +] test = [ - "httpx~=0.23.0", - "pytest~=7.1.2", - "pytest-asyncio~=0.18.3", - "websockets~=10.3" + 'httpx~=0.23.0', + 'pytest~=7.1.2', + 'pytest-asyncio~=0.18.3', + 'websockets~=10.3' ] [project.urls] -Homepage = "https://github.com/emmett-framework/granian" -Funding = "https://github.com/sponsors/gi0baro" -Source = "https://github.com/emmett-framework/granian" +Homepage = 'https://github.com/emmett-framework/granian' +Funding = 'https://github.com/sponsors/gi0baro' +Source = 'https://github.com/emmett-framework/granian' [project.scripts] -granian = "granian:cli.cli" +granian = 'granian:cli.cli' [build-system] -requires = ["maturin>=1.1.0,<1.3.0"] -build-backend = "maturin" +requires = ['maturin>=1.1.0,<1.3.0'] +build-backend = 'maturin' [tool.maturin] -module-name = "granian._granian" -bindings = "pyo3" +module-name = 'granian._granian' +bindings = 'pyo3' + +[tool.ruff] +line-length = 120 +extend-select = [ + # E and F are enabled by default + 'B', # flake8-bugbear + 'C4', # flake8-comprehensions + 'C90', # mccabe + 'I', # isort + 'N', # pep8-naming + 'Q', # flake8-quotes + 'RUF100', # ruff (unused noqa) + 'S', # flake8-bandit + 'W' # pycodestyle +] +extend-ignore = [ + 'B008', # function calls in args defaults are fine + 'B009', # getattr with constants is fine + 'B034', # re.split won't confuse us + 'B904', # rising without from is fine + 'E501', # leave line length to black + 'N818', # leave to us exceptions naming + 'S101' # assert is fine +] +flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' } +mccabe = { max-complexity = 13 } + +[tool.ruff.isort] +combine-as-imports = true +lines-after-imports = 2 +known-first-party = ['granian', 'tests'] + +[tool.ruff.per-file-ignores] +'granian/_granian.pyi' = ['I001'] +'tests/**' = ['B018', 'S110', 'S501'] + +[tool.black] +color = true +line-length = 120 +target-version = ['py38', 'py39', 'py310', 'py311'] +skip-string-normalization = true # leave this to ruff +skip-magic-trailing-comma = true # leave this to ruff [tool.pytest.ini_options] -asyncio_mode = "auto" +asyncio_mode = 'auto' diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index 5af7766f..96c100fb 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -3,45 +3,33 @@ use pyo3::prelude::*; use pyo3_asyncio::TaskLocals; use tokio::sync::oneshot; +use super::{ + io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol}, + types::ASGIScope as Scope, +}; use crate::{ callbacks::{ - CallbackWrapper, - callback_impl_run, - callback_impl_run_pytask, - callback_impl_loop_run, - callback_impl_loop_pytask, - callback_impl_loop_step, - callback_impl_loop_wake, - callback_impl_loop_err + callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step, + callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper, }, runtime::RuntimeRef, - ws::{HyperWebsocket, UpgradeData} + ws::{HyperWebsocket, UpgradeData}, }; -use super::{ - io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol}, - types::ASGIScope as Scope -}; - #[pyclass] pub(crate) struct CallbackRunnerHTTP { proto: Py, context: TaskLocals, - cb: PyObject + cb: PyObject, } impl CallbackRunnerHTTP { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: HTTPProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone(), context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap() + cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), } } @@ -64,14 +52,14 @@ macro_rules! callback_impl_done_http { let _ = tx.send(res); } } - } + }; } macro_rules! callback_impl_done_err { ($self:expr, $py:expr) => { log::warn!("Application callable raised an exception"); $self.done($py) - } + }; } #[pyclass] @@ -79,22 +67,17 @@ pub(crate) struct CallbackTaskHTTP { proto: Py, context: TaskLocals, pycontext: PyObject, - cb: PyObject + cb: PyObject, } impl CallbackTaskHTTP { - pub fn new( - py: Python, - cb: PyObject, - proto: Py, - context: TaskLocals - ) -> PyResult { + pub fn new(py: Python, cb: PyObject, proto: Py, context: TaskLocals) -> PyResult { let pyctx = context.context(py); Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), - cb + cb, }) } @@ -128,21 +111,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP { context: TaskLocals, cb: PyObject, #[pyo3(get)] - scope: PyObject + scope: PyObject, } impl CallbackWrappedRunnerHTTP { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: HTTPProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, cb: cb.callback, - scope: scope.into_py(py) + scope: scope.into_py(py), } } @@ -168,21 +146,16 @@ impl CallbackWrappedRunnerHTTP { pub(crate) struct CallbackRunnerWebsocket { proto: Py, context: TaskLocals, - cb: PyObject + cb: PyObject, } impl CallbackRunnerWebsocket { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: WebsocketProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone(), context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap() + cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), } } @@ -203,7 +176,7 @@ macro_rules! callback_impl_done_ws { let _ = tx.send(res); } } - } + }; } #[pyclass] @@ -211,22 +184,17 @@ pub(crate) struct CallbackTaskWebsocket { proto: Py, context: TaskLocals, pycontext: PyObject, - cb: PyObject + cb: PyObject, } impl CallbackTaskWebsocket { - pub fn new( - py: Python, - cb: PyObject, - proto: Py, - context: TaskLocals - ) -> PyResult { + pub fn new(py: Python, cb: PyObject, proto: Py, context: TaskLocals) -> PyResult { let pyctx = context.context(py); Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), - cb + cb, }) } @@ -260,21 +228,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket { context: TaskLocals, cb: PyObject, #[pyo3(get)] - scope: PyObject + scope: PyObject, } impl CallbackWrappedRunnerWebsocket { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: WebsocketProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, cb: cb.callback, - scope: scope.into_py(py) + scope: scope.into_py(py), } } @@ -325,7 +288,7 @@ macro_rules! call_impl_rtb_http { cb: CallbackWrapper, rt: RuntimeRef, req: Request, - scope: Scope + scope: Scope, ) -> oneshot::Receiver> { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, req, tx); @@ -345,7 +308,7 @@ macro_rules! call_impl_rtt_http { cb: CallbackWrapper, rt: RuntimeRef, req: Request, - scope: Scope + scope: Scope, ) -> oneshot::Receiver> { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, req, tx); @@ -368,7 +331,7 @@ macro_rules! call_impl_rtb_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope + scope: Scope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); @@ -389,7 +352,7 @@ macro_rules! call_impl_rtt_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope + scope: Scope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); diff --git a/src/asgi/errors.rs b/src/asgi/errors.rs index 6b788f60..8a6dda40 100644 --- a/src/asgi/errors.rs +++ b/src/asgi/errors.rs @@ -88,5 +88,5 @@ macro_rules! error_message { } pub(crate) use error_flow; -pub(crate) use error_transport; pub(crate) use error_message; +pub(crate) use error_transport; diff --git a/src/asgi/http.rs b/src/asgi/http.rs index cbf295f7..1fa62773 100644 --- a/src/asgi/http.rs +++ b/src/asgi/http.rs @@ -1,34 +1,22 @@ use hyper::{ - Body, - Request, - Response, - StatusCode, - header::SERVER as HK_SERVER, - http::response::Builder as ResponseBuilder + header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode, }; use std::net::SocketAddr; use tokio::sync::mpsc; -use crate::{ - callbacks::CallbackWrapper, - http::{HV_SERVER, response_500}, - runtime::RuntimeRef, - ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade} -}; use super::{ callbacks::{ - call_rtb_http, - call_rtb_http_pyw, - call_rtb_ws, - call_rtb_ws_pyw, - call_rtt_http, - call_rtt_http_pyw, - call_rtt_ws, - call_rtt_ws_pyw + call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws, + call_rtt_ws_pyw, }, - types::ASGIScope as Scope + types::ASGIScope as Scope, +}; +use crate::{ + callbacks::CallbackWrapper, + http::{response_500, HV_SERVER}, + runtime::RuntimeRef, + ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData}, }; - macro_rules! default_scope { ($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { @@ -39,7 +27,7 @@ macro_rules! default_scope { $req.method().as_ref(), $server_addr, $client_addr, - $req.headers() + $req.headers(), ) }; } @@ -53,7 +41,7 @@ macro_rules! handle_http_response { response_500() } } - } + }; } macro_rules! handle_request { @@ -64,7 +52,7 @@ macro_rules! handle_request { server_addr: SocketAddr, client_addr: SocketAddr, req: Request, - scheme: &str + scheme: &str, ) -> Response { let scope = default_scope!(server_addr, client_addr, &req, scheme); handle_http_response!($handler, rt, callback, req, scope) @@ -80,7 +68,7 @@ macro_rules! handle_request_with_ws { server_addr: SocketAddr, client_addr: SocketAddr, req: Request, - scheme: &str + scheme: &str, ) -> Response { let mut scope = default_scope!(server_addr, client_addr, &req, scheme); @@ -95,24 +83,20 @@ macro_rules! handle_request_with_ws { rt.inner.spawn(async move { let tx_ref = restx.clone(); - match $handler_ws( - callback, - rth, - ws, - UpgradeData::new(res, restx), - scope - ).await { + match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await { Ok(consumed) => { if !consumed { - let _ = tx_ref.send( - ResponseBuilder::new() - .status(StatusCode::FORBIDDEN) - .header(HK_SERVER, HV_SERVER) - .body(Body::from("")) - .unwrap() - ).await; + let _ = tx_ref + .send( + ResponseBuilder::new() + .status(StatusCode::FORBIDDEN) + .header(HK_SERVER, HV_SERVER) + .body(Body::from("")) + .unwrap(), + ) + .await; }; - }, + } _ => { log::error!("ASGI protocol failure"); let _ = tx_ref.send(response_500()).await; @@ -124,10 +108,10 @@ macro_rules! handle_request_with_ws { Some(res) => { resrx.close(); res - }, - _ => response_500() + } + _ => response_500(), } - }, + } Err(err) => { return ResponseBuilder::new() .status(StatusCode::BAD_REQUEST) diff --git a/src/asgi/io.rs b/src/asgi/io.rs index 4b2f25eb..d206aaf2 100644 --- a/src/asgi/io.rs +++ b/src/asgi/io.rs @@ -1,33 +1,34 @@ use bytes::Bytes; -use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}}; +use futures::{ + sink::SinkExt, + stream::{SplitSink, SplitStream, StreamExt}, +}; use hyper::{ - Request, - Response, body::{Body, HttpBody, Sender as BodySender}, - header::{HeaderName, HeaderValue, HeaderMap, SERVER as HK_SERVER} + header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER}, + Request, Response, }; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use std::sync::Arc; +use tokio::sync::{oneshot, Mutex}; use tokio_tungstenite::WebSocketStream; -use tokio::sync::{Mutex, oneshot}; use tungstenite::Message; +use super::{ + errors::{error_flow, error_message, error_transport, UnsupportedASGIMessage}, + types::ASGIMessageType, +}; use crate::{ http::HV_SERVER, - runtime::{RuntimeRef, future_into_py_iter, future_into_py_futlike}, - ws::{HyperWebsocket, UpgradeData} + runtime::{future_into_py_futlike, future_into_py_iter, RuntimeRef}, + ws::{HyperWebsocket, UpgradeData}, }; -use super::{ - errors::{UnsupportedASGIMessage, error_flow, error_transport, error_message}, - types::ASGIMessageType -}; - const EMPTY_BYTES: Vec = Vec::new(); const EMPTY_STRING: String = String::new(); -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct ASGIHTTPProtocol { rt: RuntimeRef, tx: Option>>, @@ -36,15 +37,11 @@ pub(crate) struct ASGIHTTPProtocol { response_chunked: bool, response_status: Option, response_headers: Option, - body_tx: Option>> + body_tx: Option>>, } impl ASGIHTTPProtocol { - pub fn new( - rt: RuntimeRef, - request: Request, - tx: oneshot::Sender> - ) -> Self { + pub fn new(rt: RuntimeRef, request: Request, tx: oneshot::Sender>) -> Self { Self { rt, tx: Some(tx), @@ -53,7 +50,7 @@ impl ASGIHTTPProtocol { response_chunked: false, response_status: None, response_headers: None, - body_tx: None + body_tx: None, } } @@ -71,7 +68,7 @@ impl ASGIHTTPProtocol { fn send_body<'p>(&self, py: Python<'p>, tx: Arc>, body: Vec) -> PyResult<&'p PyAny> { future_into_py_futlike(self.rt.clone(), py, async move { let mut tx = tx.lock().await; - match (&mut *tx).send_data(body.into()).await { + match (*tx).send_data(body.into()).await { Ok(_) => Ok(()), Err(err) => { log::warn!("ASGI transport tx error: {:?}", err); @@ -94,18 +91,18 @@ impl ASGIHTTPProtocol { let mut bodym = body_ref.lock().await; let body = &mut *bodym; let mut more_body = false; - let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| { - buf.map_or_else(|_| Bytes::new(), |buf| { - more_body = !body.is_end_stream(); - buf - }) + let chunk = body.data().await.map_or_else(Bytes::new, |buf| { + buf.map_or_else( + |_| Bytes::new(), + |buf| { + more_body = !body.is_end_stream(); + buf + }, + ) }); Python::with_gil(|py| { let dict = PyDict::new(py); - dict.set_item( - pyo3::intern!(py, "type"), - pyo3::intern!(py, "http.request") - )?; + dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?; dict.set_item(pyo3::intern!(py, "body"), PyBytes::new(py, &chunk[..]))?; dict.set_item(pyo3::intern!(py, "more_body"), more_body)?; Ok(dict.to_object(py)) @@ -115,16 +112,14 @@ impl ASGIHTTPProtocol { fn send<'p>(&mut self, py: Python<'p>, asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> { match adapt_message_type(data) { - Ok(ASGIMessageType::HTTPStart) => { - match self.response_started { - false => { - self.response_status = Some(adapt_status_code(data)?); - self.response_headers = Some(adapt_headers(data)); - self.response_started = true; - asyncw.call0() - }, - true => error_flow!() + Ok(ASGIMessageType::HTTPStart) => match self.response_started { + false => { + self.response_status = Some(adapt_status_code(data)?); + self.response_headers = Some(adapt_headers(data)); + self.response_started = true; + asyncw.call0() } + true => error_flow!(), }, Ok(ASGIMessageType::HTTPBody) => { let (body, more) = adapt_body(data); @@ -133,7 +128,7 @@ impl ASGIHTTPProtocol { let headers = self.response_headers.take().unwrap(); self.send_response(self.response_status.unwrap(), headers, body.into()); asyncw.call0() - }, + } (true, true, false) => { self.response_chunked = true; let headers = self.response_headers.take().unwrap(); @@ -142,37 +137,31 @@ impl ASGIHTTPProtocol { self.body_tx = Some(tx.clone()); self.send_response(self.response_status.unwrap(), headers, body_stream); self.send_body(py, tx, body) - }, - (true, true, true) => { - match self.body_tx.as_mut() { - Some(tx) => { - let tx = tx.clone(); - self.send_body(py, tx, body) - }, - _ => error_flow!() + } + (true, true, true) => match self.body_tx.as_mut() { + Some(tx) => { + let tx = tx.clone(); + self.send_body(py, tx, body) } + _ => error_flow!(), }, - (true, false, true) => { - match self.body_tx.take() { - Some(tx) => { - match body.is_empty() { - false => self.send_body(py, tx, body), - true => asyncw.call0() - } - }, - _ => error_flow!() - } + (true, false, true) => match self.body_tx.take() { + Some(tx) => match body.is_empty() { + false => self.send_body(py, tx, body), + true => asyncw.call0(), + }, + _ => error_flow!(), }, - _ => error_flow!() + _ => error_flow!(), } - }, + } Err(err) => Err(err.into()), - _ => error_message!() + _ => error_message!(), } } } -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct ASGIWebsocketProtocol { rt: RuntimeRef, tx: Option>, @@ -181,16 +170,11 @@ pub(crate) struct ASGIWebsocketProtocol { ws_tx: Arc, Message>>>>, ws_rx: Arc>>>>, accepted: Arc>, - closed: bool + closed: bool, } impl ASGIWebsocketProtocol { - pub fn new( - rt: RuntimeRef, - tx: oneshot::Sender, - websocket: HyperWebsocket, - upgrade: UpgradeData - ) -> Self { + pub fn new(rt: RuntimeRef, tx: oneshot::Sender, websocket: HyperWebsocket, upgrade: UpgradeData) -> Self { Self { rt, tx: Some(tx), @@ -199,7 +183,7 @@ impl ASGIWebsocketProtocol { ws_tx: Arc::new(Mutex::new(None)), ws_rx: Arc::new(Mutex::new(None)), accepted: Arc::new(Mutex::new(false)), - closed: false + closed: false, } } @@ -211,7 +195,7 @@ impl ASGIWebsocketProtocol { let tx = self.ws_tx.clone(); let rx = self.ws_rx.clone(); future_into_py_iter(self.rt.clone(), py, async move { - if let Ok(_) = upgrade.send().await { + if (upgrade.send().await).is_ok() { if let Ok(stream) = websocket.await { let mut wtx = tx.lock().await; let mut wrx = rx.lock().await; @@ -220,7 +204,7 @@ impl ASGIWebsocketProtocol { *wtx = Some(tx); *wrx = Some(rx); *accepted = true; - return Ok(()) + return Ok(()); } } error_flow!() @@ -228,18 +212,14 @@ impl ASGIWebsocketProtocol { } #[inline(always)] - fn send_message<'p>( - &self, - py: Python<'p>, - data: &'p PyDict - ) -> PyResult<&'p PyAny> { + fn send_message<'p>(&self, py: Python<'p>, data: &'p PyDict) -> PyResult<&'p PyAny> { let transport = self.ws_tx.clone(); let message = ws_message_into_rs(data); future_into_py_iter(self.rt.clone(), py, async move { if let Ok(message) = message { if let Some(ws) = &mut *(transport.lock().await) { - if let Ok(_) = ws.send(message).await { - return Ok(()) + if (ws.send(message).await).is_ok() { + return Ok(()); } }; }; @@ -253,8 +233,8 @@ impl ASGIWebsocketProtocol { let transport = self.ws_tx.clone(); future_into_py_iter(self.rt.clone(), py, async move { if let Some(ws) = &mut *(transport.lock().await) { - if let Ok(_) = ws.close().await { - return Ok(()) + if (ws.close().await).is_ok() { + return Ok(()); } }; error_flow!() @@ -262,10 +242,7 @@ impl ASGIWebsocketProtocol { } fn consumed(&self) -> bool { - match &self.upgrade { - Some(_) => false, - _ => true - } + self.upgrade.is_none() } pub fn tx(&mut self) -> (Option>, bool) { @@ -278,44 +255,26 @@ impl ASGIWebsocketProtocol { fn receive<'p>(&mut self, py: Python<'p>) -> PyResult<&'p PyAny> { let transport = self.ws_rx.clone(); let accepted = self.accepted.clone(); - let closed = self.closed.clone(); + let closed = self.closed; future_into_py_futlike(self.rt.clone(), py, async move { let accepted = accepted.lock().await; match (*accepted, closed) { (false, false) => { return Python::with_gil(|py| { let dict = PyDict::new(py); - dict.set_item( - pyo3::intern!(py, "type"), - pyo3::intern!(py, "websocket.connect") - )?; + dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?; Ok(dict.to_object(py)) }) - }, - (true, false) => {}, - _ => { - return error_flow!() } + (true, false) => {} + _ => return error_flow!(), } if let Some(ws) = &mut *(transport.lock().await) { - loop { - match ws.next().await { - Some(recv) => { - match recv { - Ok(Message::Ping(_)) => { - continue - }, - Ok(message) => { - return ws_message_into_py(message) - }, - _ => { - break - } - } - }, - _ => { - break - } + while let Some(recv) = ws.next().await { + match recv { + Ok(Message::Ping(_)) => continue, + Ok(message) => return ws_message_into_py(message), + _ => break, } } } @@ -325,25 +284,17 @@ impl ASGIWebsocketProtocol { fn send<'p>(&mut self, py: Python<'p>, _asyncw: &'p PyAny, data: &'p PyDict) -> PyResult<&'p PyAny> { match (adapt_message_type(data), self.closed) { - (Ok(ASGIMessageType::WSAccept), _) => { - self.accept(py) - }, - (Ok(ASGIMessageType::WSClose), false) => { - self.close(py) - }, - (Ok(ASGIMessageType::WSMessage), false) => { - self.send_message(py, data) - }, + (Ok(ASGIMessageType::WSAccept), _) => self.accept(py), + (Ok(ASGIMessageType::WSClose), false) => self.close(py), + (Ok(ASGIMessageType::WSMessage), false) => self.send_message(py, data), (Err(err), _) => Err(err.into()), - _ => error_message!() + _ => error_message!(), } } } #[inline(never)] -fn adapt_message_type( - message: &PyDict -) -> Result { +fn adapt_message_type(message: &PyDict) -> Result { match message.get_item("type") { Some(item) => { let message_type: &str = item.extract()?; @@ -353,20 +304,18 @@ fn adapt_message_type( "websocket.accept" => Ok(ASGIMessageType::WSAccept), "websocket.close" => Ok(ASGIMessageType::WSClose), "websocket.send" => Ok(ASGIMessageType::WSMessage), - _ => error_message!() + _ => error_message!(), } - }, - _ => error_message!() + } + _ => error_message!(), } } #[inline(always)] fn adapt_status_code(message: &PyDict) -> Result { match message.get_item("status") { - Some(item) => { - Ok(item.extract()?) - }, - _ => error_message!() + Some(item) => Ok(item.extract()?), + _ => error_message!(), } } @@ -377,34 +326,26 @@ fn adapt_headers(message: &PyDict) -> HeaderMap { match message.get_item("headers") { Some(item) => { let accum: Vec> = item.extract().unwrap_or(Vec::new()); - for tup in accum.iter() { - match ( - HeaderName::from_bytes(tup[0]), - HeaderValue::from_bytes(tup[1]) - ) { - (Ok(key), Ok(val)) => { ret.append(key, val); }, - _ => {} + for tup in &accum { + if let (Ok(key), Ok(val)) = (HeaderName::from_bytes(tup[0]), HeaderValue::from_bytes(tup[1])) { + ret.append(key, val); } - }; + } ret - }, - _ => ret + } + _ => ret, } } #[inline(always)] fn adapt_body(message: &PyDict) -> (Vec, bool) { let body = match message.get_item("body") { - Some(item) => { - item.extract().unwrap_or(EMPTY_BYTES) - }, - _ => EMPTY_BYTES + Some(item) => item.extract().unwrap_or(EMPTY_BYTES), + _ => EMPTY_BYTES, }; let more = match message.get_item("more_body") { - Some(item) => { - item.extract().unwrap_or(false) - }, - _ => false + Some(item) => item.extract().unwrap_or(false), + _ => false, }; (body, more) } @@ -412,22 +353,12 @@ fn adapt_body(message: &PyDict) -> (Vec, bool) { #[inline(always)] fn ws_message_into_rs(message: &PyDict) -> PyResult { match (message.get_item("bytes"), message.get_item("text")) { - (Some(item), None) => { - Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES))) - }, - (None, Some(item)) => { - Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING))) - }, - (Some(itemb), Some(itemt)) => { - match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) { - (Some(msgb), None) => { - Ok(Message::Binary(msgb)) - }, - (None, Some(msgt)) => { - Ok(Message::Text(msgt)) - }, - _ => error_flow!() - } + (Some(item), None) => Ok(Message::Binary(item.extract().unwrap_or(EMPTY_BYTES))), + (None, Some(item)) => Ok(Message::Text(item.extract().unwrap_or(EMPTY_STRING))), + (Some(itemb), Some(itemt)) => match (itemb.extract().unwrap_or(None), itemt.extract().unwrap_or(None)) { + (Some(msgb), None) => Ok(Message::Binary(msgb)), + (None, Some(msgt)) => Ok(Message::Text(msgt)), + _ => error_flow!(), }, _ => { error_flow!() @@ -438,41 +369,23 @@ fn ws_message_into_rs(message: &PyDict) -> PyResult { #[inline(always)] fn ws_message_into_py(message: Message) -> PyResult { match message { - Message::Binary(message) => { - Python::with_gil(|py| { - let dict = PyDict::new(py); - dict.set_item( - pyo3::intern!(py, "type"), - pyo3::intern!(py, "websocket.receive") - )?; - dict.set_item( - pyo3::intern!(py, "bytes"), - PyBytes::new(py, &message[..]) - )?; - Ok(dict.to_object(py)) - }) - }, - Message::Text(message) => { - Python::with_gil(|py| { - let dict = PyDict::new(py); - dict.set_item( - pyo3::intern!(py, "type"), - pyo3::intern!(py, "websocket.receive") - )?; - dict.set_item(pyo3::intern!(py, "text"), message)?; - Ok(dict.to_object(py)) - }) - }, - Message::Close(_) => { - Python::with_gil(|py| { - let dict = PyDict::new(py); - dict.set_item( - pyo3::intern!(py, "type"), - pyo3::intern!(py, "websocket.disconnect") - )?; - Ok(dict.to_object(py)) - }) - }, + Message::Binary(message) => Python::with_gil(|py| { + let dict = PyDict::new(py); + dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?; + dict.set_item(pyo3::intern!(py, "bytes"), PyBytes::new(py, &message[..]))?; + Ok(dict.to_object(py)) + }), + Message::Text(message) => Python::with_gil(|py| { + let dict = PyDict::new(py); + dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?; + dict.set_item(pyo3::intern!(py, "text"), message)?; + Ok(dict.to_object(py)) + }), + Message::Close(_) => Python::with_gil(|py| { + let dict = PyDict::new(py); + dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.disconnect"))?; + Ok(dict.to_object(py)) + }), v => { log::warn!("Unsupported websocket message received {:?}", v); error_flow!() diff --git a/src/asgi/serve.rs b/src/asgi/serve.rs index a30f15c3..6243cdc6 100644 --- a/src/asgi/serve.rs +++ b/src/asgi/serve.rs @@ -1,29 +1,14 @@ use pyo3::prelude::*; -use crate::{ - workers::{ - WorkerConfig, - serve_rth, - serve_wth, - serve_rth_ssl, - serve_wth_ssl - } -}; use super::http::{ - handle_rtb, - handle_rtb_pyw, - handle_rtt, - handle_rtt_pyw, - handle_rtb_ws, - handle_rtb_ws_pyw, - handle_rtt_ws, - handle_rtt_ws_pyw + handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws, + handle_rtt_ws_pyw, }; +use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig}; - -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub struct ASGIWorker { - config: WorkerConfig + config: WorkerConfig, } impl ASGIWorker { @@ -74,7 +59,7 @@ impl ASGIWorker { opt_enabled: bool, ssl_enabled: bool, ssl_cert: Option<&str>, - ssl_key: Option<&str> + ssl_key: Option<&str>, ) -> PyResult { Ok(Self { config: WorkerConfig::new( @@ -88,22 +73,16 @@ impl ASGIWorker { opt_enabled, ssl_enabled, ssl_cert, - ssl_key - ) + ssl_key, + ), }) } - fn serve_rth( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { match ( self.config.websockets_enabled, self.config.ssl_enabled, - self.config.opt_enabled + self.config.opt_enabled, ) { (false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx), @@ -112,21 +91,15 @@ impl ASGIWorker { (false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx), - (true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx) + (true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx), } } - fn serve_wth( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { match ( self.config.websockets_enabled, self.config.ssl_enabled, - self.config.opt_enabled + self.config.opt_enabled, ) { (false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx), @@ -135,7 +108,7 @@ impl ASGIWorker { (false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx), - (true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx) + (true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx), } } } diff --git a/src/asgi/types.rs b/src/asgi/types.rs index 1841d0e6..8e1056e2 100644 --- a/src/asgi/types.rs +++ b/src/asgi/types.rs @@ -1,10 +1,9 @@ -use hyper::{Uri, Version, header::HeaderMap}; +use hyper::{header::HeaderMap, Uri, Version}; use once_cell::sync::OnceCell; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyList, PyString}; use std::net::{IpAddr, SocketAddr}; - const SCHEME_HTTPS: &str = "https"; const SCHEME_WS: &str = "ws"; const SCHEME_WSS: &str = "wss"; @@ -17,10 +16,10 @@ pub(crate) enum ASGIMessageType { HTTPBody, WSAccept, WSClose, - WSMessage + WSMessage, } -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct ASGIScope { http_version: Version, scheme: String, @@ -31,7 +30,7 @@ pub(crate) struct ASGIScope { client_ip: IpAddr, client_port: u16, headers: HeaderMap, - is_websocket: bool + is_websocket: bool, } impl ASGIScope { @@ -42,31 +41,31 @@ impl ASGIScope { method: &str, server: SocketAddr, client: SocketAddr, - headers: &HeaderMap + headers: &HeaderMap, ) -> Self { Self { - http_version: http_version, + http_version, scheme: scheme.to_string(), method: method.to_string(), - uri: uri, + uri, server_ip: server.ip(), server_port: server.port(), client_ip: client.ip(), client_port: client.port(), - headers: headers.to_owned(), - is_websocket: false + headers: headers.clone(), + is_websocket: false, } } pub fn set_websocket(&mut self) { - self.is_websocket = true + self.is_websocket = true; } #[inline(always)] fn py_proto(&self) -> &str { match self.is_websocket { false => "http", - true => "websocket" + true => "websocket", } } @@ -76,7 +75,7 @@ impl ASGIScope { Version::HTTP_10 => "1", Version::HTTP_11 => "1.1", Version::HTTP_2 => "2", - _ => "1" + _ => "1", } } @@ -85,22 +84,20 @@ impl ASGIScope { let scheme = &self.scheme[..]; match self.is_websocket { false => scheme, - true => { - match scheme { - SCHEME_HTTPS => SCHEME_WSS, - _ => SCHEME_WS - } - } + true => match scheme { + SCHEME_HTTPS => SCHEME_WSS, + _ => SCHEME_WS, + }, } } #[inline(always)] fn py_headers<'p>(&self, py: Python<'p>) -> PyResult<&'p PyList> { let rv = PyList::empty(py); - for (key, value) in self.headers.iter() { + for (key, value) in &self.headers { rv.append(( PyBytes::new(py, key.as_str().as_bytes()), - PyBytes::new(py, value.as_bytes()) + PyBytes::new(py, value.as_bytes()), ))?; } Ok(rv) @@ -110,17 +107,10 @@ impl ASGIScope { #[pymethods] impl ASGIScope { fn as_dict<'p>(&self, py: Python<'p>, url_path_prefix: &'p str) -> PyResult<&'p PyAny> { - let ( - path, - query_string, - proto, - http_version, - server, - client, - scheme, - method - ) = py.allow_threads(|| { - let (path, query_string) = self.uri.path_and_query() + let (path, query_string, proto, http_version, server, client, scheme, method) = py.allow_threads(|| { + let (path, query_string) = self + .uri + .path_and_query() .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); ( path, @@ -136,19 +126,23 @@ impl ASGIScope { let dict: &PyDict = PyDict::new(py); dict.set_item( pyo3::intern!(py, "asgi"), - ASGI_VERSION.get_or_try_init(|| { - let rv = PyDict::new(py); - rv.set_item("version", "3.0")?; - rv.set_item("spec_version", "2.3")?; - Ok::(rv.into()) - })?.as_ref(py) + ASGI_VERSION + .get_or_try_init(|| { + let rv = PyDict::new(py); + rv.set_item("version", "3.0")?; + rv.set_item("spec_version", "2.3")?; + Ok::(rv.into()) + })? + .as_ref(py), )?; dict.set_item( pyo3::intern!(py, "extensions"), - ASGI_EXTENSIONS.get_or_try_init(|| { - let rv = PyDict::new(py); - Ok::(rv.into()) - })?.as_ref(py) + ASGI_EXTENSIONS + .get_or_try_init(|| { + let rv = PyDict::new(py); + Ok::(rv.into()) + })? + .as_ref(py), )?; dict.set_item(pyo3::intern!(py, "type"), proto)?; dict.set_item(pyo3::intern!(py, "http_version"), http_version)?; @@ -160,17 +154,12 @@ impl ASGIScope { dict.set_item(pyo3::intern!(py, "path"), path)?; dict.set_item( pyo3::intern!(py, "raw_path"), - PyString::new(py, path) - .call_method1( - pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),) - )? + PyString::new(py, path).call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "ascii"),))?, )?; dict.set_item( pyo3::intern!(py, "query_string"), PyString::new(py, query_string) - .call_method1( - pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),) - )? + .call_method1(pyo3::intern!(py, "encode"), (pyo3::intern!(py, "latin-1"),))?, )?; dict.set_item(pyo3::intern!(py, "headers"), self.py_headers(py)?)?; Ok(dict) diff --git a/src/callbacks.rs b/src/callbacks.rs index c6a67ef1..9af137b1 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -2,32 +2,27 @@ use once_cell::sync::OnceCell; use pyo3::prelude::*; use pyo3::pyclass::IterNextOutput; - static CONTEXTVARS: OnceCell = OnceCell::new(); static CONTEXT: OnceCell = OnceCell::new(); #[derive(Clone)] pub(crate) struct CallbackWrapper { pub callback: PyObject, - pub context: pyo3_asyncio::TaskLocals + pub context: pyo3_asyncio::TaskLocals, } impl CallbackWrapper { - pub(crate) fn new( - callback: PyObject, - event_loop: &PyAny, - context: &PyAny - ) -> Self { + pub(crate) fn new(callback: PyObject, event_loop: &PyAny, context: &PyAny) -> Self { Self { callback, - context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context) + context: pyo3_asyncio::TaskLocals::new(event_loop).with_context(context), } } } #[pyclass] pub(crate) struct PyIterAwaitable { - result: Option> + result: Option>, } impl PyIterAwaitable { @@ -36,7 +31,7 @@ impl PyIterAwaitable { } pub(crate) fn set_result(&mut self, result: PyResult) { - self.result = Some(result) + self.result = Some(result); } } @@ -52,13 +47,11 @@ impl PyIterAwaitable { fn __next__(&mut self, py: Python) -> PyResult> { match self.result.take() { - Some(res) => { - match res { - Ok(v) => Ok(IterNextOutput::Return(v)), - Err(err) => Err(err) - } + Some(res) => match res { + Ok(v) => Ok(IterNextOutput::Return(v)), + Err(err) => Err(err), }, - _ => Ok(IterNextOutput::Yield(py.None())) + _ => Ok(IterNextOutput::Yield(py.None())), } } } @@ -68,7 +61,7 @@ pub(crate) struct PyFutureAwaitable { py_block: bool, event_loop: PyObject, result: Option>, - cb: Option<(PyObject, Py)> + cb: Option<(PyObject, Py)>, } impl PyFutureAwaitable { @@ -77,7 +70,7 @@ impl PyFutureAwaitable { event_loop, py_block: true, result: None, - cb: None + cb: None, } } @@ -85,12 +78,9 @@ impl PyFutureAwaitable { pyself.result = Some(result); if let Some((cb, ctx)) = pyself.cb.take() { let py = pyself.py(); - let _ = pyself.event_loop.call_method( - py, - "call_soon_threadsafe", - (cb, &pyself), - Some(ctx.as_ref(py)) - ); + let _ = pyself + .event_loop + .call_method(py, "call_soon_threadsafe", (cb, &pyself), Some(ctx.as_ref(py))); } } } @@ -108,25 +98,22 @@ impl PyFutureAwaitable { #[setter(_asyncio_future_blocking)] fn set_block(&mut self, val: bool) { - self.py_block = val + self.py_block = val; } fn get_loop(&mut self) -> PyObject { self.event_loop.clone() } - fn add_done_callback( - mut pyself: PyRefMut<'_, Self>, - py: Python, - cb: PyObject, - context: PyObject - ) -> PyResult<()> { + fn add_done_callback(mut pyself: PyRefMut<'_, Self>, py: Python, cb: PyObject, context: PyObject) -> PyResult<()> { let kwctx = pyo3::types::PyDict::new(py); kwctx.set_item("context", context)?; match pyself.result { Some(_) => { - pyself.event_loop.call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?; - }, + pyself + .event_loop + .call_method(py, "call_soon", (cb, &pyself), Some(kwctx))?; + } _ => { pyself.cb = Some((cb, kwctx.into_py(py))); } @@ -136,9 +123,9 @@ impl PyFutureAwaitable { fn cancel(mut pyself: PyRefMut<'_, Self>, py: Python) -> bool { if let Some((cb, kwctx)) = pyself.cb.take() { - let _ = pyself.event_loop.call_method( - py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py)) - ); + let _ = pyself + .event_loop + .call_method(py, "call_soon", (cb, &pyself), Some(kwctx.as_ref(py))); } false } @@ -150,79 +137,69 @@ impl PyFutureAwaitable { pyself } - fn __next__( - mut pyself: PyRefMut<'_, Self> - ) -> PyResult, PyObject>> { + fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult, PyObject>> { match pyself.result { - Some(_) => { - match pyself.result.take().unwrap() { - Ok(v) => Ok(IterNextOutput::Return(v)), - Err(err) => Err(err) - } + Some(_) => match pyself.result.take().unwrap() { + Ok(v) => Ok(IterNextOutput::Return(v)), + Err(err) => Err(err), }, - _ => Ok(IterNextOutput::Yield(pyself)) + _ => Ok(IterNextOutput::Yield(pyself)), } } } fn contextvars(py: Python) -> PyResult<&PyAny> { Ok(CONTEXTVARS - .get_or_try_init(|| py.import("contextvars").map(|m| m.into()))? + .get_or_try_init(|| py.import("contextvars").map(std::convert::Into::into))? .as_ref(py)) } pub fn empty_pycontext(py: Python) -> PyResult<&PyAny> { Ok(CONTEXT - .get_or_try_init(|| contextvars(py)?.getattr("Context")?.call0().map(|c| c.into()))? + .get_or_try_init(|| { + contextvars(py)? + .getattr("Context")? + .call0() + .map(std::convert::Into::into) + })? .as_ref(py)) } macro_rules! callback_impl_run { () => { - pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> { let event_loop = self.context.event_loop(py); let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?; let kwctx = pyo3::types::PyDict::new(py); kwctx.set_item( pyo3::intern!(py, "context"), - crate::callbacks::empty_pycontext(py)? + crate::callbacks::empty_pycontext(py)?, )?; - event_loop.call_method( - pyo3::intern!(py, "call_soon_threadsafe"), - (target,), - Some(kwctx) - ) + event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx)) } }; } macro_rules! callback_impl_run_pytask { () => { - pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> { let event_loop = self.context.event_loop(py); let context = self.context.context(py); let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?; let kwctx = pyo3::types::PyDict::new(py); - kwctx.set_item( - pyo3::intern!(py, "context"), - context - )?; - event_loop.call_method( - pyo3::intern!(py, "call_soon_threadsafe"), - (target,), - Some(kwctx) - ) + kwctx.set_item(pyo3::intern!(py, "context"), context)?; + event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(kwctx)) } }; } macro_rules! callback_impl_loop_run { () => { - pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn run(self, py: Python<'_>) -> PyResult<&PyAny> { let context = self.pycontext.clone().into_ref(py); context.call_method1( pyo3::intern!(py, "run"), - (self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,) + (self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,), ) } }; @@ -232,7 +209,7 @@ macro_rules! callback_impl_loop_pytask { ($pyself:expr, $py:expr) => { $pyself.context.event_loop($py).call_method1( pyo3::intern!($py, "create_task"), - ($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,) + ($pyself.cb.clone().into_ref($py).call1(($pyself.into_py($py),))?,), ) }; } @@ -241,12 +218,9 @@ macro_rules! callback_impl_loop_step { ($pyself:expr, $py:expr) => { match $pyself.cb.call_method1($py, pyo3::intern!($py, "send"), ($py.None(),)) { Ok(res) => { - let blocking: bool = match res.getattr( - $py, - pyo3::intern!($py, "_asyncio_future_blocking") - ) { + let blocking: bool = match res.getattr($py, pyo3::intern!($py, "_asyncio_future_blocking")) { Ok(v) => v.extract($py)?, - _ => false + _ => false, }; let ctx = $pyself.pycontext.clone(); @@ -255,43 +229,30 @@ macro_rules! callback_impl_loop_step { match blocking { true => { - res.setattr( - $py, - pyo3::intern!($py, "_asyncio_future_blocking"), - false - )?; + res.setattr($py, pyo3::intern!($py, "_asyncio_future_blocking"), false)?; res.call_method( $py, pyo3::intern!($py, "add_done_callback"), - ( - $pyself - .into_py($py) - .getattr($py, pyo3::intern!($py, "_loop_wake"))?, - ), - Some(kwctx) + ($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_wake"))?,), + Some(kwctx), )?; Ok(()) - }, + } false => { let event_loop = $pyself.context.event_loop($py); event_loop.call_method( pyo3::intern!($py, "call_soon"), - ( - $pyself - .into_py($py) - .getattr($py, pyo3::intern!($py, "_loop_step"))?, - ), - Some(kwctx) + ($pyself.into_py($py).getattr($py, pyo3::intern!($py, "_loop_step"))?,), + Some(kwctx), )?; Ok(()) } } - }, + } Err(err) => { - if ( - err.is_instance_of::($py) || - err.is_instance_of::($py) - ) { + if (err.is_instance_of::($py) + || err.is_instance_of::($py)) + { $pyself.done($py); Ok(()) } else { @@ -307,7 +268,7 @@ macro_rules! callback_impl_loop_wake { ($pyself:expr, $py:expr, $fut:expr) => { match $fut.call_method0($py, pyo3::intern!($py, "result")) { Ok(_) => $pyself.into_py($py).call_method0($py, pyo3::intern!($py, "_loop_step")), - Err(err) => $pyself._loop_err($py, err) + Err(err) => $pyself._loop_err($py, err), } }; } @@ -322,10 +283,10 @@ macro_rules! callback_impl_loop_err { }; } -pub(crate) use callback_impl_run; -pub(crate) use callback_impl_run_pytask; -pub(crate) use callback_impl_loop_run; +pub(crate) use callback_impl_loop_err; pub(crate) use callback_impl_loop_pytask; +pub(crate) use callback_impl_loop_run; pub(crate) use callback_impl_loop_step; pub(crate) use callback_impl_loop_wake; -pub(crate) use callback_impl_loop_err; +pub(crate) use callback_impl_run; +pub(crate) use callback_impl_run_pytask; diff --git a/src/http.rs b/src/http.rs index f4ccef32..e0989d7d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,7 @@ -use hyper::{Body, Response, header::{HeaderValue, SERVER as HK_SERVER}}; +use hyper::{ + header::{HeaderValue, SERVER as HK_SERVER}, + Body, Response, +}; pub(crate) const HV_SERVER: HeaderValue = HeaderValue::from_static("granian"); diff --git a/src/lib.rs b/src/lib.rs index 327f9df1..97cec0f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,8 +8,8 @@ mod callbacks; mod http; mod rsgi; mod runtime; -mod tls; mod tcp; +mod tls; mod utils; mod workers; mod ws; diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index 09872589..184848c5 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -2,45 +2,33 @@ use pyo3::prelude::*; use pyo3_asyncio::TaskLocals; use tokio::sync::oneshot; +use super::{ + io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol}, + types::{PyResponse, PyResponseBody, RSGIScope as Scope}, +}; use crate::{ callbacks::{ - CallbackWrapper, - callback_impl_run, - callback_impl_run_pytask, - callback_impl_loop_run, - callback_impl_loop_pytask, - callback_impl_loop_step, - callback_impl_loop_wake, - callback_impl_loop_err + callback_impl_loop_err, callback_impl_loop_pytask, callback_impl_loop_run, callback_impl_loop_step, + callback_impl_loop_wake, callback_impl_run, callback_impl_run_pytask, CallbackWrapper, }, runtime::RuntimeRef, - ws::{HyperWebsocket, UpgradeData} + ws::{HyperWebsocket, UpgradeData}, }; -use super::{ - io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol}, - types::{RSGIScope as Scope, PyResponse, PyResponseBody} -}; - #[pyclass] pub(crate) struct CallbackRunnerHTTP { proto: Py, context: TaskLocals, - cb: PyObject + cb: PyObject, } impl CallbackRunnerHTTP { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: HTTPProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone(), context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap() + cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), } } @@ -58,19 +46,17 @@ macro_rules! callback_impl_done_http { ($self:expr, $py:expr) => { if let Ok(mut proto) = $self.proto.as_ref($py).try_borrow_mut() { if let Some(tx) = proto.tx() { - let _ = tx.send( - PyResponse::Body(PyResponseBody::empty(500, Vec::new())) - ); + let _ = tx.send(PyResponse::Body(PyResponseBody::empty(500, Vec::new()))); } } - } + }; } macro_rules! callback_impl_done_err { ($self:expr, $py:expr) => { log::warn!("Application callable raised an exception"); $self.done($py) - } + }; } #[pyclass] @@ -78,18 +64,18 @@ pub(crate) struct CallbackTaskHTTP { proto: Py, context: TaskLocals, pycontext: PyObject, - cb: PyObject + cb: PyObject, } impl CallbackTaskHTTP { - pub fn new( - py: Python, - cb: PyObject, - proto: Py, - context: TaskLocals - ) -> PyResult { + pub fn new(py: Python, cb: PyObject, proto: Py, context: TaskLocals) -> PyResult { let pyctx = context.context(py); - Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb }) + Ok(Self { + proto, + context, + pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), + cb, + }) } fn done(&self, py: Python) { @@ -122,21 +108,16 @@ pub(crate) struct CallbackWrappedRunnerHTTP { context: TaskLocals, cb: PyObject, #[pyo3(get)] - scope: PyObject + scope: PyObject, } impl CallbackWrappedRunnerHTTP { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: HTTPProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: HTTPProtocol, scope: Scope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, cb: cb.callback, - scope: scope.into_py(py) + scope: scope.into_py(py), } } @@ -162,21 +143,16 @@ impl CallbackWrappedRunnerHTTP { pub(crate) struct CallbackRunnerWebsocket { proto: Py, context: TaskLocals, - cb: PyObject + cb: PyObject, } impl CallbackRunnerWebsocket { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: WebsocketProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { let pyproto = Py::new(py, proto).unwrap(); Self { proto: pyproto.clone(), context: cb.context, - cb: cb.callback.call1(py, (scope, pyproto)).unwrap() + cb: cb.callback.call1(py, (scope, pyproto)).unwrap(), } } @@ -197,7 +173,7 @@ macro_rules! callback_impl_done_ws { let _ = tx.send(res); } } - } + }; } #[pyclass] @@ -205,18 +181,18 @@ pub(crate) struct CallbackTaskWebsocket { proto: Py, context: TaskLocals, pycontext: PyObject, - cb: PyObject + cb: PyObject, } impl CallbackTaskWebsocket { - pub fn new( - py: Python, - cb: PyObject, - proto: Py, - context: TaskLocals - ) -> PyResult { + pub fn new(py: Python, cb: PyObject, proto: Py, context: TaskLocals) -> PyResult { let pyctx = context.context(py); - Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb }) + Ok(Self { + proto, + context, + pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), + cb, + }) } fn done(&self, py: Python) { @@ -249,21 +225,16 @@ pub(crate) struct CallbackWrappedRunnerWebsocket { context: TaskLocals, cb: PyObject, #[pyo3(get)] - scope: PyObject + scope: PyObject, } impl CallbackWrappedRunnerWebsocket { - pub fn new( - py: Python, - cb: CallbackWrapper, - proto: WebsocketProtocol, - scope: Scope - ) -> Self { + pub fn new(py: Python, cb: CallbackWrapper, proto: WebsocketProtocol, scope: Scope) -> Self { Self { proto: Py::new(py, proto).unwrap(), context: cb.context, cb: cb.callback, - scope: scope.into_py(py) + scope: scope.into_py(py), } } @@ -291,7 +262,7 @@ macro_rules! call_impl_rtb_http { cb: CallbackWrapper, rt: RuntimeRef, req: hyper::Request, - scope: Scope + scope: Scope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, tx, req); @@ -311,7 +282,7 @@ macro_rules! call_impl_rtt_http { cb: CallbackWrapper, rt: RuntimeRef, req: hyper::Request, - scope: Scope + scope: Scope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); let protocol = HTTPProtocol::new(rt, tx, req); @@ -334,7 +305,7 @@ macro_rules! call_impl_rtb_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope + scope: Scope, ) -> oneshot::Receiver<(i32, bool)> { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); @@ -355,7 +326,7 @@ macro_rules! call_impl_rtt_ws { rt: RuntimeRef, ws: HyperWebsocket, upgrade: UpgradeData, - scope: Scope + scope: Scope, ) -> oneshot::Receiver<(i32, bool)> { let (tx, rx) = oneshot::channel(); let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); diff --git a/src/rsgi/errors.rs b/src/rsgi/errors.rs index 4f7cbbec..fe27eea4 100644 --- a/src/rsgi/errors.rs +++ b/src/rsgi/errors.rs @@ -1,6 +1,5 @@ use pyo3::{create_exception, exceptions::PyRuntimeError}; - create_exception!(_granian, RSGIProtocolError, PyRuntimeError, "RSGIProtocolError"); create_exception!(_granian, RSGIProtocolClosed, PyRuntimeError, "RSGIProtocolClosed"); diff --git a/src/rsgi/http.rs b/src/rsgi/http.rs index f786be62..76b7d66e 100644 --- a/src/rsgi/http.rs +++ b/src/rsgi/http.rs @@ -1,34 +1,22 @@ use hyper::{ - Body, - Request, - Response, - StatusCode, - header::SERVER as HK_SERVER, - http::response::Builder as ResponseBuilder + header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, Body, Request, Response, StatusCode, }; use std::net::SocketAddr; use tokio::sync::mpsc; -use crate::{ - callbacks::CallbackWrapper, - http::{HV_SERVER, response_500}, - runtime::RuntimeRef, - ws::{UpgradeData, is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade} -}; use super::{ callbacks::{ - call_rtb_http, - call_rtb_http_pyw, - call_rtb_ws, - call_rtb_ws_pyw, - call_rtt_http, - call_rtt_http_pyw, - call_rtt_ws, - call_rtt_ws_pyw + call_rtb_http, call_rtb_http_pyw, call_rtb_ws, call_rtb_ws_pyw, call_rtt_http, call_rtt_http_pyw, call_rtt_ws, + call_rtt_ws_pyw, }, - types::{RSGIScope as Scope, PyResponse} + types::{PyResponse, RSGIScope as Scope}, +}; +use crate::{ + callbacks::CallbackWrapper, + http::{response_500, HV_SERVER}, + runtime::RuntimeRef, + ws::{is_upgrade_request as is_ws_upgrade, upgrade_intent as ws_upgrade, UpgradeData}, }; - macro_rules! default_scope { ($server_addr:expr, $client_addr:expr, $req:expr, $scheme:expr) => { @@ -40,7 +28,7 @@ macro_rules! default_scope { $req.method().as_ref(), $server_addr, $client_addr, - $req.headers() + $req.headers(), ) }; } @@ -48,12 +36,8 @@ macro_rules! default_scope { macro_rules! handle_http_response { ($handler:expr, $rt:expr, $callback:expr, $req:expr, $scope:expr) => { match $handler($callback, $rt, $req, $scope).await { - Ok(PyResponse::Body(pyres)) => { - pyres.to_response() - }, - Ok(PyResponse::File(pyres)) => { - pyres.to_response().await - }, + Ok(PyResponse::Body(pyres)) => pyres.to_response(), + Ok(PyResponse::File(pyres)) => pyres.to_response().await, _ => { log::error!("RSGI protocol failure"); response_500() @@ -70,7 +54,7 @@ macro_rules! handle_request { server_addr: SocketAddr, client_addr: SocketAddr, req: Request, - scheme: &str + scheme: &str, ) -> Response { let scope = default_scope!(server_addr, client_addr, &req, scheme); handle_http_response!($handler, rt, callback, req, scope) @@ -86,7 +70,7 @@ macro_rules! handle_request_with_ws { server_addr: SocketAddr, client_addr: SocketAddr, req: Request, - scheme: &str + scheme: &str, ) -> Response { let mut scope = default_scope!(server_addr, client_addr, &req, scheme); @@ -101,27 +85,23 @@ macro_rules! handle_request_with_ws { rt.inner.spawn(async move { let tx_ref = restx.clone(); - match $handler_ws( - callback, - rth, - ws, - UpgradeData::new(res, restx), - scope - ).await { + match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await { Ok((status, consumed)) => { if !consumed { - let _ = tx_ref.send( - ResponseBuilder::new() - .status( - StatusCode::from_u16(status as u16) - .unwrap_or(StatusCode::FORBIDDEN) - ) - .header(HK_SERVER, HV_SERVER) - .body(Body::from("")) - .unwrap() - ).await; + let _ = tx_ref + .send( + ResponseBuilder::new() + .status( + StatusCode::from_u16(status as u16) + .unwrap_or(StatusCode::FORBIDDEN), + ) + .header(HK_SERVER, HV_SERVER) + .body(Body::from("")) + .unwrap(), + ) + .await; } - }, + } _ => { log::error!("RSGI protocol failure"); let _ = tx_ref.send(response_500()).await; @@ -133,10 +113,10 @@ macro_rules! handle_request_with_ws { Some(res) => { resrx.close(); res - }, - _ => response_500() - } - }, + } + _ => response_500(), + }; + } Err(err) => { return ResponseBuilder::new() .status(StatusCode::BAD_REQUEST) @@ -149,7 +129,6 @@ macro_rules! handle_request_with_ws { handle_http_response!($handler_req, rt, callback, req, scope) } - }; } diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index 4f061484..5b629907 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -1,35 +1,40 @@ use bytes::Bytes; -use futures::{sink::SinkExt, stream::{SplitSink, SplitStream, StreamExt}}; -use hyper::{body::{Body, Sender as BodySender, HttpBody}, Request}; +use futures::{ + sink::SinkExt, + stream::{SplitSink, SplitStream, StreamExt}, +}; +use hyper::{ + body::{Body, HttpBody, Sender as BodySender}, + Request, +}; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyString}; use std::sync::Arc; -use tokio_tungstenite::WebSocketStream; use tokio::sync::{oneshot, Mutex}; +use tokio_tungstenite::WebSocketStream; use tungstenite::Message; -use crate::{ - runtime::{Runtime, RuntimeRef, future_into_py_iter, future_into_py_futlike}, - ws::{HyperWebsocket, UpgradeData} -}; use super::{ errors::{error_proto, error_stream}, - types::{PyResponse, PyResponseBody, PyResponseFile} + types::{PyResponse, PyResponseBody, PyResponseFile}, +}; +use crate::{ + runtime::{future_into_py_futlike, future_into_py_iter, Runtime, RuntimeRef}, + ws::{HyperWebsocket, UpgradeData}, }; - -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct RSGIHTTPStreamTransport { rt: RuntimeRef, - tx: Arc> + tx: Arc>, } impl RSGIHTTPStreamTransport { - pub fn new( - rt: RuntimeRef, - transport: BodySender - ) -> Self { - Self { rt: rt, tx: Arc::new(Mutex::new(transport)) } + pub fn new(rt: RuntimeRef, transport: BodySender) -> Self { + Self { + rt, + tx: Arc::new(Mutex::new(transport)), + } } } @@ -41,8 +46,8 @@ impl RSGIHTTPStreamTransport { if let Ok(mut stream) = transport.try_lock() { return match stream.send_data(data.into()).await { Ok(_) => Ok(()), - _ => error_stream!() - } + _ => error_stream!(), + }; } error_proto!() }) @@ -54,31 +59,27 @@ impl RSGIHTTPStreamTransport { if let Ok(mut stream) = transport.try_lock() { return match stream.send_data(data.into()).await { Ok(_) => Ok(()), - _ => error_stream!() - } + _ => error_stream!(), + }; } error_proto!() }) } } -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct RSGIHTTPProtocol { rt: RuntimeRef, tx: Option>, - body: Arc> + body: Arc>, } impl RSGIHTTPProtocol { - pub fn new( - rt: RuntimeRef, - tx: oneshot::Sender, - request: Request - ) -> Self { + pub fn new(rt: RuntimeRef, tx: oneshot::Sender, request: Request) -> Self { Self { rt, tx: Some(tx), - body: Arc::new(Mutex::new(request.into_body())) + body: Arc::new(Mutex::new(request.into_body())), } } @@ -110,14 +111,13 @@ impl RSGIHTTPProtocol { let mut bodym = body_ref.lock().await; let body = &mut *bodym; if body.is_end_stream() { - return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")) + return Err(pyo3::exceptions::PyStopAsyncIteration::new_err("stream exhausted")); } - let chunk = body.data().await.map_or_else(|| Bytes::new(), |buf| { - buf.unwrap_or_else(|_| Bytes::new()) - }); - Ok(Python::with_gil(|py| { - PyBytes::new(py, &chunk[..]).to_object(py) - })) + let chunk = body + .data() + .await + .map_or_else(Bytes::new, |buf| buf.unwrap_or_else(|_| Bytes::new())); + Ok(Python::with_gil(|py| PyBytes::new(py, &chunk[..]).to_object(py))) })?; Ok(Some(fut)) } @@ -125,36 +125,28 @@ impl RSGIHTTPProtocol { #[pyo3(signature = (status=200, headers=vec![]))] fn response_empty(&mut self, status: u16, headers: Vec<(String, String)>) { if let Some(tx) = self.tx.take() { - let _ = tx.send( - PyResponse::Body(PyResponseBody::empty(status, headers)) - ); + let _ = tx.send(PyResponse::Body(PyResponseBody::empty(status, headers))); } } #[pyo3(signature = (status=200, headers=vec![], body=vec![]))] fn response_bytes(&mut self, status: u16, headers: Vec<(String, String)>, body: Vec) { if let Some(tx) = self.tx.take() { - let _ = tx.send( - PyResponse::Body(PyResponseBody::from_bytes(status, headers, body)) - ); + let _ = tx.send(PyResponse::Body(PyResponseBody::from_bytes(status, headers, body))); } } - #[pyo3(signature = (status=200, headers=vec![], body="".to_string()))] + #[pyo3(signature = (status=200, headers=vec![], body=String::new()))] fn response_str(&mut self, status: u16, headers: Vec<(String, String)>, body: String) { if let Some(tx) = self.tx.take() { - let _ = tx.send( - PyResponse::Body(PyResponseBody::from_string(status, headers, body)) - ); + let _ = tx.send(PyResponse::Body(PyResponseBody::from_string(status, headers, body))); } } #[pyo3(signature = (status, headers, file))] fn response_file(&mut self, status: u16, headers: Vec<(String, String)>, file: String) { if let Some(tx) = self.tx.take() { - let _ = tx.send( - PyResponse::File(PyResponseFile::new(status, headers, file)) - ); + let _ = tx.send(PyResponse::File(PyResponseFile::new(status, headers, file))); } } @@ -163,34 +155,33 @@ impl RSGIHTTPProtocol { &mut self, py: Python<'p>, status: u16, - headers: Vec<(String, String)> + headers: Vec<(String, String)>, ) -> PyResult<&'p PyAny> { if let Some(tx) = self.tx.take() { let (body_tx, body_stream) = Body::channel(); - let _ = tx.send( - PyResponse::Body(PyResponseBody::new(status, headers, body_stream)) - ); + let _ = tx.send(PyResponse::Body(PyResponseBody::new(status, headers, body_stream))); let trx = Py::new(py, RSGIHTTPStreamTransport::new(self.rt.clone(), body_tx))?; - return Ok(trx.into_ref(py)) + return Ok(trx.into_ref(py)); } error_proto!() } } -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct RSGIWebsocketTransport { rt: RuntimeRef, tx: Arc, Message>>>, - rx: Arc>>> + rx: Arc>>>, } impl RSGIWebsocketTransport { - pub fn new( - rt: RuntimeRef, - transport: WebSocketStream - ) -> Self { + pub fn new(rt: RuntimeRef, transport: WebSocketStream) -> Self { let (tx, rx) = transport.split(); - Self { rt: rt, tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)) } + Self { + rt, + tx: Arc::new(Mutex::new(tx)), + rx: Arc::new(Mutex::new(rx)), + } } pub fn close(&self) { @@ -209,27 +200,14 @@ impl RSGIWebsocketTransport { let transport = self.rx.clone(); future_into_py_futlike(self.rt.clone(), py, async move { if let Ok(mut stream) = transport.try_lock() { - loop { - match stream.next().await { - Some(recv) => { - match recv { - Ok(Message::Ping(_)) => { - continue - }, - Ok(message) => { - return message_into_py(message) - }, - _ => { - break - } - } - }, - _ => { - break - } + while let Some(recv) = stream.next().await { + match recv { + Ok(Message::Ping(_)) => continue, + Ok(message) => return message_into_py(message), + _ => break, } } - return error_stream!() + return error_stream!(); } error_proto!() }) @@ -241,8 +219,8 @@ impl RSGIWebsocketTransport { if let Ok(mut stream) = transport.try_lock() { return match stream.send(Message::Binary(data)).await { Ok(_) => Ok(()), - _ => error_stream!() - } + _ => error_stream!(), + }; } error_proto!() }) @@ -254,22 +232,22 @@ impl RSGIWebsocketTransport { if let Ok(mut stream) = transport.try_lock() { return match stream.send(Message::Text(data)).await { Ok(_) => Ok(()), - _ => error_stream!() - } + _ => error_stream!(), + }; } error_proto!() }) } } -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct RSGIWebsocketProtocol { rt: RuntimeRef, tx: Option>, websocket: Arc>, upgrade: Option, transport: Arc>>>, - status: i32 + status: i32, } impl RSGIWebsocketProtocol { @@ -277,7 +255,7 @@ impl RSGIWebsocketProtocol { rt: RuntimeRef, tx: oneshot::Sender<(i32, bool)>, websocket: HyperWebsocket, - upgrade: UpgradeData + upgrade: UpgradeData, ) -> Self { Self { rt, @@ -285,15 +263,12 @@ impl RSGIWebsocketProtocol { websocket: Arc::new(Mutex::new(websocket)), upgrade: Some(upgrade), transport: Arc::new(Mutex::new(None)), - status: 0 + status: 0, } } fn consumed(&self) -> bool { - match &self.upgrade { - Some(_) => false, - _ => true - } + self.upgrade.is_none() } pub fn tx(&mut self) -> (Option>, (i32, bool)) { @@ -304,18 +279,20 @@ impl RSGIWebsocketProtocol { enum WebsocketMessageType { Close = 0, Bytes = 1, - Text = 2 + Text = 2, } #[pyclass] struct WebsocketInboundCloseMessage { #[pyo3(get)] - kind: usize + kind: usize, } impl WebsocketInboundCloseMessage { pub fn new() -> Self { - Self { kind: WebsocketMessageType::Close as usize } + Self { + kind: WebsocketMessageType::Close as usize, + } } } @@ -324,12 +301,15 @@ struct WebsocketInboundBytesMessage { #[pyo3(get)] kind: usize, #[pyo3(get)] - data: Py + data: Py, } impl WebsocketInboundBytesMessage { - pub fn new(data:Py) -> Self { - Self { kind: WebsocketMessageType::Bytes as usize, data: data } + pub fn new(data: Py) -> Self { + Self { + kind: WebsocketMessageType::Bytes as usize, + data, + } } } @@ -338,12 +318,15 @@ struct WebsocketInboundTextMessage { #[pyo3(get)] kind: usize, #[pyo3(get)] - data: Py + data: Py, } impl WebsocketInboundTextMessage { pub fn new(data: Py) -> Self { - Self { kind: WebsocketMessageType::Text as usize, data: data } + Self { + kind: WebsocketMessageType::Text as usize, + data, + } } } @@ -374,23 +357,18 @@ impl RSGIWebsocketProtocol { future_into_py_iter(self.rt.clone(), py, async move { let mut ws = transport.lock().await; match upgrade.send().await { - Ok(_) => { - match (&mut *ws).await { - Ok(stream) => { - let mut trx = itransport.lock().await; - Ok(Python::with_gil(|py| { - let pytransport = Py::new( - py, - RSGIWebsocketTransport::new(rth, stream) - ).unwrap(); - *trx = Some(pytransport.clone()); - pytransport - })) - }, - _ => error_proto!() + Ok(_) => match (&mut *ws).await { + Ok(stream) => { + let mut trx = itransport.lock().await; + Ok(Python::with_gil(|py| { + let pytransport = Py::new(py, RSGIWebsocketTransport::new(rth, stream)).unwrap(); + *trx = Some(pytransport.clone()); + pytransport + })) } + _ => error_proto!(), }, - _ => error_proto!() + _ => error_proto!(), } }) } @@ -399,25 +377,13 @@ impl RSGIWebsocketProtocol { #[inline(always)] fn message_into_py(message: Message) -> PyResult { match message { - Message::Binary(message) => { - Ok(Python::with_gil(|py| { - WebsocketInboundBytesMessage::new( - PyBytes::new(py, &message).into() - ).into_py(py) - })) - }, - Message::Text(message) => { - Ok(Python::with_gil(|py| { - WebsocketInboundTextMessage::new( - PyString::new(py, &message).into() - ).into_py(py) - })) - }, - Message::Close(_) => { - Ok(Python::with_gil(|py| { - WebsocketInboundCloseMessage::new().into_py(py) - })) - } + Message::Binary(message) => Ok(Python::with_gil(|py| { + WebsocketInboundBytesMessage::new(PyBytes::new(py, &message).into()).into_py(py) + })), + Message::Text(message) => Ok(Python::with_gil(|py| { + WebsocketInboundTextMessage::new(PyString::new(py, &message).into()).into_py(py) + })), + Message::Close(_) => Ok(Python::with_gil(|py| WebsocketInboundCloseMessage::new().into_py(py))), v => { log::warn!("Unsupported websocket message received {:?}", v); error_proto!() diff --git a/src/rsgi/serve.rs b/src/rsgi/serve.rs index f361c095..fec2c5f2 100644 --- a/src/rsgi/serve.rs +++ b/src/rsgi/serve.rs @@ -1,28 +1,14 @@ use pyo3::prelude::*; -use crate::{ - workers::{ - WorkerConfig, - serve_rth, - serve_wth, - serve_rth_ssl, - serve_wth_ssl - } -}; use super::http::{ - handle_rtb, - handle_rtb_pyw, - handle_rtt, - handle_rtt_pyw, - handle_rtb_ws, - handle_rtb_ws_pyw, - handle_rtt_ws, - handle_rtt_ws_pyw + handle_rtb, handle_rtb_pyw, handle_rtb_ws, handle_rtb_ws_pyw, handle_rtt, handle_rtt_pyw, handle_rtt_ws, + handle_rtt_ws_pyw, }; +use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig}; -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub struct RSGIWorker { - config: WorkerConfig + config: WorkerConfig, } impl RSGIWorker { @@ -73,7 +59,7 @@ impl RSGIWorker { opt_enabled: bool, ssl_enabled: bool, ssl_cert: Option<&str>, - ssl_key: Option<&str> + ssl_key: Option<&str>, ) -> PyResult { Ok(Self { config: WorkerConfig::new( @@ -87,22 +73,16 @@ impl RSGIWorker { opt_enabled, ssl_enabled, ssl_cert, - ssl_key - ) + ssl_key, + ), }) } - fn serve_rth( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { match ( self.config.websockets_enabled, self.config.ssl_enabled, - self.config.opt_enabled + self.config.opt_enabled, ) { (false, false, true) => self._serve_rth(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_rth_pyw(callback, event_loop, context, signal_rx), @@ -111,21 +91,15 @@ impl RSGIWorker { (false, true, true) => self._serve_rth_ssl(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_rth_ssl_pyw(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_rth_ssl_ws(callback, event_loop, context, signal_rx), - (true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx) + (true, true, false) => self._serve_rth_ssl_ws_pyw(callback, event_loop, context, signal_rx), } } - fn serve_wth( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { match ( self.config.websockets_enabled, self.config.ssl_enabled, - self.config.opt_enabled + self.config.opt_enabled, ) { (false, false, true) => self._serve_wth(callback, event_loop, context, signal_rx), (false, false, false) => self._serve_wth_pyw(callback, event_loop, context, signal_rx), @@ -134,7 +108,7 @@ impl RSGIWorker { (false, true, true) => self._serve_wth_ssl(callback, event_loop, context, signal_rx), (false, true, false) => self._serve_wth_ssl_pyw(callback, event_loop, context, signal_rx), (true, true, true) => self._serve_wth_ssl_ws(callback, event_loop, context, signal_rx), - (true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx) + (true, true, false) => self._serve_wth_ssl_ws_pyw(callback, event_loop, context, signal_rx), } } } diff --git a/src/rsgi/types.rs b/src/rsgi/types.rs index 37123235..66037012 100644 --- a/src/rsgi/types.rs +++ b/src/rsgi/types.rs @@ -1,6 +1,6 @@ use hyper::{ header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER}, - Body, Uri, Version + Body, Uri, Version, }; use pyo3::prelude::*; use pyo3::types::PyString; @@ -10,11 +10,10 @@ use tokio_util::codec::{BytesCodec, FramedRead}; use crate::http::HV_SERVER; - -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] #[derive(Clone)] pub(crate) struct RSGIHeaders { - inner: HeaderMap + inner: HeaderMap, } impl RSGIHeaders { @@ -29,7 +28,7 @@ impl RSGIHeaders { let mut ret = Vec::with_capacity(self.inner.keys_len()); for key in self.inner.keys() { ret.push(key.as_str()); - }; + } ret } @@ -37,15 +36,15 @@ impl RSGIHeaders { let mut ret = Vec::with_capacity(self.inner.keys_len()); for val in self.inner.values() { ret.push(val.to_str().unwrap()); - }; + } Ok(ret) } fn items(&self) -> PyResult> { let mut ret = Vec::with_capacity(self.inner.keys_len()); - for (key, val) in self.inner.iter() { + for (key, val) in &self.inner { ret.push((key.as_str(), val.to_str().unwrap())); - }; + } Ok(ret) } @@ -56,18 +55,16 @@ impl RSGIHeaders { #[pyo3(signature = (key, default=None))] fn get(&self, py: Python, key: &str, default: Option) -> Option { match self.inner.get(key) { - Some(val) => { - match val.to_str() { - Ok(string) => Some(PyString::new(py, string).into()), - _ => default - } + Some(val) => match val.to_str() { + Ok(string) => Some(PyString::new(py, string).into()), + _ => default, }, - _ => default + _ => default, } } } -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub(crate) struct RSGIScope { #[pyo3(get)] proto: String, @@ -84,7 +81,7 @@ pub(crate) struct RSGIScope { #[pyo3(get)] client: String, #[pyo3(get)] - headers: RSGIHeaders + headers: RSGIHeaders, } impl RSGIScope { @@ -96,23 +93,23 @@ impl RSGIScope { method: &str, server: SocketAddr, client: SocketAddr, - headers: &HeaderMap + headers: &HeaderMap, ) -> Self { Self { proto: proto.to_string(), - http_version: http_version, + http_version, rsgi_version: "1.2".to_string(), scheme: scheme.to_string(), method: method.to_string(), - uri: uri, + uri, server: server.to_string(), client: client.to_string(), - headers: RSGIHeaders::new(headers) + headers: RSGIHeaders::new(headers), } } pub fn set_proto(&mut self, value: &str) { - self.proto = value.to_string() + self.proto = value.to_string(); } } @@ -125,7 +122,7 @@ impl RSGIScope { Version::HTTP_11 => "1.1", Version::HTTP_2 => "2", Version::HTTP_3 => "3", - _ => "1" + _ => "1", } } @@ -142,38 +139,36 @@ impl RSGIScope { pub(crate) enum PyResponse { Body(PyResponseBody), - File(PyResponseFile) + File(PyResponseFile), } pub(crate) struct PyResponseBody { status: u16, headers: Vec<(String, String)>, - body: Body + body: Body, } pub(crate) struct PyResponseFile { status: u16, headers: Vec<(String, String)>, - file_path: String + file_path: String, } macro_rules! response_head_from_py { - ($status:expr, $headers:expr, $res:expr) => { - { - let mut rh = hyper::http::HeaderMap::new(); - - rh.insert(HK_SERVER, HV_SERVER); - for (key, value) in $headers { - rh.append( - HeaderName::from_bytes(key.as_bytes()).unwrap(), - HeaderValue::from_str(&value).unwrap() - ); - } - - *$res.status_mut() = $status.try_into().unwrap(); - *$res.headers_mut() = rh; + ($status:expr, $headers:expr, $res:expr) => {{ + let mut rh = hyper::http::HeaderMap::new(); + + rh.insert(HK_SERVER, HV_SERVER); + for (key, value) in $headers { + rh.append( + HeaderName::from_bytes(key.as_bytes()).unwrap(), + HeaderValue::from_str(&value).unwrap(), + ); } - } + + *$res.status_mut() = $status.try_into().unwrap(); + *$res.headers_mut() = rh; + }}; } impl PyResponseBody { @@ -182,18 +177,30 @@ impl PyResponseBody { } pub fn empty(status: u16, headers: Vec<(String, String)>) -> Self { - Self { status, headers, body: Body::empty() } + Self { + status, + headers, + body: Body::empty(), + } } pub fn from_bytes(status: u16, headers: Vec<(String, String)>, body: Vec) -> Self { - Self { status, headers, body: Body::from(body) } + Self { + status, + headers, + body: Body::from(body), + } } pub fn from_string(status: u16, headers: Vec<(String, String)>, body: String) -> Self { - Self { status, headers, body: Body::from(body) } + Self { + status, + headers, + body: Body::from(body), + } } - pub fn to_response(self) -> hyper::Response:: { + pub fn to_response(self) -> hyper::Response { let mut res = hyper::Response::::new(self.body); response_head_from_py!(self.status, &self.headers, res); res @@ -202,10 +209,14 @@ impl PyResponseBody { impl PyResponseFile { pub fn new(status: u16, headers: Vec<(String, String)>, file_path: String) -> Self { - Self { status, headers, file_path } + Self { + status, + headers, + file_path, + } } - pub async fn to_response(&self) -> hyper::Response:: { + pub async fn to_response(&self) -> hyper::Response { let file = File::open(&self.file_path).await.unwrap(); let stream = FramedRead::new(file, BytesCodec::new()); let mut res = hyper::Response::::new(Body::wrap_stream(stream)); diff --git a/src/runtime.rs b/src/runtime.rs index 5b5e2198..4ed03ebb 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,12 +1,19 @@ use once_cell::unsync::OnceCell as UnsyncOnceCell; -use pyo3_asyncio::TaskLocals; use pyo3::prelude::*; -use std::{future::Future, io, pin::Pin, sync::{Arc, Mutex}}; -use tokio::{runtime::Builder, task::{JoinHandle, LocalSet}}; +use pyo3_asyncio::TaskLocals; +use std::{ + future::Future, + io, + pin::Pin, + sync::{Arc, Mutex}, +}; +use tokio::{ + runtime::Builder, + task::{JoinHandle, LocalSet}, +}; use super::callbacks::{PyFutureAwaitable, PyIterAwaitable}; - tokio::task_local! { static TASK_LOCALS: UnsyncOnceCell; } @@ -27,11 +34,7 @@ pub trait Runtime: Send + 'static { } pub trait ContextExt: Runtime { - fn scope( - &self, - locals: TaskLocals, - fut: F - ) -> Pin + Send>> + fn scope(&self, locals: TaskLocals, fut: F) -> Pin + Send>> where F: Future + Send + 'static; @@ -45,36 +48,34 @@ pub trait SpawnLocalExt: Runtime { } pub trait LocalContextExt: Runtime { - fn scope_local( - &self, - locals: TaskLocals, - fut: F - ) -> Pin>> + fn scope_local(&self, locals: TaskLocals, fut: F) -> Pin>> where F: Future + 'static; } pub(crate) struct RuntimeWrapper { - rt: tokio::runtime::Runtime + rt: tokio::runtime::Runtime, } impl RuntimeWrapper { pub fn new(blocking_threads: usize) -> Self { - Self { rt: default_runtime(blocking_threads).unwrap() } + Self { + rt: default_runtime(blocking_threads).unwrap(), + } } pub fn with_runtime(rt: tokio::runtime::Runtime) -> Self { - Self { rt: rt } + Self { rt } } pub fn handler(&self) -> RuntimeRef { - RuntimeRef::new(self.rt.handle().to_owned()) + RuntimeRef::new(self.rt.handle().clone()) } } #[derive(Clone)] pub struct RuntimeRef { - pub inner: tokio::runtime::Handle + pub inner: tokio::runtime::Handle, } impl RuntimeRef { @@ -108,11 +109,7 @@ impl Runtime for RuntimeRef { } impl ContextExt for RuntimeRef { - fn scope( - &self, - locals: TaskLocals, - fut: F - ) -> Pin + Send>> + fn scope(&self, locals: TaskLocals, fut: F) -> Pin + Send>> where F: Future + Send + 'static, { @@ -123,7 +120,7 @@ impl ContextExt for RuntimeRef { } fn get_task_locals() -> Option { - match TASK_LOCALS.try_with(|c| c.get().map(|locals| locals.clone())) { + match TASK_LOCALS.try_with(|c| c.get().cloned()) { Ok(locals) => locals, Err(_) => None, } @@ -140,11 +137,7 @@ impl SpawnLocalExt for RuntimeRef { } impl LocalContextExt for RuntimeRef { - fn scope_local( - &self, - locals: TaskLocals, - fut: F - ) -> Pin>> + fn scope_local(&self, locals: TaskLocals, fut: F) -> Pin>> where F: Future + 'static, { @@ -169,7 +162,7 @@ pub(crate) fn init_runtime_mt(threads: usize, blocking_threads: usize) -> Runtim .max_blocking_threads(blocking_threads) .enable_all() .build() - .unwrap() + .unwrap(), ) } @@ -177,12 +170,8 @@ pub(crate) fn init_runtime_st(blocking_threads: usize) -> RuntimeWrapper { RuntimeWrapper::new(blocking_threads) } -pub(crate) fn into_future( - awaitable: &PyAny, -) -> PyResult> + Send> { - pyo3_asyncio::into_future_with_locals( - &get_current_locals::(awaitable.py())?, awaitable - ) +pub(crate) fn into_future(awaitable: &PyAny) -> PyResult> + Send> { + pyo3_asyncio::into_future_with_locals(&get_current_locals::(awaitable.py())?, awaitable) } #[inline] @@ -241,10 +230,7 @@ where rt.spawn(async move { let result = fut.await; Python::with_gil(move |py| { - PyFutureAwaitable::set_result( - py_aw.as_ref(py).borrow_mut(), - result.map(|v| v.into_py(py)) - ); + PyFutureAwaitable::set_result(py_aw.as_ref(py).borrow_mut(), result.map(|v| v.into_py(py))); }); }); @@ -269,11 +255,7 @@ where let rth = rt.handler(); rt.spawn(async move { - let val = rth.scope( - task_locals.clone(), - fut - ) - .await; + let val = rth.scope(task_locals.clone(), fut).await; if let Ok(mut result) = result_tx.lock() { *result = Some(val.unwrap()); } @@ -292,7 +274,7 @@ where pub(crate) fn block_on_local(rt: RuntimeWrapper, local: LocalSet, fut: F) where - F: Future + 'static + F: Future + 'static, { local.block_on(&rt.rt, fut); } diff --git a/src/tcp.rs b/src/tcp.rs index 4aa54eef..7b32fc95 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -9,10 +9,9 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket}; use socket2::{Domain, Protocol, Socket, Type}; - -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub struct ListenerHolder { - socket: TcpListener + socket: TcpListener, } #[pymethods] @@ -20,28 +19,19 @@ impl ListenerHolder { #[cfg(unix)] #[new] pub fn new(fd: i32) -> PyResult { - let socket = unsafe { - TcpListener::from_raw_fd(fd) - }; - Ok(Self { socket: socket }) + let socket = unsafe { TcpListener::from_raw_fd(fd) }; + Ok(Self { socket }) } #[cfg(windows)] #[new] pub fn new(fd: u64) -> PyResult { - let socket = unsafe { - TcpListener::from_raw_socket(fd) - }; - Ok(Self { socket: socket }) + let socket = unsafe { TcpListener::from_raw_socket(fd) }; + Ok(Self { socket }) } #[classmethod] - pub fn from_address( - _cls: &PyType, - address: &str, - port: u16, - backlog: i32 - ) -> PyResult { + pub fn from_address(_cls: &PyType, address: &str, port: u16, backlog: i32) -> PyResult { let address: SocketAddr = (address.parse::()?, port).into(); let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?; socket.set_reuse_address(true)?; @@ -54,17 +44,13 @@ impl ListenerHolder { #[cfg(unix)] pub fn __getstate__(&self, py: Python) -> PyObject { let fd = self.socket.as_raw_fd(); - ( - fd.into_py(py), - ).to_object(py) + (fd.into_py(py),).to_object(py) } #[cfg(windows)] pub fn __getstate__(&self, py: Python) -> PyObject { let fd = self.socket.as_raw_socket(); - ( - fd.into_py(py), - ).to_object(py) + (fd.into_py(py),).to_object(py) } #[cfg(unix)] @@ -84,7 +70,6 @@ impl ListenerHolder { } } - pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> { module.add_class::()?; diff --git a/src/tls.rs b/src/tls.rs index d97d39bf..9355ed2f 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,20 +1,22 @@ use futures::stream::StreamExt; -use hyper::server::{accept, conn::{AddrIncoming, AddrStream}}; +use hyper::server::{ + accept, + conn::{AddrIncoming, AddrStream}, +}; use std::{fs, future, io, iter::Iterator, net::TcpListener, sync::Arc}; use tls_listener::{Error as TlsError, TlsListener}; use tokio_rustls::{ - TlsAcceptor, rustls::{Certificate, PrivateKey, ServerConfig}, - server::TlsStream + server::TlsStream, + TlsAcceptor, }; - pub(crate) type TlsAddrStream = TlsStream; pub(crate) fn tls_listen( config: Arc, - tcp: TcpListener -) -> impl accept::Accept> { + tcp: TcpListener, +) -> impl accept::Accept> { tcp.set_nonblocking(true).unwrap(); let tcp_listener = tokio::net::TcpListener::from_std(tcp).unwrap(); let incoming = AddrIncoming::from_listener(tcp_listener).unwrap(); @@ -34,30 +36,26 @@ fn tls_error(err: String) -> io::Error { } pub(crate) fn load_certs(filename: &str) -> io::Result> { - let certfile = fs::File::open(filename) - .map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?; + let certfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?; let mut reader = io::BufReader::new(certfile); - let certs = rustls_pemfile::certs(&mut reader) - .map_err(|_| tls_error("failed to load certificate".into()))?; + let certs = rustls_pemfile::certs(&mut reader).map_err(|_| tls_error("failed to load certificate".into()))?; Ok(certs.into_iter().map(Certificate).collect()) } pub(crate) fn load_private_key(filename: &str) -> io::Result { - let keyfile = fs::File::open(filename) - .map_err(|e| tls_error(format!("failed to open {}: {}", filename, e)))?; + let keyfile = fs::File::open(filename).map_err(|e| tls_error(format!("failed to open {filename}: {e}")))?; let mut reader = io::BufReader::new(keyfile); - let keys = rustls_pemfile::read_all(&mut reader) - .map_err(|_| tls_error("failed to load private key".into()))?; + let keys = rustls_pemfile::read_all(&mut reader).map_err(|_| tls_error("failed to load private key".into()))?; if keys.len() != 1 { return Err(tls_error("expected a single private key".into())); } let key = match &keys[0] { - rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.to_vec()), - rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.to_vec()), - rustls_pemfile::Item::ECKey(key) => PrivateKey(key.to_vec()), + rustls_pemfile::Item::RSAKey(key) => PrivateKey(key.clone()), + rustls_pemfile::Item::PKCS8Key(key) => PrivateKey(key.clone()), + rustls_pemfile::Item::ECKey(key) => PrivateKey(key.clone()), _ => { return Err(tls_error("failed to load private key".into())); } diff --git a/src/utils.rs b/src/utils.rs index 379ec656..93cf4c3d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,13 +1,15 @@ pub(crate) fn header_contains_value( headers: &hyper::HeaderMap, header: impl hyper::header::AsHeaderName, - value: impl AsRef<[u8]> + value: impl AsRef<[u8]>, ) -> bool { let value = value.as_ref(); for header in headers.get_all(header) { - if header.as_bytes().split(|&c| c == b',').any( - |x| trim(x).eq_ignore_ascii_case(value) - ) { + if header + .as_bytes() + .split(|&c| c == b',') + .any(|x| trim(x).eq_ignore_ascii_case(value)) + { return true; } } @@ -30,7 +32,7 @@ fn trim_start(data: &[u8]) -> &[u8] { #[inline] fn trim_end(data: &[u8]) -> &[u8] { if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) { - &data[..last + 1] + &data[..=last] } else { b"" } diff --git a/src/workers.rs b/src/workers.rs index e49fa3b6..fc7750d2 100644 --- a/src/workers.rs +++ b/src/workers.rs @@ -8,8 +8,8 @@ use std::os::windows::io::FromRawSocket; use super::asgi::serve::ASGIWorker; use super::rsgi::serve::RSGIWorker; -use super::wsgi::serve::WSGIWorker; use super::tls::{load_certs as tls_load_certs, load_private_key as tls_load_pkey}; +use super::wsgi::serve::WSGIWorker; pub(crate) struct WorkerConfig { pub id: i32, @@ -22,7 +22,7 @@ pub(crate) struct WorkerConfig { pub opt_enabled: bool, pub ssl_enabled: bool, ssl_cert: Option, - ssl_key: Option + ssl_key: Option, } impl WorkerConfig { @@ -37,7 +37,7 @@ impl WorkerConfig { opt_enabled: bool, ssl_enabled: bool, ssl_cert: Option<&str>, - ssl_key: Option<&str> + ssl_key: Option<&str>, ) -> Self { Self { id, @@ -49,23 +49,19 @@ impl WorkerConfig { websockets_enabled, opt_enabled, ssl_enabled, - ssl_cert: ssl_cert.map_or(None, |v| Some(v.into())), - ssl_key: ssl_key.map_or(None, |v| Some(v.into())) + ssl_cert: ssl_cert.map(std::convert::Into::into), + ssl_key: ssl_key.map(std::convert::Into::into), } } #[cfg(unix)] pub fn tcp_listener(&self) -> TcpListener { - unsafe { - TcpListener::from_raw_fd(self.socket_fd) - } + unsafe { TcpListener::from_raw_fd(self.socket_fd) } } #[cfg(windows)] pub fn tcp_listener(&self) -> TcpListener { - unsafe { - TcpListener::from_raw_socket(self.socket_fd as u64) - } + unsafe { TcpListener::from_raw_socket(self.socket_fd as u64) } } pub fn tls_cfg(&self) -> tokio_rustls::rustls::ServerConfig { @@ -74,13 +70,13 @@ impl WorkerConfig { .with_no_client_auth() .with_single_cert( tls_load_certs(&self.ssl_cert.clone().unwrap()[..]).unwrap(), - tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap() + tls_load_pkey(&self.ssl_key.clone().unwrap()[..]).unwrap(), ) .unwrap(); cfg.alpn_protocols = match &self.http_mode[..] { "1" => vec![b"http/1.1".to_vec()], "2" => vec![b"h2".to_vec()], - _ => vec![b"h2".to_vec(), b"http/1.1".to_vec()] + _ => vec![b"h2".to_vec(), b"http/1.1".to_vec()], }; cfg } @@ -102,7 +98,7 @@ pub(crate) struct WorkerExecutor; impl hyper::rt::Executor for WorkerExecutor where - F: std::future::Future + 'static + F: std::future::Future + 'static, { fn execute(&self, fut: F) { tokio::task::spawn_local(fut); @@ -123,14 +119,9 @@ macro_rules! build_service { let rth = rth.clone(); async move { - Ok::<_, std::convert::Infallible>($target( - rth, - callback_wrapper, - local_addr, - remote_addr, - req, - "http" - ).await) + Ok::<_, std::convert::Infallible>( + $target(rth, callback_wrapper, local_addr, remote_addr, req, "http").await, + ) } })) } @@ -153,14 +144,9 @@ macro_rules! build_service_ssl { let rth = rth.clone(); async move { - Ok::<_, std::convert::Infallible>($target( - rth, - callback_wrapper, - local_addr, - remote_addr, - req, - "https" - ).await) + Ok::<_, std::convert::Infallible>( + $target(rth, callback_wrapper, local_addr, remote_addr, req, "https").await, + ) } })) } @@ -170,13 +156,7 @@ macro_rules! build_service_ssl { macro_rules! serve_rth { ($func_name:ident, $target:expr) => { - fn $func_name( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { pyo3_log::init(); let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads); let rth = rt.handler(); @@ -184,34 +164,30 @@ macro_rules! serve_rth { let http1_only = self.config.http_mode == "1"; let http2_only = self.config.http_mode == "2"; let http1_buffer_max = self.config.http1_buffer_max.clone(); - let callback_wrapper = crate::callbacks::CallbackWrapper::new( - callback, event_loop, context - ); + let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); - let svc_loop = crate::runtime::run_until_complete( - rt.handler(), - event_loop, - async move { - let service = crate::workers::build_service!( - callback_wrapper, rth, $target - ); - let server = hyper::Server::from_tcp(tcp_listener).unwrap() - .http1_only(http1_only) - .http2_only(http2_only) - .http1_max_buf_size(http1_buffer_max) - .serve(service); - server.with_graceful_shutdown(async move { - Python::with_gil(|py| { - crate::runtime::into_future(signal_rx.as_ref(py)).unwrap() - }).await.unwrap(); - }).await.unwrap(); - log::info!("Stopping worker-{}", worker_id); - Ok(()) - } - ); + let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move { + let service = crate::workers::build_service!(callback_wrapper, rth, $target); + let server = hyper::Server::from_tcp(tcp_listener) + .unwrap() + .http1_only(http1_only) + .http2_only(http2_only) + .http1_max_buf_size(http1_buffer_max) + .serve(service); + server + .with_graceful_shutdown(async move { + Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()) + .await + .unwrap(); + }) + .await + .unwrap(); + log::info!("Stopping worker-{}", worker_id); + Ok(()) + }); match svc_loop { Ok(_) => {} @@ -226,13 +202,7 @@ macro_rules! serve_rth { macro_rules! serve_rth_ssl { ($func_name:ident, $target:expr) => { - fn $func_name( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { pyo3_log::init(); let rt = crate::runtime::init_runtime_mt(self.config.threads, self.config.pthreads); let rth = rt.handler(); @@ -241,38 +211,29 @@ macro_rules! serve_rth_ssl { let http2_only = self.config.http_mode == "2"; let http1_buffer_max = self.config.http1_buffer_max.clone(); let tls_cfg = self.config.tls_cfg(); - let callback_wrapper = crate::callbacks::CallbackWrapper::new( - callback, event_loop, context - ); + let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); - let svc_loop = crate::runtime::run_until_complete( - rt.handler(), - event_loop, - async move { - let service = crate::workers::build_service_ssl!( - callback_wrapper, rth, $target - ); - let server = hyper::Server::builder( - crate::tls::tls_listen( - std::sync::Arc::new(tls_cfg), tcp_listener - ) - ) - .http1_only(http1_only) - .http2_only(http2_only) - .http1_max_buf_size(http1_buffer_max) - .serve(service); - server.with_graceful_shutdown(async move { - Python::with_gil(|py| { - crate::runtime::into_future(signal_rx.as_ref(py)).unwrap() - }).await.unwrap(); - }).await.unwrap(); - log::info!("Stopping worker-{}", worker_id); - Ok(()) - } - ); + let svc_loop = crate::runtime::run_until_complete(rt.handler(), event_loop, async move { + let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target); + let server = hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener)) + .http1_only(http1_only) + .http2_only(http2_only) + .http1_max_buf_size(http1_buffer_max) + .serve(service); + server + .with_graceful_shutdown(async move { + Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()) + .await + .unwrap(); + }) + .await + .unwrap(); + log::info!("Stopping worker-{}", worker_id); + Ok(()) + }); match svc_loop { Ok(_) => {} @@ -287,22 +248,14 @@ macro_rules! serve_rth_ssl { macro_rules! serve_wth { ($func_name: ident, $target:expr) => { - fn $func_name( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { pyo3_log::init(); let rtm = crate::runtime::init_runtime_mt(1, 1); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); - let callback_wrapper = crate::callbacks::CallbackWrapper::new( - callback, event_loop, context - ); + let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); let mut workers = vec![]; let (stx, srx) = tokio::sync::watch::channel(false); @@ -323,38 +276,36 @@ macro_rules! serve_wth { let local = tokio::task::LocalSet::new(); crate::runtime::block_on_local(rt, local, async move { - let service = crate::workers::build_service!( - callback_wrapper, rth, $target - ); - let server = hyper::Server::from_tcp(tcp_listener).unwrap() + let service = crate::workers::build_service!(callback_wrapper, rth, $target); + let server = hyper::Server::from_tcp(tcp_listener) + .unwrap() .executor(crate::workers::WorkerExecutor) .http1_only(http1_only) .http2_only(http2_only) .http1_max_buf_size(http1_buffer_max) .serve(service); - server.with_graceful_shutdown(async move { - srx.changed().await.unwrap(); - }).await.unwrap(); + server + .with_graceful_shutdown(async move { + srx.changed().await.unwrap(); + }) + .await + .unwrap(); log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1); }); })); - }; + } - let main_loop = crate::runtime::run_until_complete( - rtm.handler(), - event_loop, - async move { - Python::with_gil(|py| { - crate::runtime::into_future(signal_rx.as_ref(py)).unwrap() - }).await.unwrap(); - stx.send(true).unwrap(); - log::info!("Stopping worker-{}", worker_id); - while let Some(worker) = workers.pop() { - worker.join().unwrap(); - } - Ok(()) + let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move { + Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()) + .await + .unwrap(); + stx.send(true).unwrap(); + log::info!("Stopping worker-{}", worker_id); + while let Some(worker) = workers.pop() { + worker.join().unwrap(); } - ); + Ok(()) + }); match main_loop { Ok(_) => {} @@ -369,22 +320,14 @@ macro_rules! serve_wth { macro_rules! serve_wth_ssl { ($func_name: ident, $target:expr) => { - fn $func_name( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn $func_name(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { pyo3_log::init(); let rtm = crate::runtime::init_runtime_mt(1, 1); let worker_id = self.config.id; log::info!("Started worker-{}", worker_id); - let callback_wrapper = crate::callbacks::CallbackWrapper::new( - callback, event_loop, context - ); + let callback_wrapper = crate::callbacks::CallbackWrapper::new(callback, event_loop, context); let mut workers = vec![]; let (stx, srx) = tokio::sync::watch::channel(false); @@ -406,42 +349,36 @@ macro_rules! serve_wth_ssl { let local = tokio::task::LocalSet::new(); crate::runtime::block_on_local(rt, local, async move { - let service = crate::workers::build_service_ssl!( - callback_wrapper, rth, $target - ); - let server = hyper::Server::builder( - crate::tls::tls_listen( - std::sync::Arc::new(tls_cfg), tcp_listener - ) - ) - .executor(crate::workers::WorkerExecutor) - .http1_only(http1_only) - .http2_only(http2_only) - .http1_max_buf_size(http1_buffer_max) - .serve(service); - server.with_graceful_shutdown(async move { - srx.changed().await.unwrap(); - }).await.unwrap(); + let service = crate::workers::build_service_ssl!(callback_wrapper, rth, $target); + let server = + hyper::Server::builder(crate::tls::tls_listen(std::sync::Arc::new(tls_cfg), tcp_listener)) + .executor(crate::workers::WorkerExecutor) + .http1_only(http1_only) + .http2_only(http2_only) + .http1_max_buf_size(http1_buffer_max) + .serve(service); + server + .with_graceful_shutdown(async move { + srx.changed().await.unwrap(); + }) + .await + .unwrap(); log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1); }); })); - }; + } - let main_loop = crate::runtime::run_until_complete( - rtm.handler(), - event_loop, - async move { - Python::with_gil(|py| { - crate::runtime::into_future(signal_rx.as_ref(py)).unwrap() - }).await.unwrap(); - stx.send(true).unwrap(); - log::info!("Stopping worker-{}", worker_id); - while let Some(worker) = workers.pop() { - worker.join().unwrap(); - } - Ok(()) + let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop, async move { + Python::with_gil(|py| crate::runtime::into_future(signal_rx.as_ref(py)).unwrap()) + .await + .unwrap(); + stx.send(true).unwrap(); + log::info!("Stopping worker-{}", worker_id); + while let Some(worker) = workers.pop() { + worker.join().unwrap(); } - ); + Ok(()) + }); match main_loop { Ok(_) => {} @@ -457,8 +394,8 @@ macro_rules! serve_wth_ssl { pub(crate) use build_service; pub(crate) use build_service_ssl; pub(crate) use serve_rth; -pub(crate) use serve_wth; pub(crate) use serve_rth_ssl; +pub(crate) use serve_wth; pub(crate) use serve_wth_ssl; pub(crate) fn init_pymodule(module: &PyModule) -> PyResult<()> { diff --git a/src/ws.rs b/src/ws.rs index bd44dded..3f892289 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -1,24 +1,24 @@ use hyper::{ - Body, - Request, - Response, - StatusCode, header::{CONNECTION, UPGRADE}, - http::response::Builder + http::response::Builder, + Body, Request, Response, StatusCode, }; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::sync::mpsc; +use tokio_tungstenite::WebSocketStream; use tungstenite::{ error::ProtocolError, handshake::derive_accept_key, - protocol::{Role, WebSocketConfig} + protocol::{Role, WebSocketConfig}, }; -use pin_project::pin_project; -use std::{future::Future, pin::Pin, task::{Context, Poll}}; -use tokio_tungstenite::WebSocketStream; -use tokio::sync::mpsc; use super::utils::header_contains_value; - #[pin_project] #[derive(Debug)] pub struct HyperWebsocket { @@ -37,15 +37,9 @@ impl Future for HyperWebsocket { Poll::Ready(x) => x, }; - let upgraded = upgraded.map_err(|_| - tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete) - )?; + let upgraded = upgraded.map_err(|_| tungstenite::Error::Protocol(ProtocolError::HandshakeIncomplete))?; - let stream = WebSocketStream::from_raw_socket( - upgraded, - Role::Server, - this.config.take(), - ); + let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, this.config.take()); tokio::pin!(stream); match stream.as_mut().poll(cx) { @@ -58,18 +52,15 @@ impl Future for HyperWebsocket { pub(crate) struct UpgradeData { response_builder: Option, response_tx: Option>>, - pub consumed: bool + pub consumed: bool, } impl UpgradeData { - pub fn new( - response_builder: Builder, - response_tx: mpsc::Sender>) - -> Self { + pub fn new(response_builder: Builder, response_tx: mpsc::Sender>) -> Self { Self { response_builder: Some(response_builder), response_tx: Some(response_tx), - consumed: false + consumed: false, } } @@ -79,19 +70,16 @@ impl UpgradeData { Ok(_) => { self.consumed = true; Ok(()) - }, - err => err + } + err => err, } } } #[inline] pub(crate) fn is_upgrade_request(request: &Request) -> bool { - header_contains_value( - request.headers(), CONNECTION, "Upgrade" - ) && header_contains_value( - request.headers(), UPGRADE, "websocket" - ) + header_contains_value(request.headers(), CONNECTION, "Upgrade") + && header_contains_value(request.headers(), UPGRADE, "websocket") } pub(crate) fn upgrade_intent( @@ -100,13 +88,17 @@ pub(crate) fn upgrade_intent( ) -> Result<(Builder, HyperWebsocket), ProtocolError> { let request = request.borrow_mut(); - let key = request.headers() + let key = request + .headers() .get("Sec-WebSocket-Key") .ok_or(ProtocolError::MissingSecWebSocketKey)?; - if request.headers().get("Sec-WebSocket-Version").map( - |v| v.as_bytes() - ) != Some(b"13") { + if request + .headers() + .get("Sec-WebSocket-Version") + .map(hyper::http::HeaderValue::as_bytes) + != Some(b"13") + { return Err(ProtocolError::MissingSecWebSocketVersionHeader); } diff --git a/src/wsgi/callbacks.rs b/src/wsgi/callbacks.rs index ae86cf01..e961827b 100644 --- a/src/wsgi/callbacks.rs +++ b/src/wsgi/callbacks.rs @@ -2,44 +2,35 @@ use hyper::Body; use pyo3::prelude::*; use tokio::task::JoinHandle; +use super::types::{WSGIResponseBodyIter, WSGIScope as Scope}; use crate::callbacks::CallbackWrapper; -use super::types::{WSGIScope as Scope, WSGIResponseBodyIter}; const WSGI_LIST_RESPONSE_BODY: i32 = 0; const WSGI_ITER_RESPONSE_BODY: i32 = 1; - #[inline(always)] -fn run_callback( - callback: PyObject, - scope: Scope -) -> PyResult<(i32, Vec<(String, String)>, Body)> { +fn run_callback(callback: PyObject, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> { Python::with_gil(|py| { - let (status, headers, body_type, pybody) = callback.call1(py, (scope,))? - .extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?; + let (status, headers, body_type, pybody) = + callback + .call1(py, (scope,))? + .extract::<(i32, Vec<(String, String)>, i32, PyObject)>(py)?; let body = match body_type { WSGI_LIST_RESPONSE_BODY => Body::from(pybody.extract::>(py)?), WSGI_ITER_RESPONSE_BODY => Body::wrap_stream(WSGIResponseBodyIter::new(pybody)), - _ => Body::empty() + _ => Body::empty(), }; Ok((status, headers, body)) }) } -pub(crate) fn call_rtb_http( - cb: CallbackWrapper, - scope: Scope -) -> PyResult<(i32, Vec<(String, String)>, Body)> { - run_callback(cb.callback.clone(), scope) +pub(crate) fn call_rtb_http(cb: CallbackWrapper, scope: Scope) -> PyResult<(i32, Vec<(String, String)>, Body)> { + run_callback(cb.callback, scope) } pub(crate) fn call_rtt_http( cb: CallbackWrapper, - scope: Scope + scope: Scope, ) -> JoinHandle, Body)>> { - let callback = cb.callback.clone(); - - tokio::task::spawn_blocking(move || { - run_callback(callback, scope) - }) + tokio::task::spawn_blocking(move || run_callback(cb.callback, scope)) } diff --git a/src/wsgi/http.rs b/src/wsgi/http.rs index 66cccd75..e74882a7 100644 --- a/src/wsgi/http.rs +++ b/src/wsgi/http.rs @@ -1,21 +1,18 @@ use hyper::{ - Body, - Request, - Response, - header::{SERVER as HK_SERVER, HeaderName, HeaderValue} + header::{HeaderName, HeaderValue, SERVER as HK_SERVER}, + Body, Request, Response, }; use std::net::SocketAddr; +use super::{ + callbacks::{call_rtb_http, call_rtt_http}, + types::WSGIScope as Scope, +}; use crate::{ callbacks::CallbackWrapper, - http::{HV_SERVER, response_500}, + http::{response_500, HV_SERVER}, runtime::RuntimeRef, }; -use super::{ - callbacks::{call_rtb_http, call_rtt_http}, - types::WSGIScope as Scope -}; - #[inline(always)] fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) -> Response { @@ -26,7 +23,7 @@ fn build_response(status: i32, pyheaders: Vec<(String, String)>, body: Body) -> for (key, val) in pyheaders { headers.append( HeaderName::from_bytes(key.as_bytes()).unwrap(), - HeaderValue::from_str(&val).unwrap() + HeaderValue::from_str(&val).unwrap(), ); } res @@ -38,14 +35,11 @@ pub(crate) async fn handle_rtt( server_addr: SocketAddr, client_addr: SocketAddr, req: Request, - scheme: &str + scheme: &str, ) -> Response { - if let Ok(res) = call_rtt_http( - callback, - Scope::new(scheme, server_addr, client_addr, req).await - ).await { + if let Ok(res) = call_rtt_http(callback, Scope::new(scheme, server_addr, client_addr, req).await).await { if let Ok((status, headers, body)) = res { - return build_response(status, headers, body) + return build_response(status, headers, body); } log::warn!("Application callable raised an exception"); } else { @@ -60,12 +54,9 @@ pub(crate) async fn handle_rtb( server_addr: SocketAddr, client_addr: SocketAddr, req: Request, - scheme: &str + scheme: &str, ) -> Response { - match call_rtb_http( - callback, - Scope::new(scheme, server_addr, client_addr, req).await - ) { + match call_rtb_http(callback, Scope::new(scheme, server_addr, client_addr, req).await) { Ok((status, headers, body)) => build_response(status, headers, body), _ => { log::warn!("Application callable raised an exception"); diff --git a/src/wsgi/serve.rs b/src/wsgi/serve.rs index 284a4cdf..0b60e8b3 100644 --- a/src/wsgi/serve.rs +++ b/src/wsgi/serve.rs @@ -1,17 +1,11 @@ use pyo3::prelude::*; -use crate::workers::{ - WorkerConfig, - serve_rth, - serve_wth, - serve_rth_ssl, - serve_wth_ssl -}; use super::http::{handle_rtb, handle_rtt}; +use crate::workers::{serve_rth, serve_rth_ssl, serve_wth, serve_wth_ssl, WorkerConfig}; -#[pyclass(module="granian._granian")] +#[pyclass(module = "granian._granian")] pub struct WSGIWorker { - config: WorkerConfig + config: WorkerConfig, } impl WSGIWorker { @@ -46,7 +40,7 @@ impl WSGIWorker { http1_buffer_max: usize, ssl_enabled: bool, ssl_cert: Option<&str>, - ssl_key: Option<&str> + ssl_key: Option<&str>, ) -> PyResult { Ok(Self { config: WorkerConfig::new( @@ -60,34 +54,22 @@ impl WSGIWorker { true, ssl_enabled, ssl_cert, - ssl_key - ) + ssl_key, + ), }) } - fn serve_rth( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn serve_rth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { match self.config.ssl_enabled { false => self._serve_rth(callback, event_loop, context, signal_rx), - true => self._serve_rth_ssl(callback, event_loop, context, signal_rx) + true => self._serve_rth_ssl(callback, event_loop, context, signal_rx), } } - fn serve_wth( - &self, - callback: PyObject, - event_loop: &PyAny, - context: &PyAny, - signal_rx: PyObject - ) { + fn serve_wth(&self, callback: PyObject, event_loop: &PyAny, context: &PyAny, signal_rx: PyObject) { match self.config.ssl_enabled { false => self._serve_wth(callback, event_loop, context, signal_rx), - true => self._serve_wth_ssl(callback, event_loop, context, signal_rx) + true => self._serve_wth_ssl(callback, event_loop, context, signal_rx), } } } diff --git a/src/wsgi/types.rs b/src/wsgi/types.rs index 02c24b26..7e5c71f8 100644 --- a/src/wsgi/types.rs +++ b/src/wsgi/types.rs @@ -1,23 +1,21 @@ use futures::Stream; use hyper::{ body::Bytes, - header::{CONTENT_TYPE, CONTENT_LENGTH, HeaderMap}, - Body, - Method, - Request, - Uri, - Version + header::{HeaderMap, CONTENT_LENGTH, CONTENT_TYPE}, + Body, Method, Request, Uri, Version, }; -use pyo3::{prelude::*, types::IntoPyDict}; use pyo3::types::{PyBytes, PyDict, PyList}; -use std::{net::{IpAddr, SocketAddr}, task::{Context, Poll}}; +use pyo3::{prelude::*, types::IntoPyDict}; +use std::{ + net::{IpAddr, SocketAddr}, + task::{Context, Poll}, +}; const LINE_SPLIT: u8 = u8::from_be_bytes(*b"\n"); - #[pyclass(module = "granian._granian")] pub(crate) struct WSGIBody { - inner: Bytes + inner: Bytes, } impl WSGIBody { @@ -36,9 +34,9 @@ impl WSGIBody { match self.inner.iter().position(|&c| c == LINE_SPLIT) { Some(next_split) => { let bytes = self.inner.split_to(next_split); - Some(PyBytes::new(py, &bytes[..])) - }, - _ => None + Some(PyBytes::new(py, &bytes)) + } + _ => None, } } @@ -48,18 +46,16 @@ impl WSGIBody { None => { let bytes = self.inner.split_to(self.inner.len()); PyBytes::new(py, &bytes[..]) - }, - Some(size) => { - match size { - 0 => PyBytes::new(py, b""), - size => { - let limit = self.inner.len(); - let rsize = if size > limit { limit } else { size }; - let bytes = self.inner.split_to(rsize); - PyBytes::new(py, &bytes[..]) - } - } } + Some(size) => match size { + 0 => PyBytes::new(py, b""), + size => { + let limit = self.inner.len(); + let rsize = if size > limit { limit } else { size }; + let bytes = self.inner.split_to(rsize); + PyBytes::new(py, &bytes[..]) + } + }, } } @@ -69,16 +65,17 @@ impl WSGIBody { let bytes = self.inner.split_to(next_split); self.inner = self.inner.slice(1..); PyBytes::new(py, &bytes[..]) - }, - _ => PyBytes::new(py, b"") + } + _ => PyBytes::new(py, b""), } } #[pyo3(signature = (_hint=None))] fn readlines<'p>(&mut self, py: Python<'p>, _hint: Option) -> &'p PyList { - let lines: Vec<&PyBytes> = self.inner + let lines: Vec<&PyBytes> = self + .inner .split(|&c| c == LINE_SPLIT) - .map(|item| PyBytes::new(py, &item[..])) + .map(|item| PyBytes::new(py, item)) .collect(); self.inner.clear(); PyList::new(py, lines) @@ -95,28 +92,19 @@ pub(crate) struct WSGIScope { server_port: u16, client: String, headers: HeaderMap, - body: Bytes + body: Bytes, } impl WSGIScope { - pub async fn new( - scheme: &str, - server: SocketAddr, - client: SocketAddr, - request: Request, - ) -> Self { + pub async fn new(scheme: &str, server: SocketAddr, client: SocketAddr, request: Request) -> Self { let http_version = request.version(); - let method = request.method().to_owned(); - let uri = request.uri().to_owned(); - let headers = request.headers().to_owned(); + let method = request.method().clone(); + let uri = request.uri().clone(); + let headers = request.headers().clone(); let body = match method { - Method::HEAD | Method::GET | Method::OPTIONS => { Bytes::new() }, - _ => { - hyper::body::to_bytes(request) - .await - .unwrap_or(Bytes::new()) - } + Method::HEAD | Method::GET | Method::OPTIONS => Bytes::new(), + _ => hyper::body::to_bytes(request).await.unwrap_or(Bytes::new()), }; Self { @@ -128,7 +116,7 @@ impl WSGIScope { server_port: server.port(), client: client.to_string(), headers, - body + body, } } @@ -138,7 +126,7 @@ impl WSGIScope { Version::HTTP_10 => "HTTP/1", Version::HTTP_11 => "HTTP/1.1", Version::HTTP_2 => "HTTP/2", - _ => "HTTP/1" + _ => "HTTP/1", } } } @@ -157,21 +145,21 @@ impl WSGIScope { content_type, content_len, headers, - body + body, ) = py.allow_threads(|| { - let (path, query_string) = self.uri.path_and_query() + let (path, query_string) = self + .uri + .path_and_query() .map_or_else(|| ("", ""), |pq| (pq.path(), pq.query().unwrap_or(""))); let content_type = self.headers.remove(CONTENT_TYPE); let content_len = self.headers.remove(CONTENT_LENGTH); let mut headers = Vec::with_capacity(self.headers.len()); - for (key, val) in self.headers.iter() { - headers.push( - ( - format!("HTTP_{}", key.as_str().replace("-", "_").to_uppercase()), - val.to_str().unwrap_or_default() - ) - ); + for (key, val) in &self.headers { + headers.push(( + format!("HTTP_{}", key.as_str().replace('-', "_").to_uppercase()), + val.to_str().unwrap_or_default(), + )); } ( @@ -185,7 +173,7 @@ impl WSGIScope { content_type, content_len, headers, - WSGIBody::new(self.body.to_owned()) + WSGIBody::new(self.body.clone()), ) }); @@ -202,13 +190,13 @@ impl WSGIScope { if let Some(content_type) = content_type { ret.set_item( pyo3::intern!(py, "CONTENT_TYPE"), - content_type.to_str().unwrap_or_default() + content_type.to_str().unwrap_or_default(), )?; } if let Some(content_len) = content_len { ret.set_item( pyo3::intern!(py, "CONTENT_LENGTH"), - content_len.to_str().unwrap_or_default() + content_len.to_str().unwrap_or_default(), )?; } @@ -219,7 +207,7 @@ impl WSGIScope { } pub(crate) struct WSGIResponseBodyIter { - inner: PyObject + inner: PyObject, } impl WSGIResponseBodyIter { @@ -235,27 +223,20 @@ impl WSGIResponseBodyIter { impl Stream for WSGIResponseBodyIter { type Item = PyResult>; - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_> - ) -> Poll> { - Python::with_gil(|py| { - match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) { - Ok(chunk_obj) => { - match chunk_obj.extract::>(py) { - Ok(chunk) => Poll::Ready(Some(Ok(chunk))), - _ => { - self.close_inner(py); - Poll::Ready(None) - } - } - }, - Err(err) => { - if err.is_instance_of::(py) { - self.close_inner(py); - } + fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) { + Ok(chunk_obj) => match chunk_obj.extract::>(py) { + Ok(chunk) => Poll::Ready(Some(Ok(chunk))), + _ => { + self.close_inner(py); Poll::Ready(None) } + }, + Err(err) => { + if err.is_instance_of::(py) { + self.close_inner(py); + } + Poll::Ready(None) } }) } diff --git a/tests/apps/asgi.py b/tests/apps/asgi.py index c4d4b012..3595dd63 100644 --- a/tests/apps/asgi.py +++ b/tests/apps/asgi.py @@ -1,55 +1,45 @@ import json + PLAINTEXT_RESPONSE = { 'type': 'http.response.start', 'status': 200, - 'headers': [ - [b'content-type', b'text/plain; charset=utf-8'], - ] -} -JSON_RESPONSE = { - 'type': 'http.response.start', - 'status': 200, - 'headers': [ - [b'content-type', b'application/json'], - ] + 'headers': [[b'content-type', b'text/plain; charset=utf-8']], } +JSON_RESPONSE = {'type': 'http.response.start', 'status': 200, 'headers': [[b'content-type', b'application/json']]} async def info(scope, receive, send): await send(JSON_RESPONSE) - await send({ - 'type': 'http.response.body', - 'body': json.dumps({ - 'type': scope['type'], - 'asgi': scope['asgi'], - 'http_version': scope['http_version'], - 'scheme': scope['scheme'], - 'method': scope['method'], - 'path': scope['path'], - 'query_string': scope['query_string'].decode("latin-1"), - 'headers': { - k.decode("utf8"): v.decode("utf8") - for k, v in scope['headers'] - } - }).encode("utf8"), - 'more_body': False - }) + await send( + { + 'type': 'http.response.body', + 'body': json.dumps( + { + 'type': scope['type'], + 'asgi': scope['asgi'], + 'http_version': scope['http_version'], + 'scheme': scope['scheme'], + 'method': scope['method'], + 'path': scope['path'], + 'query_string': scope['query_string'].decode('latin-1'), + 'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']}, + } + ).encode('utf8'), + 'more_body': False, + } + ) async def echo(scope, receive, send): await send(PLAINTEXT_RESPONSE) more_body = True - body = b"" + body = b'' while more_body: msg = await receive() more_body = msg['more_body'] body += msg['body'] - await send({ - 'type': 'http.response.body', - 'body': body, - 'more_body': False - }) + await send({'type': 'http.response.body', 'body': body, 'more_body': False}) async def ws_reject(scope, receive, send): @@ -58,21 +48,22 @@ async def ws_reject(scope, receive, send): async def ws_info(scope, receive, send): await send({'type': 'websocket.accept'}) - await send({ - 'type': 'websocket.send', - 'text': json.dumps({ - 'type': scope['type'], - 'asgi': scope['asgi'], - 'http_version': scope['http_version'], - 'scheme': scope['scheme'], - 'path': scope['path'], - 'query_string': scope['query_string'].decode("latin-1"), - 'headers': { - k.decode("utf8"): v.decode("utf8") - for k, v in scope['headers'] - } - }) - }) + await send( + { + 'type': 'websocket.send', + 'text': json.dumps( + { + 'type': scope['type'], + 'asgi': scope['asgi'], + 'http_version': scope['http_version'], + 'scheme': scope['scheme'], + 'path': scope['path'], + 'query_string': scope['query_string'].decode('latin-1'), + 'headers': {k.decode('utf8'): v.decode('utf8') for k, v in scope['headers']}, + } + ), + } + ) await send({'type': 'websocket.close'}) @@ -98,10 +89,7 @@ async def ws_push(scope, receive, send): try: while True: - await send({ - 'type': 'websocket.send', - 'text': 'ping' - }) + await send({'type': 'websocket.send', 'text': 'ping'}) except Exception: pass @@ -116,12 +104,12 @@ async def err_proto(scope, receive, send): def app(scope, receive, send): return { - "/info": info, - "/echo": echo, - "/ws_reject": ws_reject, - "/ws_info": ws_info, - "/ws_echo": ws_echo, - "/ws_push": ws_push, - "/err_app": err_app, - "/err_proto": err_proto + '/info': info, + '/echo': echo, + '/ws_reject': ws_reject, + '/ws_info': ws_info, + '/ws_echo': ws_echo, + '/ws_push': ws_push, + '/err_app': err_app, + '/err_proto': err_proto, }[scope['path']](scope, receive, send) diff --git a/tests/apps/rsgi.py b/tests/apps/rsgi.py index 067787cd..0c16fa53 100644 --- a/tests/apps/rsgi.py +++ b/tests/apps/rsgi.py @@ -1,55 +1,42 @@ import json -from granian.rsgi import ( - HTTPProtocol, - Scope, - WebsocketMessageType, - WebsocketProtocol -) +from granian.rsgi import HTTPProtocol, Scope, WebsocketMessageType, WebsocketProtocol async def info(scope: Scope, protocol: HTTPProtocol): protocol.response_bytes( 200, [('content-type', 'application/json')], - json.dumps({ - 'proto': scope.proto, - 'http_version': scope.http_version, - 'rsgi_version': scope.rsgi_version, - 'scheme': scope.scheme, - 'method': scope.method, - 'path': scope.path, - 'query_string': scope.query_string, - 'headers': {k: v for k, v in scope.headers.items()} - }).encode("utf8") + json.dumps( + { + 'proto': scope.proto, + 'http_version': scope.http_version, + 'rsgi_version': scope.rsgi_version, + 'scheme': scope.scheme, + 'method': scope.method, + 'path': scope.path, + 'query_string': scope.query_string, + 'headers': dict(scope.headers.items()), + } + ).encode('utf8'), ) async def echo(_, protocol: HTTPProtocol): msg = await protocol() - protocol.response_bytes( - 200, - [('content-type', 'text/plain; charset=utf-8')], - msg - ) + protocol.response_bytes(200, [('content-type', 'text/plain; charset=utf-8')], msg) async def echo_stream(_, protocol: HTTPProtocol): - trx = protocol.response_stream( - 200, - [('content-type', 'text/plain; charset=utf-8')] - ) + trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')]) async for msg in protocol: await trx.send_bytes(msg) async def stream(_, protocol: HTTPProtocol): - trx = protocol.response_stream( - 200, - [('content-type', 'text/plain; charset=utf-8')] - ) + trx = protocol.response_stream(200, [('content-type', 'text/plain; charset=utf-8')]) for _ in range(0, 3): - await trx.send_bytes(b"test") + await trx.send_bytes(b'test') async def ws_reject(_, protocol: WebsocketProtocol): @@ -59,16 +46,20 @@ async def ws_reject(_, protocol: WebsocketProtocol): async def ws_info(scope: Scope, protocol: WebsocketProtocol): trx = await protocol.accept() - await trx.send_str(json.dumps({ - 'proto': scope.proto, - 'http_version': scope.http_version, - 'rsgi_version': scope.rsgi_version, - 'scheme': scope.scheme, - 'method': scope.method, - 'path': scope.path, - 'query_string': scope.query_string, - 'headers': {k: v for k, v in scope.headers.items()} - })) + await trx.send_str( + json.dumps( + { + 'proto': scope.proto, + 'http_version': scope.http_version, + 'rsgi_version': scope.rsgi_version, + 'scheme': scope.scheme, + 'method': scope.method, + 'path': scope.path, + 'query_string': scope.query_string, + 'headers': dict(scope.headers.items()), + } + ) + ) while True: message = await trx.receive() if message.kind == WebsocketMessageType.close: @@ -97,7 +88,7 @@ async def ws_push(_, protocol: WebsocketProtocol): try: while True: - await trx.send_str("ping") + await trx.send_str('ping') except Exception: pass @@ -110,13 +101,13 @@ async def err_app(scope: Scope, protocol: HTTPProtocol): def app(scope, protocol): return { - "/info": info, - "/echo": echo, - "/echos": echo_stream, - "/stream": stream, - "/ws_reject": ws_reject, - "/ws_info": ws_info, - "/ws_echo": ws_echo, - "/ws_push": ws_push, - "/err_app": err_app + '/info': info, + '/echo': echo, + '/echos': echo_stream, + '/stream': stream, + '/ws_reject': ws_reject, + '/ws_info': ws_info, + '/ws_echo': ws_echo, + '/ws_push': ws_push, + '/err_app': err_app, }[scope.path](scope, protocol) diff --git a/tests/apps/wsgi.py b/tests/apps/wsgi.py index a6eca9bb..c54d065b 100644 --- a/tests/apps/wsgi.py +++ b/tests/apps/wsgi.py @@ -2,37 +2,32 @@ def info(environ, protocol): - protocol( - "200 OK", - [('content-type', 'application/json')] - ) - return [json.dumps({ - 'scheme': environ['wsgi.url_scheme'], - 'method': environ['REQUEST_METHOD'], - 'path': environ["PATH_INFO"], - 'query_string': environ["QUERY_STRING"], - 'content_length': environ['CONTENT_LENGTH'], - 'headers': {k: v for k, v in environ.items() if k.startswith("HTTP_")} - }).encode("utf8")] + protocol('200 OK', [('content-type', 'application/json')]) + return [ + json.dumps( + { + 'scheme': environ['wsgi.url_scheme'], + 'method': environ['REQUEST_METHOD'], + 'path': environ['PATH_INFO'], + 'query_string': environ['QUERY_STRING'], + 'content_length': environ['CONTENT_LENGTH'], + 'headers': {k: v for k, v in environ.items() if k.startswith('HTTP_')}, + } + ).encode('utf8') + ] def echo(environ, protocol): - protocol( - '200 OK', - [('content-type', 'text/plain; charset=utf-8')] - ) + protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')]) return [environ['wsgi.input'].read()] def iterbody(environ, protocol): def response(): for _ in range(0, 3): - yield b"test" + yield b'test' - protocol( - '200 OK', - [('content-type', 'text/plain; charset=utf-8')] - ) + protocol('200 OK', [('content-type', 'text/plain; charset=utf-8')]) return response() @@ -41,9 +36,6 @@ def err_app(environ, protocol): def app(environ, protocol): - return { - "/info": info, - "/echo": echo, - "/iterbody": iterbody, - "/err_app": err_app - }[environ["PATH_INFO"]](environ, protocol) + return {'/info': info, '/echo': echo, '/iterbody': iterbody, '/err_app': err_app}[environ['PATH_INFO']]( + environ, protocol + ) diff --git a/tests/conftest.py b/tests/conftest.py index 8679f388..0da2c979 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import asyncio import os import socket - from contextlib import asynccontextmanager, closing from functools import partial from pathlib import Path @@ -11,19 +10,20 @@ @asynccontextmanager async def _server(interface, port, threading_mode, tls=False): - certs_path = Path.cwd() / "tests" / "fixtures" / "tls" + certs_path = Path.cwd() / 'tests' / 'fixtures' / 'tls' tls_opts = ( - f"--ssl-certificate {certs_path / 'cert.pem'} " - f"--ssl-keyfile {certs_path / 'key.pem'} " - ) if tls else "" + (f"--ssl-certificate {certs_path / 'cert.pem'} " f"--ssl-keyfile {certs_path / 'key.pem'} ") if tls else '' + ) proc = await asyncio.create_subprocess_shell( - "".join([ - f"granian --interface {interface} --port {port} ", - f"--threads 1 --threading-mode {threading_mode} ", - tls_opts, - f"tests.apps.{interface}:app" - ]), - env=dict(os.environ) + ''.join( + [ + f'granian --interface {interface} --port {port} ', + f'--threads 1 --threading-mode {threading_mode} ', + tls_opts, + f'tests.apps.{interface}:app', + ] + ), + env=dict(os.environ), ) await asyncio.sleep(1) try: @@ -33,7 +33,7 @@ async def _server(interface, port, threading_mode, tls=False): await proc.wait() -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def server_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: sock.bind(('localhost', 0)) @@ -41,26 +41,26 @@ def server_port(): return sock.getsockname()[1] -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def asgi_server(server_port): - return partial(_server, "asgi", server_port) + return partial(_server, 'asgi', server_port) -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def rsgi_server(server_port): - return partial(_server, "rsgi", server_port) + return partial(_server, 'rsgi', server_port) -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def wsgi_server(server_port): - return partial(_server, "wsgi", server_port) + return partial(_server, 'wsgi', server_port) -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def server(server_port, request): return partial(_server, request.param, server_port) -@pytest.fixture(scope="function") +@pytest.fixture(scope='function') def server_tls(server_port, request): return partial(_server, request.param, server_port, tls=True) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 4b492719..a14cfde9 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -3,92 +3,59 @@ @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_scope(asgi_server, threading_mode): async with asgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/info?test=true") + res = httpx.get(f'http://localhost:{port}/info?test=true') assert res.status_code == 200 - assert res.headers["content-type"] == "application/json" + assert res.headers['content-type'] == 'application/json' data = res.json() - assert data['asgi'] == { - 'version': '3.0', - 'spec_version': '2.3' - } - assert data['type'] == "http" + assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'} + assert data['type'] == 'http' assert data['http_version'] == '1.1' assert data['scheme'] == 'http' - assert data['method'] == "GET" + assert data['method'] == 'GET' assert data['path'] == '/info' assert data['query_string'] == 'test=true' assert data['headers']['host'] == f'localhost:{port}' @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_body(asgi_server, threading_mode): async with asgi_server(threading_mode) as port: - res = httpx.post(f"http://localhost:{port}/echo", content="test") + res = httpx.post(f'http://localhost:{port}/echo', content='test') assert res.status_code == 200 - assert res.text == "test" + assert res.text == 'test' @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_body_large(asgi_server, threading_mode): - data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)]) + data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)]) async with asgi_server(threading_mode) as port: - res = httpx.post(f"http://localhost:{port}/echo", content=data) + res = httpx.post(f'http://localhost:{port}/echo', content=data) assert res.status_code == 200 assert res.text == data @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_app_error(asgi_server, threading_mode): async with asgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/err_app") + res = httpx.get(f'http://localhost:{port}/err_app') assert res.status_code == 500 @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_protocol_error(asgi_server, threading_mode): async with asgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/err_proto") + res = httpx.get(f'http://localhost:{port}/err_proto') assert res.status_code == 500 diff --git a/tests/test_https.py b/tests/test_https.py index 69be4761..aca05a9f 100644 --- a/tests/test_https.py +++ b/tests/test_https.py @@ -1,20 +1,18 @@ -import httpx import json import pathlib -import pytest import ssl + +import httpx +import pytest import websockets @pytest.mark.asyncio -@pytest.mark.parametrize("server_tls", ["asgi", "rsgi"], indirect=True) -@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) +@pytest.mark.parametrize('server_tls', ['asgi', 'rsgi'], indirect=True) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_http_scope(server_tls, threading_mode): async with server_tls(threading_mode) as port: - res = httpx.get( - f"https://localhost:{port}/info?test=true", - verify=False - ) + res = httpx.get(f'https://localhost:{port}/info?test=true', verify=False) assert res.status_code == 200 data = res.json() @@ -22,17 +20,14 @@ async def test_http_scope(server_tls, threading_mode): @pytest.mark.asyncio -@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_asgi_ws_scope(asgi_server, threading_mode): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem" + localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem' ssl_context.load_verify_locations(localhost_pem) async with asgi_server(threading_mode, tls=True) as port: - async with websockets.connect( - f"wss://localhost:{port}/ws_info?test=true", - ssl=ssl_context - ) as ws: + async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws: res = await ws.recv() data = json.loads(res) @@ -40,17 +35,14 @@ async def test_asgi_ws_scope(asgi_server, threading_mode): @pytest.mark.asyncio -@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_rsgi_ws_scope(rsgi_server, threading_mode): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - localhost_pem = pathlib.Path.cwd() / "tests" / "fixtures" / "tls" / "cert.pem" + localhost_pem = pathlib.Path.cwd() / 'tests' / 'fixtures' / 'tls' / 'cert.pem' ssl_context.load_verify_locations(localhost_pem) async with rsgi_server(threading_mode, tls=True) as port: - async with websockets.connect( - f"wss://localhost:{port}/ws_info?test=true", - ssl=ssl_context - ) as ws: + async with websockets.connect(f'wss://localhost:{port}/ws_info?test=true', ssl=ssl_context) as ws: res = await ws.recv() data = json.loads(res) diff --git a/tests/test_rsgi.py b/tests/test_rsgi.py index f356cafa..ddb3bee2 100644 --- a/tests/test_rsgi.py +++ b/tests/test_rsgi.py @@ -3,90 +3,60 @@ @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_scope(rsgi_server, threading_mode): async with rsgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/info?test=true") + res = httpx.get(f'http://localhost:{port}/info?test=true') assert res.status_code == 200 - assert res.headers["content-type"] == "application/json" + assert res.headers['content-type'] == 'application/json' data = res.json() - assert data['proto'] == "http" + assert data['proto'] == 'http' assert data['http_version'] == '1.1' assert data['rsgi_version'] == '1.2' assert data['scheme'] == 'http' - assert data['method'] == "GET" + assert data['method'] == 'GET' assert data['path'] == '/info' assert data['query_string'] == 'test=true' assert data['headers']['host'] == f'localhost:{port}' @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_body(rsgi_server, threading_mode): async with rsgi_server(threading_mode) as port: - res = httpx.post(f"http://localhost:{port}/echo", content="test") + res = httpx.post(f'http://localhost:{port}/echo', content='test') assert res.status_code == 200 - assert res.text == "test" + assert res.text == 'test' @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_body_stream_req(rsgi_server, threading_mode): - data = "".join([f"{idx}test".zfill(8) for idx in range(0, 5000)]) + data = ''.join([f'{idx}test'.zfill(8) for idx in range(0, 5000)]) async with rsgi_server(threading_mode) as port: - res = httpx.post(f"http://localhost:{port}/echos", content=data) + res = httpx.post(f'http://localhost:{port}/echos', content=data) assert res.status_code == 200 assert res.text == data @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_body_stream_res(rsgi_server, threading_mode): async with rsgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/stream") + res = httpx.get(f'http://localhost:{port}/stream') assert res.status_code == 200 - assert res.text == "test" * 3 + assert res.text == 'test' * 3 @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_app_error(rsgi_server, threading_mode): async with rsgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/err_app") + res = httpx.get(f'http://localhost:{port}/err_app') assert res.status_code == 500 diff --git a/tests/test_ws.py b/tests/test_ws.py index 698d1be8..49e715a4 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -1,56 +1,48 @@ import json -import pytest import sys + +import pytest import websockets -@pytest.mark.skipif(sys.platform == "win32", reason="skip on windows") +@pytest.mark.skipif(sys.platform == 'win32', reason='skip on windows') @pytest.mark.asyncio -@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True) -@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) +@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_messages(server, threading_mode): async with server(threading_mode) as port: - async with websockets.connect(f"ws://localhost:{port}/ws_echo") as ws: - await ws.send("foo") + async with websockets.connect(f'ws://localhost:{port}/ws_echo') as ws: + await ws.send('foo') res_text = await ws.recv() - await ws.send(b"foo") + await ws.send(b'foo') res_bytes = await ws.recv() - assert res_text == "foo" - assert res_bytes == b"foo" + assert res_text == 'foo' + assert res_bytes == b'foo' @pytest.mark.asyncio -@pytest.mark.parametrize("server", ["asgi", "rsgi"], indirect=True) -@pytest.mark.parametrize("threading_mode", ["runtime", "workers"]) +@pytest.mark.parametrize('server', ['asgi', 'rsgi'], indirect=True) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_reject(server, threading_mode): async with server(threading_mode) as port: with pytest.raises(websockets.InvalidStatusCode) as exc: - async with websockets.connect(f"ws://localhost:{port}/ws_reject") as ws: + async with websockets.connect(f'ws://localhost:{port}/ws_reject'): pass assert exc.value.status_code == 403 @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_asgi_scope(asgi_server, threading_mode): async with asgi_server(threading_mode) as port: - async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws: + async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws: res = await ws.recv() data = json.loads(res) - assert data['asgi'] == { - 'version': '3.0', - 'spec_version': '2.3' - } - assert data['type'] == "websocket" + assert data['asgi'] == {'version': '3.0', 'spec_version': '2.3'} + assert data['type'] == 'websocket' assert data['http_version'] == '1.1' assert data['scheme'] == 'ws' assert data['path'] == '/ws_info' @@ -59,16 +51,10 @@ async def test_asgi_scope(asgi_server, threading_mode): @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_rsgi_scope(rsgi_server, threading_mode): async with rsgi_server(threading_mode) as port: - async with websockets.connect(f"ws://localhost:{port}/ws_info?test=true") as ws: + async with websockets.connect(f'ws://localhost:{port}/ws_info?test=true') as ws: res = await ws.recv() data = json.loads(res) @@ -76,7 +62,7 @@ async def test_rsgi_scope(rsgi_server, threading_mode): assert data['http_version'] == '1.1' assert data['rsgi_version'] == '1.2' assert data['scheme'] == 'http' - assert data['method'] == "GET" + assert data['method'] == 'GET' assert data['path'] == '/ws_info' assert data['query_string'] == 'test=true' assert data['headers']['host'] == f'localhost:{port}' diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index a0122044..ab0ad999 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -3,24 +3,18 @@ @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_scope(wsgi_server, threading_mode): - payload = "body_payload" + payload = 'body_payload' async with wsgi_server(threading_mode) as port: - res = httpx.post(f"http://localhost:{port}/info?test=true", content=payload) + res = httpx.post(f'http://localhost:{port}/info?test=true', content=payload) assert res.status_code == 200 - assert res.headers["content-type"] == "application/json" + assert res.headers['content-type'] == 'application/json' data = res.json() assert data['scheme'] == 'http' - assert data['method'] == "POST" + assert data['method'] == 'POST' assert data['path'] == '/info' assert data['query_string'] == 'test=true' assert data['headers']['HTTP_HOST'] == f'localhost:{port}' @@ -28,47 +22,29 @@ async def test_scope(wsgi_server, threading_mode): @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_body(wsgi_server, threading_mode): async with wsgi_server(threading_mode) as port: - res = httpx.post(f"http://localhost:{port}/echo", content="test") + res = httpx.post(f'http://localhost:{port}/echo', content='test') assert res.status_code == 200 - assert res.text == "test" + assert res.text == 'test' @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_iterbody(wsgi_server, threading_mode): async with wsgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/iterbody") + res = httpx.get(f'http://localhost:{port}/iterbody') assert res.status_code == 200 - assert res.text == "test" * 3 + assert res.text == 'test' * 3 @pytest.mark.asyncio -@pytest.mark.parametrize( - "threading_mode", - [ - "runtime", - "workers" - ] -) +@pytest.mark.parametrize('threading_mode', ['runtime', 'workers']) async def test_app_error(wsgi_server, threading_mode): async with wsgi_server(threading_mode) as port: - res = httpx.get(f"http://localhost:{port}/err_app") + res = httpx.get(f'http://localhost:{port}/err_app') assert res.status_code == 500