diff --git a/autopush/jwt.py b/autopush/jwt.py index 1fb711c1..cde23e96 100644 --- a/autopush/jwt.py +++ b/autopush/jwt.py @@ -8,6 +8,9 @@ from cryptography.hazmat.primitives import hashes from pyasn1.error import PyAsn1Error from twisted.logger import Logger +from typing import Tuple # noqa + +from autopush.types import JSONDict # noqa def repad(string): @@ -34,7 +37,7 @@ class VerifyJWT(object): @staticmethod def extract_signature(auth): - # type: (str) -> tuple() + # type: (str) -> Tuple[str, str] """Fix the JWT auth token. The JWA spec defines the signature to be a pair of 32octet encoded @@ -62,7 +65,7 @@ def extract_signature(auth): @staticmethod def decode(token, key): - # type (str, str) -> dict() + # type (str, str) -> JSONDict """Decode a web token into a assertion dictionary. This attempts to rectify both ecdsa and openssl generated diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index a8663f19..6c686dcd 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -80,18 +80,6 @@ def setUp(self): d = self.finish_deferred = Deferred() self.message.finish = lambda: d.callback(True) - def _make_req(self, id=None, headers=None, body='', - rargs=None, *args, **kwargs): - if headers is None: - headers = {} - self.request_mock.body = body - self.request_mock.headers.update(headers) - self.message.path_kwargs = {} - self.message.path_args = rargs or args or [] - if id is not None: - self.message.path_kwargs = {"message_id": id} - return dict() - def test_delete_token_invalid(self): self.fernet_mock.configure_mock(**{ "decrypt.side_effect": InvalidToken}) @@ -100,7 +88,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400, reason=None) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req(id='')) + self.message.delete(message_id='') return self.finish_deferred def test_delete_token_wrong_components(self): @@ -110,7 +98,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400, reason=None) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req("ignored")) + self.message.delete(message_id="ignored") return self.finish_deferred def test_delete_token_wrong_kind(self): @@ -121,7 +109,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400, reason=None) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req('ignored')) + self.message.delete(message_id='ignored') return self.finish_deferred def test_delete_invalid_timestamp_token(self): @@ -132,7 +120,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400, reason=None) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req('ignored')) + self.message.delete(message_id='ignored') return self.finish_deferred def test_delete_success(self): @@ -146,7 +134,7 @@ def handle_finish(result): self.status_mock.assert_called_with(204) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req("123-456")) + self.message.delete(message_id="123-456") return self.finish_deferred def test_delete_topic_success(self): @@ -160,22 +148,7 @@ def handle_finish(result): self.status_mock.assert_called_with(204) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req("123-456")) - return self.finish_deferred - - def test_delete_topic_success2(self): - tok = ":".join(["01", dummy_uaid.hex, str(dummy_chid), "Inbox"]) - self.fernet_mock.decrypt.return_value = tok - self.message_mock.configure_mock(**{ - "delete_message.return_value": True}) - - def handle_finish(result): - self.message_mock.delete_message.assert_called() - self.status_mock.assert_called_with(204) - self.finish_deferred.addCallback(handle_finish) - - self.message.delete(self._make_req(id=None, - rargs=["123-456"])) + self.message.delete(message_id="123-456") return self.finish_deferred def test_delete_topic_error_parts(self): @@ -188,7 +161,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400, reason=None) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req("123-456")) + self.message.delete(message_id="123-456") return self.finish_deferred def test_delete_db_error(self): @@ -203,7 +176,7 @@ def handle_finish(result): self.status_mock.assert_called_with(503, reason=None) self.finish_deferred.addCallback(handle_finish) - self.message.delete(self._make_req("ignored")) + self.message.delete(message_id="ignored") return self.finish_deferred @@ -246,17 +219,25 @@ def setUp(self): self.reg.finish = lambda: d.callback(True) self.settings = settings - def _make_req(self, router_type="", router_token="", uaid=None, - chid=None, body="", headers=None): - if headers is None: - headers = {} - self.request_mock.body = body or self.request_mock.body - self.request_mock.headers.update(headers) - self.reg.path_kwargs = {"router_type": router_type, - "router_token": router_token, - "uaid": uaid, - "chid": chid} - return dict() + def _req(self, meth, router_type="", router_token="", uaid=None, + chid=None): + return meth( + router_type=router_type, + router_token=router_token, + uaid=uaid, + chid=chid) + + def _post(self, **kwargs): + return self._req(self.reg.post, **kwargs) + + def _put(self, **kwargs): + return self._req(self.reg.put, **kwargs) + + def _delete(self, **kwargs): + return self._req(self.reg.delete, **kwargs) + + def _get(self, **kwargs): + return self._req(self.reg.get, **kwargs) def test_base_tags(self): self.reg._base_tags = [] @@ -356,8 +337,7 @@ def handle_finish(value): ok_("secret" in call_arg) self.finish_deferred.addBoth(handle_finish) - self.reg.post(self._make_req("simplepush", "", - body=self.reg.request.body)) + self._post(router_type="simplepush") return self.finish_deferred def test_post_gcm(self, *args): @@ -396,7 +376,7 @@ def restore(*args, **kwargs): old_func = uuid.uuid4 ids = [dummy_uaid, dummy_chid] uuid.uuid4 = lambda: ids.pop() - self.reg.post(self._make_req("gcm", "182931248179192")) + self._post(router_type="gcm", router_token="182931248179192") return self.finish_deferred def test_post_invalid_args(self, *args): @@ -410,7 +390,7 @@ def handle_finish(value): self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req()) + self._post() return self.finish_deferred def test_post_bad_router_type(self, *args): @@ -431,7 +411,7 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req()) + self._post() return self.finish_deferred def test_post_bad_router_register(self, *args): @@ -452,8 +432,7 @@ def handle_finish(value): self._check_error(rexc.status_code, rexc.errno, "") self.finish_deferred.addBoth(handle_finish) - self.reg.post(self._make_req("simplepush", "", - body=self.reg.request.body)) + self._post(router_type="simplepush") return self.finish_deferred def test_post_existing_uaid(self, *args): @@ -480,7 +459,7 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req(router_type="test", uaid=dummy_uaid.hex)) + self._post(router_type="test", uaid=dummy_uaid.hex) return self.finish_deferred def test_post_bad_uaid(self, *args): @@ -501,8 +480,7 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req(router_type="simplepush", - uaid='invalid')) + self._post(router_type="simplepush", uaid='invalid') return self.finish_deferred def test_no_uaid(self): @@ -512,9 +490,9 @@ def handle_finish(value): self.finish_deferred.addCallback(handle_finish) self.settings.router.get_uaid = Mock() self.settings.router.get_uaid.side_effect = ItemNotFound - self.reg.post(self._make_req(router_type="webpush", - uaid=dummy_uaid.hex, - chid=str(dummy_chid))) + self._post(router_type="webpush", + uaid=dummy_uaid.hex, + chid=str(dummy_chid)) return self.finish_deferred def test_no_auth(self): @@ -522,21 +500,22 @@ def handle_finish(value): self._check_error(401, 109, "Unauthorized") self.finish_deferred.addCallback(handle_finish) - self.reg.post(self._make_req(router_type="webpush", - uaid=dummy_uaid.hex, - chid=str(dummy_chid))) + self._post(router_type="webpush", + uaid=dummy_uaid.hex, + chid=str(dummy_chid)) return self.finish_deferred def test_bad_body(self): + self.reg.request.body = "{invalid" + def handle_finish(value): self._check_error(401, 108, "Unauthorized") self.finish_deferred.addCallback(handle_finish) - self.reg.post(self._make_req(router_type="webpush", - uaid=dummy_uaid.hex, - chid=str(dummy_chid), - body="{invalid")) + self._post(router_type="webpush", + uaid=dummy_uaid.hex, + chid=str(dummy_chid)) return self.finish_deferred def test_post_bad_params(self, *args): @@ -555,9 +534,9 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = "WebPush Invalid" - self.reg.post(self._make_req(router_type="simplepush", - uaid=dummy_uaid.hex, - chid=str(dummy_chid))) + self._post(router_type="simplepush", + uaid=dummy_uaid.hex, + chid=str(dummy_chid)) return self.finish_deferred def test_post_uaid_chid(self, *args): @@ -586,9 +565,9 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req(router_type="simplepush", - uaid=dummy_uaid.hex, - chid=str(dummy_chid))) + self._post(router_type="simplepush", + uaid=dummy_uaid.hex, + chid=str(dummy_chid)) return self.finish_deferred def test_post_uaid_critical_failure(self, *args): @@ -611,9 +590,9 @@ def handle_finish(value): self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req(router_type="simplepush", - uaid=dummy_uaid.hex, - chid=str(dummy_chid))) + self._post(router_type="simplepush", + uaid=dummy_uaid.hex, + chid=str(dummy_chid)) return self.finish_deferred def test_post_nochid(self): @@ -642,8 +621,7 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req(router_type="simplepush", - uaid=dummy_uaid.hex)) + self._post(router_type="simplepush", uaid=dummy_uaid.hex) return self.finish_deferred def test_post_with_app_server_key(self): @@ -688,8 +666,7 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.post(self._make_req(router_type="simplepush", - uaid=dummy_uaid.hex)) + self._post(router_type="simplepush", uaid=dummy_uaid.hex) return self.finish_deferred def test_put(self): @@ -716,7 +693,7 @@ def restore(*args, **kwargs): self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.put(self._make_req(router_type='test', uaid=dummy_uaid.hex)) + self._put(router_type='test', uaid=dummy_uaid.hex) return self.finish_deferred def test_put_bad_auth(self): @@ -732,8 +709,7 @@ def restore(*args, **kwargs): uuid.uuid4 = lambda: dummy_uaid self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) - self.reg.put(self._make_req(router_type="test", - uaid=dummy_uaid.hex)) + self._put(router_type="test", uaid=dummy_uaid.hex) return self.finish_deferred def test_put_bad_arguments(self, *args): @@ -754,7 +730,7 @@ def restore(*args, **kwargs): uuid.uuid4 = lambda: dummy_chid self.finish_deferred.addBoth(restore) self.finish_deferred.addCallback(handle_finish) - self.reg.put(self._make_req(uaid=dummy_uaid.hex)) + self._put(uaid=dummy_uaid.hex) return self.finish_deferred def test_put_bad_router_register(self): @@ -767,7 +743,7 @@ def handle_finish(value): self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth - self.reg.put(self._make_req(router_type='test', uaid=dummy_uaid.hex)) + self._put(router_type='test', uaid=dummy_uaid.hex) return self.finish_deferred def test_delete_bad_chid_value(self): @@ -781,8 +757,10 @@ def handle_finish(value): self._check_error(410, 106, "") self.finish_deferred.addCallback(handle_finish) - self.reg.delete(self._make_req("test", "test", dummy_uaid.hex, - "invalid")) + self._delete(router_type="test", + router_token="test", + uaid=dummy_uaid.hex, + chid="invalid") return self.finish_deferred def test_delete_no_such_chid(self): @@ -805,8 +783,10 @@ def fixup_messages(result): self.finish_deferred.addCallback(handle_finish) self.finish_deferred.addBoth(fixup_messages) - self.reg.delete(self._make_req("test", "test", - dummy_uaid.hex, str(uuid.uuid4()))) + self.reg.delete(router_type="test", + router_token="test", + uaid=dummy_uaid.hex, + chid=str(uuid.uuid4())) return self.finish_deferred def test_delete_uaid(self): @@ -828,7 +808,9 @@ def handle_finish(value, chid2): self.finish_deferred.addCallback(handle_finish, chid2) self.reg.request.headers["Authorization"] = self.auth - self.reg.delete(self._make_req("simplepush", "test", dummy_uaid.hex)) + self._delete(router_type="simplepush", + router_token="test", + uaid=dummy_uaid.hex) return self.finish_deferred def test_delete_bad_uaid(self): @@ -838,7 +820,9 @@ def handle_finish(value): self.status_mock.assert_called_with(401, reason=None) self.finish_deferred.addCallback(handle_finish) - self.reg.delete(self._make_req("test", "test", "invalid")) + self._delete(router_type="test", + router_token="test", + uaid="invalid") return self.finish_deferred def test_delete_orphans(self): @@ -850,7 +834,9 @@ def handle_finish(value): self.router_mock.drop_user = Mock() self.router_mock.drop_user.return_value = False self.finish_deferred.addCallback(handle_finish) - self.reg.delete(self._make_req("test", "test", dummy_uaid.hex)) + self._delete(router_type="test", + router_token="test", + uaid=dummy_uaid.hex) return self.finish_deferred def test_delete_bad_auth(self, *args): @@ -860,7 +846,9 @@ def handle_finish(value): self.status_mock.assert_called_with(401, reason=None) self.finish_deferred.addCallback(handle_finish) - self.reg.delete(self._make_req("test", "test", dummy_uaid.hex)) + self._delete(router_type="test", + router_token="test", + uaid=dummy_uaid.hex) return self.finish_deferred def test_delete_bad_router(self): @@ -870,7 +858,9 @@ def handle_finish(value): self.status_mock.assert_called_with(400, reason=None) self.finish_deferred.addCallback(handle_finish) - self.reg.delete(self._make_req("invalid", "test", dummy_uaid.hex)) + self._delete(router_type="invalid", + router_token="test", + uaid=dummy_uaid.hex) return self.finish_deferred def test_get(self): @@ -889,10 +879,10 @@ def handle_finish(value): self.finish_deferred.addCallback(handle_finish) self.settings.message.all_channels = Mock() self.settings.message.all_channels.return_value = (True, chids) - self.reg.get(self._make_req( + self._get( router_type="test", router_token="test", - uaid=dummy_uaid.hex)) + uaid=dummy_uaid.hex) return self.finish_deferred def test_get_no_uaid(self): @@ -902,7 +892,7 @@ def handle_finish(value): self.status_mock.assert_called_with(410, reason=None) self.finish_deferred.addCallback(handle_finish) - self.reg.get(self._make_req( + self._get( router_type="test", - router_token="test")) + router_token="test") return self.finish_deferred diff --git a/autopush/tests/test_log_check.py b/autopush/tests/test_log_check.py index 069866bb..2c50c0a6 100644 --- a/autopush/tests/test_log_check.py +++ b/autopush/tests/test_log_check.py @@ -71,5 +71,5 @@ def handle_finish(value): eq_(write_args.get('error'), 'Test Failure') self.finish_deferred.addCallback(handle_finish) - self.lch.get('CRIT') + self.lch.get(err_type='CRIT') return self.finish_deferred diff --git a/autopush/tests/test_web_webpush.py b/autopush/tests/test_web_webpush.py index f17c3578..c483c51f 100644 --- a/autopush/tests/test_web_webpush.py +++ b/autopush/tests/test_web_webpush.py @@ -48,7 +48,6 @@ def setUp(self): self.wp = WebPushHandler(Application(), self.request_mock, ap_settings=settings) - self.wp.path_kwargs = {} self.status_mock = self.wp.set_status = Mock() self.write_mock = self.wp.write = Mock() self.wp.log = Mock(spec=Logger) @@ -86,7 +85,7 @@ def handle_finish(result): self.finish_deferred.addCallback(handle_finish) - self.wp.post("v1", dummy_token) + self.wp.post(api_ver="v1", token=dummy_token) return self.finish_deferred def test_router_returns_data_without_detail(self): @@ -124,8 +123,7 @@ def handle_finish(result): self.finish_deferred.addCallback(handle_finish) self.fernet_mock.decrypt.return_value = 'invalid key' self.request_mock.headers['crypto-key'] = 'dummy_key' - self.wp.path_kwargs = dict(token='ignored', api_ver='v1') - self.wp.post() + self.wp.post(token='ignored', api_ver='v1') return self.finish_deferred def test_request_bad_v1_id(self): @@ -134,8 +132,7 @@ def handle_finish(result): self.finish_deferred.addCallback(handle_finish) self.fernet_mock.decrypt.return_value = 'tooshort' - self.wp.path_kwargs = dict(token='ignored', api_ver='v1') - self.wp.post() + self.wp.post(token='ignored', api_ver='v1') return self.finish_deferred def test_request_bad_v2_id_short(self): @@ -145,8 +142,7 @@ def handle_finish(result): self.finish_deferred.addCallback(handle_finish) self.fernet_mock.decrypt.return_value = 'tooshort' self.request_mock.headers['authorization'] = 'dummy_key' - self.wp.path_kwargs = dict(token='ignored', api_ver='v2') - self.wp.post() + self.wp.post(token='ignored', api_ver='v2') return self.finish_deferred def test_request_bad_v2_id_missing_pubkey(self): @@ -157,8 +153,7 @@ def handle_finish(result): self.fernet_mock.decrypt.return_value = 'a' * 64 self.request_mock.headers['crypto-key'] = 'key_id=dummy_key' self.request_mock.headers['authorization'] = 'dummy_key' - self.wp.path_kwargs = dict(token='ignored', api_ver='v2') - self.wp.post() + self.wp.post(token='ignored', api_ver='v2') return self.finish_deferred def test_request_v2_id_variant_pubkey(self): @@ -170,14 +165,13 @@ def handle_finish(result): variant_key = base64.urlsafe_b64encode("0V0" + ('a' * 85)) self.request_mock.headers['crypto-key'] = 'p256ecdsa=' + variant_key self.request_mock.headers['authorization'] = 'webpush dummy.key' - self.wp.path_kwargs = dict(token='ignored', api_ver='v1') self.ap_settings.router.get_uaid = Mock() self.ap_settings.router.get_uaid.return_value = dict( uaid=dummy_uaid, chid=dummy_chid, router_type="gcm" ) - self.wp.post() + self.wp.post(token='ignored', api_ver='v1') return self.finish_deferred def test_request_v2_id_no_crypt_auth(self): @@ -187,14 +181,13 @@ def handle_finish(result): self.finish_deferred.addCallback(handle_finish) self.fernet_mock.decrypt.return_value = 'a' * 32 self.request_mock.headers['authorization'] = 'webpush dummy.key' - self.wp.path_kwargs = dict(token='ignored', api_ver='v1') self.ap_settings.router.get_uaid = Mock() self.ap_settings.router.get_uaid.return_value = dict( uaid=dummy_uaid, chid=dummy_chid, router_type="gcm" ) - self.wp.post() + self.wp.post(token='ignored', api_ver='v1') return self.finish_deferred def test_request_bad_v2_id_bad_pubkey(self): @@ -205,6 +198,5 @@ def handle_finish(result): self.fernet_mock.decrypt.return_value = 'a' * 64 self.request_mock.headers['crypto-key'] = 'p256ecdsa=Invalid!' self.request_mock.headers['authorization'] = 'dummy_key' - self.wp.path_kwargs = dict(token='ignored', api_ver='v2') - self.wp.post() + self.wp.post(token='ignored', api_ver='v2') return self.finish_deferred diff --git a/autopush/web/base.py b/autopush/web/base.py index 545fcfc4..2cf40308 100644 --- a/autopush/web/base.py +++ b/autopush/web/base.py @@ -6,7 +6,7 @@ from boto.dynamodb2.exceptions import ProvisionedThroughputExceededException from boto.exception import BotoServerError from marshmallow.schema import UnmarshalResult # noqa -from typing import Any # noqa +from typing import Any, Callable # noqa from twisted.internet.threads import deferToThread from twisted.logger import Logger @@ -44,14 +44,14 @@ class ThreadedValidate(object): def __init__(self, schema): self.schema = schema - def _validate_request(self, request_handler): - # type: (BaseWebHandler) -> UnmarshalResult + def _validate_request(self, request_handler, *args, **kwargs): + # type: (BaseWebHandler, *Any, **Any) -> UnmarshalResult """Validates a schema_class against a cyclone request""" data = { "headers": request_handler.request.headers, "body": request_handler.request.body, - "path_args": request_handler.path_args, - "path_kwargs": request_handler.path_kwargs, + "path_args": args, + "path_kwargs": kwargs, "arguments": request_handler.request.arguments, } schema = self.schema() @@ -59,13 +59,13 @@ def _validate_request(self, request_handler): schema.context["log"] = self.log return schema.load(data) - def _call_func(self, result, func, request_handler, *args, **kwargs): - output, errors = result + def _call_func(self, result, func, request_handler): + # type: (UnmarshalResult, Callable, BaseWebHandler) -> Any + output_kwargs, errors = result if errors: request_handler._write_validation_err(errors) else: - request_handler.valid_input = output - return func(request_handler, *args, **kwargs) + return func(request_handler, **output_kwargs) def _track_validation_timing(self, result, request_handler, start_time): # type: (Any, BaseWebHandler, float) -> Any @@ -79,11 +79,11 @@ def wrapper(request_handler, *args, **kwargs): start_time = time.time() # Wrap the handler in @cyclone.web.synchronous request_handler._auto_finish = False - d = deferToThread(self._validate_request, request_handler) + d = deferToThread( + self._validate_request, request_handler, *args, **kwargs) d.addBoth(self._track_validation_timing, request_handler, start_time) - d.addCallback(self._call_func, func, request_handler, *args, - **kwargs) + d.addCallback(self._call_func, func, request_handler) d.addErrback(request_handler._overload_err) d.addErrback(request_handler._boto_err) d.addErrback(request_handler._validation_err) @@ -101,11 +101,17 @@ def validate(cls, schema): will attach equivilant functionality to the method handler. Calling `self.finish()` is needed on decorated handlers. + Validated requests are deserialized into the **kwargs of the wrapped + request handler method. + .. code-block:: python + class MySchema(Schema): + uaid = fields.UUID(allow_none=True) + class MyHandler(cyclone.web.RequestHandler): @threaded_validate(MySchema()) - def post(self): + def post(self, uaid=None): ... """ diff --git a/autopush/web/health.py b/autopush/web/health.py index 0ce120e3..863a532e 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -13,8 +13,12 @@ class HealthHandler(BaseWebHandler): """HTTP Health Handler""" + def authenticate_peer_cert(self): + """Skip authentication checks""" + pass + @cyclone.web.asynchronous - def get(self, *args, **kwargs): + def get(self): """HTTP Get Returns basic information about the version and how many clients are @@ -33,10 +37,6 @@ def get(self, *args, **kwargs): ]) dl.addBoth(self._finish_response) - def authenticate_peer_cert(self): - """Skip authentication checks""" - pass - def _check_table(self, table): """Checks the tables known about in DynamoDB""" d = deferToThread(table.connection.list_tables) @@ -83,7 +83,7 @@ def authenticate_peer_cert(self): """skip authentication checks""" pass - def get(self, *args, **kwargs): + def get(self): """HTTP Get Returns that this node is alive, and the version. @@ -106,7 +106,7 @@ def authenticate_peer_cert(self): """skip authentication checks""" pass # pragma: nocover - def get(self, *args, **kwargs): + def get(self): """HTTP Get Returns that this node is alive, and the version. diff --git a/autopush/web/log_check.py b/autopush/web/log_check.py index 4576554c..46e13dec 100644 --- a/autopush/web/log_check.py +++ b/autopush/web/log_check.py @@ -1,4 +1,5 @@ from marshmallow import Schema, fields, pre_load +from typing import Optional # noqa from autopush.exceptions import LogCheckError from autopush.web.base import threaded_validate, BaseWebHandler @@ -6,12 +7,11 @@ class LogCheckSchema(Schema): """Empty schema for log check""" - fields.err_type = fields.Str(allow_none=True) + err_type = fields.Str(allow_none=True) @pre_load def extract_data(self, req): - # req['path_kwargs'] could be set to None, which would be returned - return dict(err_type=(req.get('path_kwargs') or {}).get('err_type')) + return dict(err_type=req['path_kwargs'].get('err_type')) class LogCheckHandler(BaseWebHandler): @@ -21,7 +21,8 @@ def authenticate_peer_cert(self): pass @threaded_validate(LogCheckSchema) - def get(self, err_type=None, *args, **kwargs): + def get(self, err_type=None): + # type: (Optional[str]) -> None """HTTP GET Generate a dummy error message for logging diff --git a/autopush/web/message.py b/autopush/web/message.py index 749679e1..76ae7ef2 100644 --- a/autopush/web/message.py +++ b/autopush/web/message.py @@ -1,6 +1,7 @@ from cryptography.fernet import InvalidToken from marshmallow import Schema, fields, pre_load from twisted.internet.threads import deferToThread +from twisted.internet.defer import Deferred # noqa from autopush.exceptions import InvalidRequest, InvalidTokenException from autopush.utils import WebPushNotification @@ -12,12 +13,7 @@ class MessageSchema(Schema): @pre_load def extract_data(self, req): - message_id = None - if req['path_args']: - message_id = req['path_args'][0] - message_id = req['path_kwargs'].get( - 'message_id', - message_id) + message_id = req['path_kwargs'].get('message_id') if not message_id: raise InvalidRequest("Missing Token", status_code=400) @@ -37,7 +33,8 @@ class MessageHandler(BaseWebHandler): cors_response_headers = ("location",) @threaded_validate(MessageSchema) - def delete(self, *args, **kwargs): + def delete(self, notification): + # type: (WebPushNotification) -> Deferred """Drops a pending message. The message will only be removed from DynamoDB. Messages that were @@ -46,9 +43,8 @@ def delete(self, *args, **kwargs): """ - notif = self.valid_input['notification'] d = deferToThread(self.ap_settings.message.delete_message, - notif) + notification) d.addCallback(self._delete_completed) self._db_error_handling(d) return d diff --git a/autopush/web/registration.py b/autopush/web/registration.py index f4e5cbbc..0a05a563 100644 --- a/autopush/web/registration.py +++ b/autopush/web/registration.py @@ -8,15 +8,18 @@ Tuple ) +from attr import attrs, attrib from boto.dynamodb2.exceptions import ItemNotFound from cryptography.hazmat.primitives import constant_time from marshmallow import ( Schema, fields, pre_load, + post_load, validates_schema ) from twisted.internet import defer +from twisted.internet.defer import Deferred # noqa from twisted.internet.threads import deferToThread from autopush.db import generate_last_connect, hasher @@ -32,11 +35,11 @@ class RegistrationSchema(Schema): - uaid = fields.UUID(allow_none=True) - chid = fields.Str(allow_none=True) router_type = fields.Str() router_token = fields.Str() router_data = fields.Dict() + uaid = fields.UUID(allow_none=True) + chid = fields.Str(allow_none=True) auth = fields.Str(allow_none=True) @pre_load @@ -49,6 +52,7 @@ def extract_data(self, req): raise InvalidRequest("Invalid Request body", status_code=401, errno=108) + # UAID and CHID may be empty. This can trigger different behaviors # in the handlers, so we can't set default values here. uaid = req['path_kwargs'].get('uaid') @@ -77,23 +81,23 @@ def extract_data(self, req): status_code=410, errno=106) return dict( - uaid=uaid, - chid=chid, router_type=req['path_kwargs'].get('router_type'), router_token=req['path_kwargs'].get('router_token'), router_data=router_data, + uaid=uaid, + chid=chid, auth=req.get('headers', {}).get("Authorization"), ) @validates_schema(skip_on_field_errors=True) def validate_data(self, data): settings = self.context['settings'] - try: - data['router'] = settings.routers[data['router_type']] - except KeyError: + + if data['router_type'] not in settings.routers: raise InvalidRequest("Invalid router", status_code=400, errno=108) + if data.get('uaid'): request_pref_header = {'www-authenticate': PREF_SCHEME} try: @@ -133,6 +137,41 @@ def validate_data(self, data): errno=109, headers=request_pref_header) + @post_load + def handler_kwargs(self, data): + # not used + data.pop('auth') + router_type = data.pop('router_type') + data['rinfo'] = RouterInfo( + router=self.context['settings'].routers[router_type], + type_=router_type, + token=data.pop('router_token'), + data=data.pop('router_data') + ) + + +@attrs(slots=True) +class RouterInfo(object): + """Bundle of Router registration information""" + + router = attrib() # type: Any + type_ = attrib() # type: str + token = attrib() # type: str + data = attrib() # type: JSONDict + + def register(self, uaid, **kwargs): + # type: (uuid.UUID, **Any) -> None + self.router.register( + uaid.hex, router_data=self.data, app_id=self.token, **kwargs) + + def amend_endpoint_response(self, response): + # type: (JSONDict) -> None + self.router.amend_endpoint_response(response, self.data) + + @property + def app_server_key(self): + return self.data.get('key') + class RegistrationHandler(BaseWebHandler): """Handle the Bridge services endpoints""" @@ -142,101 +181,70 @@ class RegistrationHandler(BaseWebHandler): # Cyclone HTTP Methods ############################################################# @threaded_validate(RegistrationSchema) - def post(self, *args, **kwargs): + def post(self, rinfo, uaid=None, chid=None): + # type: (RouterInfo, Optional[uuid.UUID], Optional[str]) -> Deferred """HTTP POST Endpoint generation and optionally router type/data registration. - """ self.add_header("Content-Type", "application/json") - - uaid = self.valid_input['uaid'] - router = self.valid_input["router"] - router_type = self.valid_input["router_type"] - router_token = self.valid_input.get("router_token") - router_data = self.valid_input['router_data'] + self.ap_settings.metrics.increment("updates.client.register", + tags=self.base_tags()) # If the client didn't provide a CHID, make one up. - # Note, valid_input may explicitly set "chid" to None + # Note, RegistrationSchema may explicitly set "chid" to None # THIS VALUE MUST MATCH WHAT'S SPECIFIED IN THE BRIDGE CONNECTIONS. # currently hex formatted. - chid = router_data["channelID"] = (self.valid_input["chid"] or - uuid.uuid4().hex) - self.ap_settings.metrics.increment("updates.client.register", - tags=self.base_tags()) + if not chid: + chid = uuid.uuid4().hex + rinfo.data["channelID"] = chid if not uaid: uaid = uuid.uuid4() - d = defer.execute( - router.register, - uaid.hex, router_data=router_data, app_id=router_token, - uri=self.request.uri) + d = defer.execute(rinfo.register, uaid, uri=self.request.uri) d.addCallback( lambda _: deferToThread(self._register_user_and_channel, - uaid, chid, router, router_type, router_data) + uaid, chid, rinfo) ) - d.addCallback(self._write_endpoint, - uaid, chid, router, router_data) + d.addCallback( + self._write_endpoint, uaid, chid, rinfo, new_uaid=True) d.addErrback(self._router_fail_err) d.addErrback(self._response_err) else: d = deferToThread(self._register_channel, - uaid, chid, router_data.get("key")) - d.addCallback(self._write_endpoint, uaid, chid) + uaid, chid, rinfo.app_server_key) + d.addCallback(self._write_endpoint, uaid, chid, rinfo) d.addErrback(self._response_err) return d @threaded_validate(RegistrationSchema) - def put(self, *args, **kwargs): + def put(self, rinfo, uaid=None, chid=None): + # type: (RouterInfo, Optional[uuid.UUID], Optional[str]) -> Deferred """HTTP PUT Update router type/data for a UAID. """ - uaid = self.valid_input['uaid'] - router = self.valid_input['router'] - router_type = self.valid_input['router_type'] - router_token = self.valid_input['router_token'] - router_data = self.valid_input['router_data'] self.add_header("Content-Type", "application/json") - d = defer.execute( - router.register, - uaid.hex, router_data=router_data, app_id=router_token, - uri=self.request.uri) + d = defer.execute(rinfo.register, uaid, uri=self.request.uri) d.addCallback( - lambda _: - deferToThread(self._register_user, uaid, router_data, router_type) + lambda _: deferToThread(self._register_user, uaid, rinfo) ) d.addCallback(self._success) d.addErrback(self._router_fail_err) d.addErrback(self._response_err) return d - def _delete_channel(self, uaid, chid): - message = self.ap_settings.message - if not message.unregister_channel(uaid.hex, chid): - raise ItemNotFound("ChannelID not found") - - def _delete_uaid(self, uaid, router): - self.log.info(format="Dropping User", code=101, - uaid_hash=hasher(uaid.hex)) - if not router.drop_user(uaid.hex): - raise ItemNotFound("UAID not found") - - def _check_uaid(self, uaid): - if not uaid: - raise ItemNotFound("UAID not found") - @threaded_validate(RegistrationSchema) - def get(self, *args, **kwargs): + def get(self, uaid=None, **kwargs): + # type: (Optional[uuid.UUID], **Any) -> Deferred """HTTP GET Return a list of known channelIDs for a given UAID """ - uaid = self.valid_input['uaid'] self.add_header("Content-Type", "application/json") d = defer.execute(self._check_uaid, uaid) d.addCallback( @@ -249,26 +257,24 @@ def get(self, *args, **kwargs): return d @threaded_validate(RegistrationSchema) - def delete(self, *args, **kwargs): + def delete(self, uaid=None, chid=None, **kwargs): + # type: (Optional[uuid.UUID], Optional[str], **Any) -> Deferred """HTTP DELETE Delete all pending records for the given channel or UAID """ - if self.valid_input['chid']: + if chid: # mark channel as dead self.ap_settings.metrics.increment("updates.client.unregister", tags=self.base_tags()) - d = deferToThread(self._delete_channel, - self.valid_input['uaid'], - self.valid_input['chid']) + d = deferToThread(self._delete_channel, uaid, chid) d.addCallback(self._success) d.addErrback(self._chid_not_found_err) d.addErrback(self._response_err) return d # nuke all records for the UAID - d = deferToThread(self._delete_uaid, self.valid_input['uaid'], - self.ap_settings.router) + d = deferToThread(self._delete_uaid, uaid) d.addCallback(self._success) d.addErrback(self._uaid_not_found_err) d.addErrback(self._response_err) @@ -295,25 +301,33 @@ def _chid_not_found_err(self, fail): ############################################################# # Callbacks ############################################################# - def _register_user_and_channel(self, - uaid, # type: uuid.UUID - chid, # type: str - router, # type: Any - router_type, # type: str - router_data # type: JSONDict - ): - # type: (...) -> str + def _delete_channel(self, uaid, chid): + if not self.ap_settings.message.unregister_channel(uaid.hex, chid): + raise ItemNotFound("ChannelID not found") + + def _delete_uaid(self, uaid): + self.log.info(format="Dropping User", code=101, + uaid_hash=hasher(uaid.hex)) + if not self.ap_settings.router.drop_user(uaid.hex): + raise ItemNotFound("UAID not found") + + def _check_uaid(self, uaid): + if not uaid: + raise ItemNotFound("UAID not found") + + def _register_user_and_channel(self, uaid, chid, rinfo): + # type: (uuid.UUID, str, RouterInfo) -> str """Register a new user/channel, return its endpoint""" - self._register_user(uaid, router_type, router_data) - return self._register_channel(uaid, chid, router_data.get("key")) + self._register_user(uaid, rinfo) + return self._register_channel(uaid, chid, rinfo.app_server_key) - def _register_user(self, uaid, router_type, router_data): - # type: (uuid.UUID, str, JSONDict) -> None + def _register_user(self, uaid, rinfo): + # type: (uuid.UUID, RouterInfo) -> None """Save a new user record""" self.ap_settings.router.register_user(dict( uaid=uaid.hex, - router_type=router_type, - router_data=router_data, + router_type=rinfo.type_, + router_data=rinfo.data, connected_at=ms_time(), last_connect=generate_last_connect(), )) @@ -324,25 +338,18 @@ def _register_channel(self, uaid, chid, app_server_key): self.ap_settings.message.register_channel(uaid.hex, chid) return self.ap_settings.make_endpoint(uaid.hex, chid, app_server_key) - def _write_endpoint(self, - endpoint, # type: str - uaid, # type: uuid.UUID - chid, # type: str - router=None, # type: Optional[Any] - router_data=None # type: Optional[JSONDict] - ): - # type: (...) -> None + def _write_endpoint(self, endpoint, uaid, chid, rinfo, new_uaid=False): + # type: (str, uuid.UUID, str, RouterInfo, bool) -> None """Write the JSON response of the created endpoint""" response = dict(channelID=chid, endpoint=endpoint) - if router_data is not None: - # a new uaid + if new_uaid: secret = None if self.ap_settings.bear_hash_key: secret = generate_hash( self.ap_settings.bear_hash_key[0], uaid.hex) response.update(uaid=uaid.hex, secret=secret) # Apply any router specific fixes to the outbound response. - router.amend_endpoint_response(response, router_data) + rinfo.amend_endpoint_response(response) self.write(json.dumps(response)) self.log.debug("Endpoint registered via HTTP", client_info=self._client_info) diff --git a/autopush/web/simplepush.py b/autopush/web/simplepush.py index e25a5d85..86b8c77b 100644 --- a/autopush/web/simplepush.py +++ b/autopush/web/simplepush.py @@ -11,7 +11,9 @@ validates_schema, ) -from twisted.internet.defer import Deferred +from twisted.internet.defer import Deferred # noqa +from twisted.internet.defer import maybeDeferred +from typing import Any, Dict # noqa from autopush.exceptions import ( InvalidRequest, @@ -101,29 +103,27 @@ class SimplePushHandler(BaseWebHandler): cors_methods = "PUT" @threaded_validate(SimplePushRequestSchema) - def put(self, *args, **kwargs): - sub = self.valid_input["subscription"] - user_data = sub["user_data"] - router = self.ap_settings.routers[user_data["router_type"]] - self._client_info["uaid"] = hasher(user_data.get("uaid")) - self._client_info["channel_id"] = user_data.get("chid") - self._client_info["message_id"] = self.valid_input["version"] - self._client_info["router_key"] = user_data["router_type"] - + def put(self, subscription, version, data): + # type: (Dict[str, Any], str, str) -> Deferred + user_data = subscription["user_data"] + self._client_info.update( + uaid=hasher(user_data.get("uaid")), + channel_id=user_data.get("chid"), + message_id=version, + router_key=user_data["router_type"] + ) notification = Notification( - version=self.valid_input["version"], - data=self.valid_input["data"], - channel_id=str(sub["chid"]), + version=version, + data=data, + channel_id=str(subscription["chid"]), ) - d = Deferred() - d.addCallback(router.route_notification, user_data) + router = self.ap_settings.routers[user_data["router_type"]] + d = maybeDeferred(router.route_notification, notification, user_data) d.addCallback(self._router_completed, user_data, "") d.addErrback(self._router_fail_err) d.addErrback(self._response_err) - - # Call the prepared router - d.callback(notification) + return d def _router_completed(self, response, uaid_data, warning=""): """Called after router has completed successfully""" diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 7503ccb4..a2c3b43b 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -15,8 +15,14 @@ from marshmallow_polyfield import PolyField from marshmallow.validate import OneOf from twisted.logger import Logger # noqa -from twisted.internet.defer import Deferred +from twisted.internet.defer import Deferred # noqa +from twisted.internet.defer import maybeDeferred from twisted.internet.threads import deferToThread +from typing import ( # noqa + Any, + Dict, + Optional +) from autopush.crypto_key import CryptoKey from autopush.db import dump_uaid, hasher @@ -388,33 +394,37 @@ class WebPushHandler(BaseWebHandler): cors_response_headers = ("location", "www-authenticate") @threaded_validate(WebPushRequestSchema) - def post(self, *args, **kwargs): + def post(self, + subscription, # type: Dict[str, Any] + notification, # type: WebPushNotification + jwt=None, # type: Optional[Dict[str, str]] + **kwargs # type: Any + ): + # type: (...) -> Deferred # Store Vapid info if present - jwt = self.valid_input.get("jwt") if jwt: self._client_info["jwt_crypto_key"] = jwt["jwt_crypto_key"] for i in jwt["jwt_data"]: self._client_info["jwt_" + i] = jwt["jwt_data"][i] - user_data = self.valid_input["subscription"]["user_data"] + user_data = subscription["user_data"] + self._client_info.update( + message_id=notification.message_id, + uaid=hasher(user_data.get("uaid")), + channel_id=user_data.get("chid"), + router_key=user_data["router_type"], + message_size=len(notification.data or ""), + ttl=notification.ttl, + version=notification.version + ) + router = self.ap_settings.routers[user_data["router_type"]] - notification = self.valid_input["notification"] - self._client_info["message_id"] = notification.message_id - self._client_info["uaid"] = hasher(user_data.get("uaid")) - self._client_info["channel_id"] = user_data.get("chid") - self._client_info["router_key"] = user_data["router_type"] - self._client_info["message_size"] = len(notification.data or "") - self._client_info["ttl"] = notification.ttl - self._client_info["version"] = notification.version self._router_time = time.time() - d = Deferred() - d.addCallback(router.route_notification, user_data) + d = maybeDeferred(router.route_notification, notification, user_data) d.addCallback(self._router_completed, user_data, "") d.addErrback(self._router_fail_err) d.addErrback(self._response_err) - - # Call the prepared router - d.callback(notification) + return d def _router_completed(self, response, uaid_data, warning=""): """Called after router has completed successfully"""