Skip to content

Commit

Permalink
Merge pull request #275 from matrix-org/type_hints_sygnal.py
Browse files Browse the repository at this point in the history
Add type hints to `sygnal.py`
  • Loading branch information
H-Shay authored Nov 22, 2021
2 parents d9fb1a3 + 1e2fe1f commit c5bb2dc
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 35 deletions.
1 change: 1 addition & 0 deletions changelog.d/275.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to sygnal/sygnal.py.
70 changes: 50 additions & 20 deletions sygnal/sygnal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,31 @@
import logging.config
import os
import sys
from typing import Any, Dict, Set, cast

import opentracing
import prometheus_client
import yaml
from opentracing import Tracer
from opentracing.scope_managers.asyncio import AsyncioScopeManager
from twisted.internet import asyncioreactor, defer
from twisted.internet.defer import ensureDeferred
from twisted.internet.interfaces import (
IReactorCore,
IReactorFDSet,
IReactorPluggableNameResolver,
IReactorTCP,
)
from twisted.python import log as twisted_log
from twisted.python.failure import Failure
from zope.interface import Interface

from sygnal.http import PushGatewayApiServer
from sygnal.notifications import Pushkin

logger = logging.getLogger(__name__)

CONFIG_DEFAULTS: dict = {
CONFIG_DEFAULTS: Dict[str, Any] = {
"http": {"port": 5000, "bind_addresses": ["127.0.0.1"]},
"log": {"setup": {}, "access": {"x_forwarded_for": False}},
"metrics": {
Expand All @@ -52,18 +62,33 @@
}


class SygnalReactor(
IReactorFDSet,
IReactorPluggableNameResolver,
IReactorTCP,
IReactorCore,
Interface,
):
pass


class Sygnal:
def __init__(self, config, custom_reactor, tracer=opentracing.tracer):
def __init__(
self,
config: Dict[str, Any],
custom_reactor: SygnalReactor,
tracer: Tracer = opentracing.tracer,
):
"""
Object that holds state for the entirety of a Sygnal instance.
Args:
config (dict): Configuration for this Sygnal
config: Configuration for this Sygnal
custom_reactor: a Twisted Reactor to use.
tracer (optional): an OpenTracing tracer. The default is the no-op tracer.
"""
self.config = config
self.reactor = custom_reactor
self.pushkins = {}
self.pushkins: Dict[str, Pushkin] = {}
self.tracer = tracer

logging_dict_config = config["log"]["setup"]
Expand Down Expand Up @@ -140,14 +165,14 @@ def __init__(self, config, custom_reactor, tracer=opentracing.tracer):
"Unknown OpenTracing implementation: %s.", tracecfg["impl"]
)

async def _make_pushkin(self, app_name, app_config):
async def _make_pushkin(self, app_name: str, app_config: Dict[str, Any]) -> Pushkin:
"""
Load and instantiate a pushkin.
Args:
app_name (str): The pushkin's app_id
app_config (dict): The pushkin's configuration
app_name: The pushkin's app_id
app_config: The pushkin's configuration
Returns (Pushkin):
Returns:
A pushkin of the desired type.
"""
app_type = app_config["type"]
Expand All @@ -165,7 +190,7 @@ async def _make_pushkin(self, app_name, app_config):
clarse = getattr(pushkin_module, to_construct)
return await clarse.create(app_name, self, app_config)

async def make_pushkins_then_start(self):
async def make_pushkins_then_start(self) -> None:
for app_id, app_cfg in self.config["apps"].items():
try:
self.pushkins[app_id] = await self._make_pushkin(app_id, app_cfg)
Expand All @@ -186,9 +211,9 @@ async def make_pushkins_then_start(self):
port = int(self.config["http"]["port"])
for interface in self.config["http"]["bind_addresses"]:
logger.info("Starting listening on %s port %d", interface, port)
self.reactor.listenTCP(port, pushgateway_api.site, interface=interface)
self.reactor.listenTCP(port, pushgateway_api.site, 50, interface=interface)

def run(self):
def run(self) -> None:
"""
Attempt to run Sygnal and then exit the application.
"""
Expand All @@ -211,10 +236,10 @@ def start():
self.reactor.run()


def parse_config():
def parse_config() -> Dict[str, Any]:
"""
Find and load Sygnal's configuration file.
Returns (dict):
Returns:
A loaded configuration.
"""
config_path = os.getenv("SYGNAL_CONF", "sygnal.yaml")
Expand All @@ -231,15 +256,17 @@ def parse_config():
raise


def check_config(config):
def check_config(config: Dict[str, Any]) -> None:
"""
Lightly check the configuration and issue warnings as appropriate.
Args:
config: The loaded configuration.
"""
UNDERSTOOD_CONFIG_FIELDS = CONFIG_DEFAULTS.keys()

def check_section(section_name, known_keys, cfgpart=config):
def check_section(
section_name: str, known_keys: Set[str], cfgpart: Dict[str, Any] = config
) -> None:
nonunderstood = set(cfgpart[section_name].keys()).difference(known_keys)
if len(nonunderstood) > 0:
logger.warning(
Expand Down Expand Up @@ -271,14 +298,16 @@ def check_section(section_name, known_keys, cfgpart=config):
check_section("sentry", {"enabled", "dsn"}, cfgpart=config["metrics"])


def merge_left_with_defaults(defaults, loaded_config):
def merge_left_with_defaults(
defaults: Dict[str, Any], loaded_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Merge two configurations, with one of them overriding the other.
Args:
defaults (dict): A configuration of defaults
loaded_config (dict): A configuration, as loaded from disk.
defaults: A configuration of defaults
loaded_config: A configuration, as loaded from disk.
Returns (dict):
Returns:
A merged configuration, with loaded_config preferred over defaults.
"""
result = defaults.copy()
Expand Down Expand Up @@ -320,5 +349,6 @@ def merge_left_with_defaults(defaults, loaded_config):
config = parse_config()
config = merge_left_with_defaults(CONFIG_DEFAULTS, config)
check_config(config)
sygnal = Sygnal(config, custom_reactor=asyncioreactor.AsyncioSelectorReactor())
custom_reactor = cast(SygnalReactor, asyncioreactor.AsyncioSelectorReactor())
sygnal = Sygnal(config, custom_reactor)
sygnal.run()
17 changes: 14 additions & 3 deletions tests/test_apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from aioapns.common import NotificationResult

from sygnal import apnstruncate
from sygnal.apnspushkin import ApnsPushkin

from tests import testutils

Expand Down Expand Up @@ -55,7 +56,15 @@ def setUp(self):
super().setUp()

self.apns_pushkin_snotif = MagicMock()
self.sygnal.pushkins[PUSHKIN_ID]._send_notification = self.apns_pushkin_snotif
test_pushkin = self.get_test_pushkin(PUSHKIN_ID)
# type safety: using ignore here due to mypy not handling monkeypatching,
# see https://github.com/python/mypy/issues/2427
test_pushkin._send_notification = self.apns_pushkin_snotif # type: ignore[assignment] # noqa: E501

def get_test_pushkin(self, name: str) -> ApnsPushkin:
test_pushkin = self.sygnal.pushkins[name]
assert isinstance(test_pushkin, ApnsPushkin)
return test_pushkin

def config_setup(self, config):
super().config_setup(config)
Expand All @@ -71,7 +80,8 @@ def test_payload_truncation(self):
method.side_effect = testutils.make_async_magic_mock(
NotificationResult("notID", "200")
)
self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 240
test_pushkin = self.get_test_pushkin(PUSHKIN_ID)
test_pushkin.MAX_JSON_BODY_SIZE = 240

# Act
self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
Expand All @@ -94,7 +104,8 @@ def test_payload_truncation_test_validity(self):
method.side_effect = testutils.make_async_magic_mock(
NotificationResult("notID", "200")
)
self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 4096
test_pushkin = self.get_test_pushkin(PUSHKIN_ID)
test_pushkin.MAX_JSON_BODY_SIZE = 4096

# Act
self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
Expand Down
20 changes: 14 additions & 6 deletions tests/test_gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,17 @@ def config_setup(self, config):
"fcm_options": {"content_available": True, "mutable_content": True},
}

def get_test_pushkin(self, name: str) -> TestGcmPushkin:
pushkin = self.sygnal.pushkins[name]
assert isinstance(pushkin, TestGcmPushkin)
return pushkin

def test_expected(self):
"""
Tests the expected case: a good response from GCM leads to a good
response from Sygnal.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]}
)
Expand All @@ -101,7 +106,7 @@ def test_expected_with_default_payload(self):
Tests the expected case: a good response from GCM leads to a good
response from Sygnal.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]}
)
Expand All @@ -118,7 +123,7 @@ def test_rejected(self):
Tests the rejected case: a pushkey rejected to GCM leads to Sygnal
informing the homeserver of the rejection.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200, {"results": [{"registration_id": "spqr", "error": "NotRegistered"}]}
)
Expand All @@ -133,7 +138,7 @@ def test_batching(self):
Tests that multiple GCM devices have their notification delivered to GCM
together, instead of being delivered separately.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200,
{
Expand All @@ -149,6 +154,7 @@ def test_batching(self):
)

self.assertEqual(resp, {"rejected": []})
assert gcm.last_request_body is not None
self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEqual(gcm.num_requests, 1)

Expand All @@ -159,7 +165,7 @@ def test_batching_individual_failure(self):
and that if only one device ID is rejected, then only that device is
reported to the homeserver as rejected.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200,
{
Expand All @@ -175,6 +181,7 @@ def test_batching_individual_failure(self):
)

self.assertEqual(resp, {"rejected": ["spqr2"]})
assert gcm.last_request_body is not None
self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEqual(gcm.num_requests, 1)

Expand All @@ -183,13 +190,14 @@ def test_fcm_options(self):
Tests that the config option `fcm_options` allows setting a base layer
of options to pass to FCM, for example ones that would be needed for iOS.
"""
gcm = self.sygnal.pushkins["com.example.gcm.ios"]
gcm = self.get_test_pushkin("com.example.gcm.ios")
gcm.preload_with_response(
200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]}
)

resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE_IOS]))

self.assertEqual(resp, {"rejected": []})
assert gcm.last_request_body is not None
self.assertEqual(gcm.last_request_body["mutable_content"], True)
self.assertEqual(gcm.last_request_body["content_available"], True)
7 changes: 6 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from aioapns.common import NotificationResult

from sygnal.apnspushkin import ApnsPushkin

from tests import testutils

PUSHKIN_ID_1 = "com.example.apns"
Expand Down Expand Up @@ -62,7 +64,10 @@ def setUp(self):

self.apns_pushkin_snotif = MagicMock()
for key, value in self.sygnal.pushkins.items():
value._send_notification = self.apns_pushkin_snotif
assert isinstance(value, ApnsPushkin)
# type safety: ignore is used here due to mypy not handling monkeypatching,
# see https://github.com/python/mypy/issues/2427
value._send_notification = self.apns_pushkin_snotif # type: ignore[assignment] # noqa: E501

def config_setup(self, config):
super().config_setup(config)
Expand Down
11 changes: 6 additions & 5 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def setUp(self):

config = merge_left_with_defaults(CONFIG_DEFAULTS, config)

self.sygnal = Sygnal(config, reactor)
self.sygnal = Sygnal(config, reactor) # type: ignore[arg-type]
self.reactor = reactor

start_deferred = ensureDeferred(self.sygnal.make_pushkins_then_start())

while not start_deferred.called:
# we need to advance until the pushkins have started up
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(lambda: start_deferred.called)
self.reactor.advance(1)
self.reactor.wait_for_work(lambda: start_deferred.called)

# sygnal should have started a single (fake) tcp listener
listeners = self.reactor.tcpServers
Expand Down Expand Up @@ -154,8 +154,8 @@ def _request(self, payload: Union[str, dict]) -> Union[dict, int]:

while not channel.done:
# we need to advance until the request has been finished
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(lambda: channel.done)
self.reactor.advance(1)
self.reactor.wait_for_work(lambda: channel.done)

assert channel.done
assert channel.result is not None
Expand Down Expand Up @@ -199,6 +199,7 @@ def all_channels_done():

while not all_channels_done():
# we need to advance until the request has been finished
assert isinstance(self.sygnal.reactor, ExtendedMemoryReactorClock)
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(all_channels_done)

Expand Down

0 comments on commit c5bb2dc

Please sign in to comment.