Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to sygnal/http.py #273

Merged
merged 5 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/273.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to sygnal/http.py.
Empty file added stubs/twisted/__init__.pyi
Empty file.
Empty file.
66 changes: 66 additions & 0 deletions stubs/twisted/web/http.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import typing
from typing import AnyStr, Dict, List, Optional

from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, ITCPTransport
from twisted.logger import Logger
from twisted.web.http_headers import Headers
from twisted.web.iweb import IRequest, IAccessLogFormatter
from zope.interface import implementer, provider

class HTTPChannel: ...

# Type ignore: I don't want to respecify the methods on the interface that we
# don't use.
@implementer(IRequest) # type: ignore[misc]
class Request:
code = 200
# Instance attributes mentioned in the docstring
method: bytes
uri: bytes
path: bytes
args: Dict[bytes, List[bytes]]
content: typing.BinaryIO
cookies: List[bytes]
requestHeaders: Headers
responseHeaders: Headers
notifications: List[Deferred[None]]
_disconnected: bool
_log: Logger

# Other instance attributes set in __init__
channel: HTTPChannel
client: IAddress
# This was hard to derive.
# - `transport` is `self.channel.transport`
# - `self.channel` is set in the constructor, and looks like it's always
# an `HTTPChannel`.
# - `HTTPChannel` is a `LineReceiver` is a `Protocol` is a `BaseProtocol`.
# - `BaseProtocol` sets `self.transport` to initially `None`.
#
# Note that `transport` is set to an ITransport in makeConnection,
# so is almost certainly not None by the time it reaches our code.
#
# I've narrowed this to ITCPTransport because
# - we use `self.transport.abortConnection`, which belongs to that interface
# - twisted does too! in its implementation of HTTPChannel.forceAbortClient
transport: Optional[ITCPTransport]
def __init__(self, channel: HTTPChannel): ...
def getHeader(self, key: AnyStr) -> Optional[AnyStr]: ...
def handleContentChunk(self, data: bytes) -> None: ...
def setResponseCode(self, code: int, message: Optional[bytes] = ...) -> None: ...
def setHeader(self, k: AnyStr, v: AnyStr) -> None: ...
def write(self, data: bytes) -> None: ...
def finish(self) -> None: ...
def getClientAddress(self) -> IAddress: ...

@provider(IAccessLogFormatter)
def proxiedLogFormatter(timestamp: str, request: Request) -> str:
...

@provider(IAccessLogFormatter)
def combinedLogFormatter(timestamp: str, request: Request) -> str:
...

def datetimeToLogString(msSinceEpoch: Optional[int]=None) -> str:
...
55 changes: 35 additions & 20 deletions sygnal/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import sys
import time
import traceback
from typing import TYPE_CHECKING, Callable, List, Union
from uuid import uuid4

from opentracing import Format, logs, tags
from opentracing import Format, Span, logs, tags
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.defer import ensureDeferred
from twisted.web import server
from twisted.web.http import (
Request,
combinedLogFormatter,
datetimeToLogString,
proxiedLogFormatter,
Expand All @@ -38,9 +40,12 @@
InvalidNotificationException,
NotificationDispatchException,
)
from sygnal.notifications import Notification, NotificationContext
from sygnal.notifications import Notification, NotificationContext, Pushkin
from sygnal.utils import NotificationLoggerAdapter, json_decoder

if TYPE_CHECKING:
from sygnal.sygnal import Sygnal

logger = logging.getLogger(__name__)

NOTIFS_RECEIVED_COUNTER = Counter(
Expand Down Expand Up @@ -77,31 +82,31 @@


class V1NotifyHandler(Resource):
def __init__(self, sygnal):
def __init__(self, sygnal: "Sygnal"):
super().__init__()
self.sygnal = sygnal

isLeaf = True

def _make_request_id(self):
def _make_request_id(self) -> str:
"""
Generates a request ID, intended to be unique, for a request so it can
be followed through logging.
Returns: a request ID for the request.
"""
return str(uuid4())

def render_POST(self, request):
def render_POST(self, request: Request) -> Union[int, bytes]:
response = self._handle_request(request)
if response != NOT_DONE_YET:
PUSHGATEWAY_HTTP_RESPONSES_COUNTER.labels(code=request.code).inc()
return response

def _handle_request(self, request):
def _handle_request(self, request: Request) -> Union[int, bytes]:
"""
Actually handle the request.
Args:
request (Request): The request, corresponding to a POST request.
request: The request, corresponding to a POST request.

Returns:
Either a str instance or NOT_DONE_YET.
Expand Down Expand Up @@ -203,12 +208,12 @@ async def cb():
if not root_span_accounted_for:
root_span.finish()

def find_pushkins(self, appid):
def find_pushkins(self, appid: str) -> List[Pushkin]:
"""Finds matching pushkins in self.sygnal.pushkins according to the appid.


Args:
appid (str): app identifier to search in self.sygnal.pushkins.
appid: app identifier to search in self.sygnal.pushkins.

Returns:
list of `Pushkin`: If it finds a specific pushkin with
Expand All @@ -227,16 +232,23 @@ def find_pushkins(self, appid):
result.append(value)
return result

async def _handle_dispatch(self, root_span, request, log, notif, context):
async def _handle_dispatch(
self,
root_span: Span,
request: Request,
log: NotificationLoggerAdapter,
notif: Notification,
context: NotificationContext,
) -> None:
"""
Actually handle the dispatch of notifications to devices, sequentially
for simplicity.

root_span: the OpenTracing span
request: the Twisted Web Request
log: the logger to use
notif (Notification): the notification to dispatch
context (NotificationContext): the context of the notification
notif: the notification to dispatch
context: the context of the notification
"""
try:
rejected = []
Expand All @@ -252,7 +264,7 @@ async def _handle_dispatch(self, root_span, request, log, notif, context):
continue

if len(found_pushkins) > 1:
log.warning("Got notification for an ambigious app ID %s", appid)
log.warning("Got notification for an ambiguous app ID %s", appid)
rejected.append(d.pushkey)
continue

Expand Down Expand Up @@ -299,7 +311,7 @@ async def _handle_dispatch(self, root_span, request, log, notif, context):


class HealthHandler(Resource):
def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
"""
`/health` is used for automatic checking of whether the service is up.
It should just return a blank 200 OK response.
Expand All @@ -311,14 +323,15 @@ class SizeLimitingRequest(server.Request):
# Arbitrarily limited to 512 KiB.
MAX_REQUEST_SIZE = 512 * 1024

def handleContentChunk(self, data):
def handleContentChunk(self, data: bytes) -> None:
# we should have a content by now
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self.MAX_REQUEST_SIZE:
logger.info(
"Aborting connection from %s because the request exceeds maximum size",
self.client.host,
self.client,
)
assert self.transport is not None
self.transport.abortConnection()
return

Expand All @@ -331,12 +344,14 @@ class SygnalLoggedSite(server.Site):
Sygnal.
"""

def __init__(self, *args, log_formatter, **kwargs):
def __init__(
self, *args, log_formatter: Callable[[str, server.Request], str], **kwargs
):
super().__init__(*args, **kwargs)
self.log_formatter = log_formatter
self.logger = logging.getLogger("sygnal.access")

def log(self, request):
def log(self, request: server.Request) -> None:
"""Log this request. Called by request.finish."""
# this also works around a bug in twisted.web.http.HTTPFactory which uses a
# monotonic time as an epoch time.
Expand All @@ -345,8 +360,8 @@ def log(self, request):
self.logger.info("Handled request: %s", line)


class PushGatewayApiServer:
def __init__(self, sygnal):
class PushGatewayApiServer(object):
def __init__(self, sygnal: "Sygnal"):
"""
Initialises the /_matrix/push/* (Push Gateway API) server.
Args:
Expand Down