Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to sygnal.py #275

Merged
merged 14 commits into from
Nov 22, 2021
1 change: 1 addition & 0 deletions changelog.d/275.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to sygnal/sygnal.py.
71 changes: 51 additions & 20 deletions sygnal/sygnal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,32 @@
import logging.config
import os
import sys
from typing import Any, Dict, Set, Union, cast

import jaeger_client
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
import opentracing
import prometheus_client
import yaml
from opentracing import Tracer
from opentracing.scope_managers.asyncio import AsyncioScopeManager
from twisted.internet import asyncioreactor, defer
from twisted.internet.defer import ensureDeferred
from twisted.internet.interfaces import (
IReactorCore,
IReactorFDSet,
IReactorPluggableNameResolver,
IReactorTCP,
)
from twisted.python import log as twisted_log
from twisted.python.failure import Failure
from zope.interface import Interface

from sygnal.http import PushGatewayApiServer
from sygnal.notifications import Pushkin

logger = logging.getLogger(__name__)

CONFIG_DEFAULTS: dict = {
CONFIG_DEFAULTS: Dict[str, Any] = {
"http": {"port": 5000, "bind_addresses": ["127.0.0.1"]},
"log": {"setup": {}, "access": {"x_forwarded_for": False}},
"metrics": {
Expand All @@ -52,18 +63,33 @@
}


class SygnalReactor(
IReactorFDSet,
IReactorPluggableNameResolver,
IReactorTCP,
IReactorCore,
Interface,
):
pass


class Sygnal:
def __init__(self, config, custom_reactor, tracer=opentracing.tracer):
def __init__(
self,
config: Dict[str, Any],
custom_reactor: SygnalReactor,
tracer: Union[Tracer, jaeger_client.Tracer] = opentracing.tracer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should be able to just use Tracer, because jaeger_client.Tracer inherits from opentracing.Tracer.

):
"""
Object that holds state for the entirety of a Sygnal instance.
Args:
config (dict): Configuration for this Sygnal
config: Configuration for this Sygnal
custom_reactor: a Twisted Reactor to use.
tracer (optional): an OpenTracing tracer. The default is the no-op tracer.
"""
self.config = config
self.reactor = custom_reactor
self.pushkins = {}
self.pushkins: Dict[str, Pushkin] = {}
self.tracer = tracer

logging_dict_config = config["log"]["setup"]
Expand Down Expand Up @@ -140,14 +166,14 @@ def __init__(self, config, custom_reactor, tracer=opentracing.tracer):
"Unknown OpenTracing implementation: %s.", tracecfg["impl"]
)

async def _make_pushkin(self, app_name, app_config):
async def _make_pushkin(self, app_name: str, app_config: Dict[str, Any]) -> Pushkin:
"""
Load and instantiate a pushkin.
Args:
app_name (str): The pushkin's app_id
app_config (dict): The pushkin's configuration
app_name: The pushkin's app_id
app_config: The pushkin's configuration

Returns (Pushkin):
Returns:
A pushkin of the desired type.
"""
app_type = app_config["type"]
Expand All @@ -165,7 +191,7 @@ async def _make_pushkin(self, app_name, app_config):
clarse = getattr(pushkin_module, to_construct)
return await clarse.create(app_name, self, app_config)

async def make_pushkins_then_start(self):
async def make_pushkins_then_start(self) -> None:
for app_id, app_cfg in self.config["apps"].items():
try:
self.pushkins[app_id] = await self._make_pushkin(app_id, app_cfg)
Expand All @@ -186,9 +212,9 @@ async def make_pushkins_then_start(self):
port = int(self.config["http"]["port"])
for interface in self.config["http"]["bind_addresses"]:
logger.info("Starting listening on %s port %d", interface, port)
self.reactor.listenTCP(port, pushgateway_api.site, interface=interface)
self.reactor.listenTCP(port, pushgateway_api.site, 50, interface=interface)
H-Shay marked this conversation as resolved.
Show resolved Hide resolved

def run(self):
def run(self) -> None:
"""
Attempt to run Sygnal and then exit the application.
"""
Expand All @@ -211,10 +237,10 @@ def start():
self.reactor.run()


def parse_config():
def parse_config() -> Dict[str, Any]:
"""
Find and load Sygnal's configuration file.
Returns (dict):
Returns:
A loaded configuration.
"""
config_path = os.getenv("SYGNAL_CONF", "sygnal.yaml")
Expand All @@ -231,15 +257,17 @@ def parse_config():
raise


def check_config(config):
def check_config(config: Dict[str, Any]) -> None:
"""
Lightly check the configuration and issue warnings as appropriate.
Args:
config: The loaded configuration.
"""
UNDERSTOOD_CONFIG_FIELDS = CONFIG_DEFAULTS.keys()

def check_section(section_name, known_keys, cfgpart=config):
def check_section(
section_name: str, known_keys: Set[str], cfgpart: Dict[str, Any] = config
) -> None:
nonunderstood = set(cfgpart[section_name].keys()).difference(known_keys)
if len(nonunderstood) > 0:
logger.warning(
Expand Down Expand Up @@ -271,14 +299,16 @@ def check_section(section_name, known_keys, cfgpart=config):
check_section("sentry", {"enabled", "dsn"}, cfgpart=config["metrics"])


def merge_left_with_defaults(defaults, loaded_config):
def merge_left_with_defaults(
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
defaults: Dict[str, Any], loaded_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Merge two configurations, with one of them overriding the other.
Args:
defaults (dict): A configuration of defaults
loaded_config (dict): A configuration, as loaded from disk.
defaults: A configuration of defaults
loaded_config: A configuration, as loaded from disk.

Returns (dict):
Returns:
A merged configuration, with loaded_config preferred over defaults.
"""
result = defaults.copy()
Expand Down Expand Up @@ -320,5 +350,6 @@ def merge_left_with_defaults(defaults, loaded_config):
config = parse_config()
config = merge_left_with_defaults(CONFIG_DEFAULTS, config)
check_config(config)
sygnal = Sygnal(config, custom_reactor=asyncioreactor.AsyncioSelectorReactor())
custom_reactor = cast(SygnalReactor, asyncioreactor.AsyncioSelectorReactor())
sygnal = Sygnal(config, custom_reactor)
sygnal.run()
17 changes: 14 additions & 3 deletions tests/test_apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from aioapns.common import NotificationResult

from sygnal import apnstruncate
from sygnal.apnspushkin import ApnsPushkin

from tests import testutils

Expand Down Expand Up @@ -55,7 +56,15 @@ def setUp(self):
super().setUp()

self.apns_pushkin_snotif = MagicMock()
self.sygnal.pushkins[PUSHKIN_ID]._send_notification = self.apns_pushkin_snotif
test_pushkin = self.get_test_pushkin(PUSHKIN_ID)
# type safety: using ignore here due to mypy not handling monkeypatching,
# see https://github.com/python/mypy/issues/2427
test_pushkin._send_notification = self.apns_pushkin_snotif # type: ignore

def get_test_pushkin(self, name: str) -> ApnsPushkin:
test_pushkin = self.sygnal.pushkins[name]
assert isinstance(test_pushkin, ApnsPushkin)
return test_pushkin

def config_setup(self, config):
super().config_setup(config)
Expand All @@ -71,7 +80,8 @@ def test_payload_truncation(self):
method.side_effect = testutils.make_async_magic_mock(
NotificationResult("notID", "200")
)
self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 240
test_pushkin = self.get_test_pushkin(PUSHKIN_ID)
test_pushkin.MAX_JSON_BODY_SIZE = 240

# Act
self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
Expand All @@ -94,7 +104,8 @@ def test_payload_truncation_test_validity(self):
method.side_effect = testutils.make_async_magic_mock(
NotificationResult("notID", "200")
)
self.sygnal.pushkins[PUSHKIN_ID].MAX_JSON_BODY_SIZE = 4096
test_pushkin = self.get_test_pushkin(PUSHKIN_ID)
test_pushkin.MAX_JSON_BODY_SIZE = 4096

# Act
self._request(self._make_dummy_notification([DEVICE_EXAMPLE]))
Expand Down
20 changes: 14 additions & 6 deletions tests/test_gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,17 @@ def config_setup(self, config):
"fcm_options": {"content_available": True, "mutable_content": True},
}

def get_test_pushkin(self, name: str) -> TestGcmPushkin:
pushkin = self.sygnal.pushkins[name]
assert isinstance(pushkin, TestGcmPushkin)
return pushkin

def test_expected(self):
"""
Tests the expected case: a good response from GCM leads to a good
response from Sygnal.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]}
)
Expand All @@ -101,7 +106,7 @@ def test_expected_with_default_payload(self):
Tests the expected case: a good response from GCM leads to a good
response from Sygnal.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]}
)
Expand All @@ -118,7 +123,7 @@ def test_rejected(self):
Tests the rejected case: a pushkey rejected to GCM leads to Sygnal
informing the homeserver of the rejection.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200, {"results": [{"registration_id": "spqr", "error": "NotRegistered"}]}
)
Expand All @@ -133,7 +138,7 @@ def test_batching(self):
Tests that multiple GCM devices have their notification delivered to GCM
together, instead of being delivered separately.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200,
{
Expand All @@ -149,6 +154,7 @@ def test_batching(self):
)

self.assertEqual(resp, {"rejected": []})
assert gcm.last_request_body is not None
self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEqual(gcm.num_requests, 1)

Expand All @@ -159,7 +165,7 @@ def test_batching_individual_failure(self):
and that if only one device ID is rejected, then only that device is
reported to the homeserver as rejected.
"""
gcm = self.sygnal.pushkins["com.example.gcm"]
gcm = self.get_test_pushkin("com.example.gcm")
gcm.preload_with_response(
200,
{
Expand All @@ -175,6 +181,7 @@ def test_batching_individual_failure(self):
)

self.assertEqual(resp, {"rejected": ["spqr2"]})
assert gcm.last_request_body is not None
self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"])
self.assertEqual(gcm.num_requests, 1)

Expand All @@ -183,13 +190,14 @@ def test_fcm_options(self):
Tests that the config option `fcm_options` allows setting a base layer
of options to pass to FCM, for example ones that would be needed for iOS.
"""
gcm = self.sygnal.pushkins["com.example.gcm.ios"]
gcm = self.get_test_pushkin("com.example.gcm.ios")
gcm.preload_with_response(
200, {"results": [{"registration_id": "spqr_new", "message_id": "msg42"}]}
)

resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE_IOS]))

self.assertEqual(resp, {"rejected": []})
assert gcm.last_request_body is not None
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(gcm.last_request_body["mutable_content"], True)
self.assertEqual(gcm.last_request_body["content_available"], True)
7 changes: 6 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from aioapns.common import NotificationResult

from sygnal.apnspushkin import ApnsPushkin

from tests import testutils

PUSHKIN_ID_1 = "com.example.apns"
Expand Down Expand Up @@ -62,7 +64,10 @@ def setUp(self):

self.apns_pushkin_snotif = MagicMock()
for key, value in self.sygnal.pushkins.items():
value._send_notification = self.apns_pushkin_snotif
assert isinstance(value, ApnsPushkin)
# type safety: ignore is used here due to mypy not handling monkeypatching,
# see https://github.com/python/mypy/issues/2427
value._send_notification = self.apns_pushkin_snotif # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where possible, prefer #type: ignore[code] where code is one of mypy's error codes. See https://mypy.readthedocs.io/en/stable/error_codes.html and the two pages following it for gory details.

(but easier to remove the comment and see what the error code you get in the terminal is.)


def config_setup(self, config):
super().config_setup(config)
Expand Down
5 changes: 4 additions & 1 deletion tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@ def setUp(self):

config = merge_left_with_defaults(CONFIG_DEFAULTS, config)

self.sygnal = Sygnal(config, reactor)
self.sygnal = Sygnal(config, reactor) # type: ignore
self.reactor = reactor

start_deferred = ensureDeferred(self.sygnal.make_pushkins_then_start())

while not start_deferred.called:
# we need to advance until the pushkins have started up
assert isinstance(self.sygnal.reactor, ExtendedMemoryReactorClock)
self.sygnal.reactor.advance(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have the reactor, maybe use self.reactor.advance etc. here and in the two cases below. I think that will let us avoid the assertions.

(Nothing wrong with as it is, just trying to keep things concise where possible)

self.sygnal.reactor.wait_for_work(lambda: start_deferred.called)

Expand Down Expand Up @@ -154,6 +155,7 @@ def _request(self, payload: Union[str, dict]) -> Union[dict, int]:

while not channel.done:
# we need to advance until the request has been finished
assert isinstance(self.sygnal.reactor, ExtendedMemoryReactorClock)
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(lambda: channel.done)

Expand Down Expand Up @@ -199,6 +201,7 @@ def all_channels_done():

while not all_channels_done():
# we need to advance until the request has been finished
assert isinstance(self.sygnal.reactor, ExtendedMemoryReactorClock)
self.sygnal.reactor.advance(1)
self.sygnal.reactor.wait_for_work(all_channels_done)

Expand Down