diff --git a/.coveragerc b/.coveragerc index de5541b4..4623533e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,3 @@ [report] omit = *noseplugin* +show_missing = true diff --git a/autopush/exceptions.py b/autopush/exceptions.py index 259445cf..de261b26 100644 --- a/autopush/exceptions.py +++ b/autopush/exceptions.py @@ -7,3 +7,12 @@ class AutopushException(Exception): class InvalidTokenException(Exception): """Invalid URL token Exception""" + + +class InvalidRequest(AutopushException): + """Invalid request exception, may include custom status_code and message + to write for the error""" + def __init__(self, message, status_code=400, errno=None): + super(AutopushException, self).__init__(message) + self.status_code = status_code + self.errno = errno diff --git a/autopush/main.py b/autopush/main.py index 1f45d23a..438d58fd 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -29,6 +29,7 @@ DefaultResource, StatusResource, ) +from autopush.web.simplepush import SimplePushHandler from autopush.senderids import SenderIDs, SENDERID_EXPRY, DEFAULT_BUCKET @@ -501,8 +502,10 @@ def endpoint_main(sysargs=None, use_files=True): # Endpoint HTTP router site = cyclone.web.Application([ - (r"/push/(?:(v\d+)\/)?([^\/]+)", EndpointHandler, + (r"/push/(?:(?Pv\d+)\/)?(?P[^\/]+)", EndpointHandler, dict(ap_settings=settings)), + (r"/spush/(?:(?Pv\d+)\/)?(?P[^\/]+)", + SimplePushHandler, dict(ap_settings=settings)), (r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)), # PUT /register/ => connect info # GET /register/uaid => chid + endpoint diff --git a/autopush/settings.py b/autopush/settings.py index 74a63315..429aede1 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -286,8 +286,15 @@ def update(self, **kwargs): else: setattr(self, key, val) + def make_simplepush_endpoint(self, uaid, chid): + """Create a simplepush endpoint""" + root = self.endpoint_url + "/spush/" + base = (uaid.replace('-', '').decode("hex") + + chid.replace('-', '').decode("hex")) + return root + 'v1/' + self.fernet.encrypt(base).strip('=') + def make_endpoint(self, uaid, chid, key=None): - """Create an v1 or v2 endpoint from the indentifiers. + """Create an v1 or v2 WebPush endpoint from the identifiers. Both endpoints use bytes instead of hex to reduce ID length. v0 is uaid.hex + ':' + chid.hex and is deprecated. diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index 0958691e..f4050dcf 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -324,6 +324,7 @@ def setUp(self): DefaultResource, StatusResource, ) + from autopush.web.simplepush import SimplePushHandler from twisted.web.server import Site router_table = os.environ.get("ROUTER_TABLE", "router_int_test") @@ -374,6 +375,8 @@ def setUp(self): site = cyclone.web.Application([ (r"/push/(v\d+)?/?([^\/]+)", EndpointHandler, dict(ap_settings=settings)), + (r"/spush/(?:(?Pv\d+)\/)?(?P[^\/]+)", + SimplePushHandler, dict(ap_settings=settings)), (r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)), # PUT /register/ => connect info # GET /register/uaid => chid + endpoint diff --git a/autopush/tests/test_main.py b/autopush/tests/test_main.py index 89841d6e..77dea8ac 100644 --- a/autopush/tests/test_main.py +++ b/autopush/tests/test_main.py @@ -4,6 +4,7 @@ from mock import Mock, patch from moto import mock_dynamodb2, mock_s3 from nose.tools import eq_ +from twisted.internet.defer import Deferred from twisted.trial import unittest as trialtest from autopush.main import ( @@ -73,13 +74,15 @@ def test_update_rotating_tables(self): settings.message_tables = {} # Get the deferred back + e = Deferred() d = settings.update_rotating_tables() def check_tables(result): eq_(len(settings.message_tables), 1) d.addCallback(check_tables) - return d + d.addBoth(lambda x: e.callback(True)) + return e def test_update_rotating_tables_month_end(self): today = datetime.date.today() @@ -118,13 +121,15 @@ def test_update_not_needed(self): settings.message_tables = {} # Get the deferred back + e = Deferred() d = settings.update_rotating_tables() def check_tables(result): eq_(len(settings.message_tables), 0) d.addCallback(check_tables) - return d + d.addBoth(lambda x: e.callback(True)) + return e class ConnectionMainTestCase(unittest.TestCase): diff --git a/autopush/tests/test_web_base.py b/autopush/tests/test_web_base.py new file mode 100644 index 00000000..086e75e0 --- /dev/null +++ b/autopush/tests/test_web_base.py @@ -0,0 +1,254 @@ +import sys +import uuid + +from cyclone.web import Application +from mock import Mock, patch +from moto import mock_dynamodb2 +from nose.tools import eq_ +from twisted.internet.defer import Deferred +from twisted.logger import Logger +from twisted.python.failure import Failure +from twisted.trial import unittest + +from autopush.db import ( + create_rotating_message_table, + hasher, + ProvisionedThroughputExceededException, +) +from autopush.exceptions import InvalidRequest +from autopush.settings import AutopushSettings + +dummy_request_id = "11111111-1234-1234-1234-567812345678" +dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) +dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) +mock_dynamodb2 = mock_dynamodb2() + + +def setUp(): + mock_dynamodb2.start() + create_rotating_message_table() + + +def tearDown(): + mock_dynamodb2.stop() + + +class TestBase(unittest.TestCase): + CORS_METHODS = "POST,PUT" + CORS_HEADERS = ','.join( + ["content-encoding", "encryption", + "crypto-key", "ttl", + "encryption-key", "content-type", + "authorization"] + ) + CORS_RESPONSE_HEADERS = ','.join( + ["location", "www-authenticate"] + ) + + @patch('uuid.uuid4', return_value=uuid.UUID(dummy_request_id)) + def setUp(self, t): + from autopush.web.base import BaseHandler + + settings = AutopushSettings( + hostname="localhost", + statsd_host=None, + ) + + self.request_mock = Mock(body=b'', arguments={}, + headers={"ttl": "0"}, + host='example.com:8080') + + self.base = BaseHandler(Application(), + self.request_mock, + ap_settings=settings) + self.status_mock = self.base.set_status = Mock() + self.write_mock = self.base.write = Mock() + self.base.log = Mock(spec=Logger) + d = self.finish_deferred = Deferred() + self.base.finish = lambda: d.callback(True) + + # Attach some common cors stuff for testing + self.base.cors_methods = "POST,PUT" + self.base.cors_request_headers = ["content-encoding", "encryption", + "crypto-key", "ttl", + "encryption-key", "content-type", + "authorization"] + self.base.cors_response_headers = ["location", "www-authenticate"] + + def test_cors(self): + ch1 = "Access-Control-Allow-Origin" + ch2 = "Access-Control-Allow-Methods" + ch3 = "Access-Control-Allow-Headers" + ch4 = "Access-Control-Expose-Headers" + base = self.base + base.ap_settings.cors = False + assert base._headers.get(ch1) != "*" + assert base._headers.get(ch2) != self.CORS_METHODS + assert base._headers.get(ch3) != self.CORS_HEADERS + assert base._headers.get(ch4) != self.CORS_RESPONSE_HEADERS + + base.clear_header(ch1) + base.clear_header(ch2) + base.ap_settings.cors = True + self.base.prepare() + eq_(base._headers[ch1], "*") + eq_(base._headers[ch2], self.CORS_METHODS) + eq_(base._headers[ch3], self.CORS_HEADERS) + eq_(base._headers[ch4], self.CORS_RESPONSE_HEADERS) + + def test_cors_head(self): + ch1 = "Access-Control-Allow-Origin" + ch2 = "Access-Control-Allow-Methods" + ch3 = "Access-Control-Allow-Headers" + ch4 = "Access-Control-Expose-Headers" + base = self.base + base.ap_settings.cors = True + base.prepare() + base.head(None) + eq_(base._headers[ch1], "*") + eq_(base._headers[ch2], self.CORS_METHODS) + eq_(base._headers[ch3], self.CORS_HEADERS) + eq_(base._headers[ch4], self.CORS_RESPONSE_HEADERS) + + def test_cors_options(self): + ch1 = "Access-Control-Allow-Origin" + ch2 = "Access-Control-Allow-Methods" + ch3 = "Access-Control-Allow-Headers" + ch4 = "Access-Control-Expose-Headers" + base = self.base + base.ap_settings.cors = True + base.prepare() + base.options(None) + eq_(base._headers[ch1], "*") + eq_(base._headers[ch2], self.CORS_METHODS) + eq_(base._headers[ch3], self.CORS_HEADERS) + eq_(base._headers[ch4], self.CORS_RESPONSE_HEADERS) + + def test_write_error(self): + """ Write error is triggered by sending the app a request + with an invalid method (e.g. "put" instead of "PUT"). + This is not code that is triggered within normal flow, but + by the cyclone wrapper. + """ + class testX(Exception): + pass + + try: + raise testX() + except: + exc_info = sys.exc_info() + + self.base.write_error(999, exc_info=exc_info) + self.status_mock.assert_called_with(999) + eq_(self.base.log.failure.called, True) + + def test_write_error_no_exc(self): + """ Write error is triggered by sending the app a request + with an invalid method (e.g. "put" instead of "PUT"). + This is not code that is triggered within normal flow, but + by the cyclone wrapper. + """ + self.base.write_error(999) + self.status_mock.assert_called_with(999) + eq_(self.base.log.failure.called, True) + + def test_init_info(self): + h = self.request_mock.headers + h["user-agent"] = "myself" + self.request_mock.remote_ip = "local1" + self.request_mock.headers["ttl"] = "0" + self.request_mock.headers["authorization"] = "bearer token fred" + d = self.base._init_info() + eq_(d["request_id"], dummy_request_id) + eq_(d["user_agent"], "myself") + eq_(d["remote_ip"], "local1") + eq_(d["message_ttl"], "0") + eq_(d["authorization"], "bearer token fred") + self.request_mock.headers["x-forwarded-for"] = "local2" + d = self.base._init_info() + eq_(d["remote_ip"], "local2") + + def test_properties(self): + eq_(self.base.uaid, "") + eq_(self.base.chid, "") + self.base.uaid = dummy_uaid + eq_(self.base._client_info["uaid_hash"], hasher(dummy_uaid)) + self.base.chid = dummy_chid + eq_(self.base._client_info['channelID'], dummy_chid) + + def test_write_response(self): + self.base._write_response(400, 103, message="Fail", + headers=dict(Location="http://a.com/")) + self.status_mock.assert_called_with(400) + + def test_validation_error(self): + try: + raise InvalidRequest("oops", errno=110) + except: + fail = Failure() + self.base._validation_err(fail) + self.status_mock.assert_called_with(400) + + def test_response_err(self): + try: + raise Exception("oops") + except: + fail = Failure() + self.base._response_err(fail) + self.status_mock.assert_called_with(500) + + def test_overload_err(self): + try: + raise ProvisionedThroughputExceededException("error", None, None) + except: + fail = Failure() + self.base._overload_err(fail) + self.status_mock.assert_called_with(503) + + def test_router_response(self): + from autopush.router.interface import RouterResponse + response = RouterResponse(headers=dict(Location="http://a.com/")) + self.base._router_response(response) + self.status_mock.assert_called_with(200) + + def test_router_response_client_error(self): + from autopush.router.interface import RouterResponse + response = RouterResponse(headers=dict(Location="http://a.com/"), + status_code=400) + self.base._router_response(response) + self.status_mock.assert_called_with(400) + + def test_router_fail_err(self): + from autopush.router.interface import RouterException + + try: + raise RouterException("error") + except: + fail = Failure() + self.base._router_fail_err(fail) + self.status_mock.assert_called_with(500) + + def test_router_fail_err_200_status(self): + from autopush.router.interface import RouterException + + try: + raise RouterException("Abort Ok", status_code=200) + except: + fail = Failure() + self.base._router_fail_err(fail) + self.status_mock.assert_called_with(200) + + def test_router_fail_err_400_status(self): + from autopush.router.interface import RouterException + + try: + raise RouterException("Abort Ok", status_code=400) + except: + fail = Failure() + self.base._router_fail_err(fail) + self.status_mock.assert_called_with(400) + + def test_write_validation_err(self): + errors = dict(data="Value too large") + self.base._write_validation_err(errors) + self.status_mock.assert_called_with(400) diff --git a/autopush/tests/test_web_validation.py b/autopush/tests/test_web_validation.py new file mode 100644 index 00000000..168c2172 --- /dev/null +++ b/autopush/tests/test_web_validation.py @@ -0,0 +1,263 @@ +import uuid + +from boto.dynamodb2.exceptions import ( + ItemNotFound, +) +from marshmallow import Schema, fields +from mock import Mock +from nose.tools import eq_, ok_, assert_raises +from twisted.internet.defer import Deferred +from twisted.trial import unittest + +from autopush.exceptions import ( + InvalidRequest, + InvalidTokenException, +) + + +dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) +dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) +dummy_token = dummy_uaid + ":" + dummy_chid + + +class InvalidSchema(Schema): + afield = fields.Integer(required=True) + + +class TestThreadedValidate(unittest.TestCase): + def _makeFUT(self, schema): + from autopush.web.validation import ThreadedValidate + return ThreadedValidate(schema) + + def _makeBasicSchema(self): + + class Basic(Schema): + pass + + return Basic() + + def _makeDummyRequest(self, method="GET", uri="/", **kwargs): + from cyclone.httpserver import HTTPRequest + req = HTTPRequest(method, uri, **kwargs) + req.connection = Mock() + return req + + def _makeReqHandler(self, request): + self._mock_errors = Mock() + from cyclone.web import RequestHandler + + class ValidateRequest(RequestHandler): + def _write_validation_err(rh, errors): + self._mock_errors(errors) + + # Minimal mocks needed for a cyclone app to work + app = Mock() + app.ui_modules = dict() + app.ui_methods = dict() + vr = ValidateRequest(app, request) + vr.ap_settings = Mock() + return vr + + def _makeFull(self, schema=None): + req = self._makeDummyRequest() + if not schema: + schema = self._makeBasicSchema() + tv = self._makeFUT(schema) + rh = self._makeReqHandler(req) + + return tv, rh + + def test_validate_load(self): + tv, rh = self._makeFull() + d, errors = tv._validate_request(rh) + eq_(errors, {}) + eq_(d, {}) + + def test_validate_invalid_schema(self): + tv, rh = self._makeFull(schema=InvalidSchema()) + d, errors = tv._validate_request(rh) + ok_("afield" in errors) + eq_(d, {}) + + def test_call_func_no_error(self): + mock_func = Mock() + tv, rh = self._makeFull() + result = tv._validate_request(rh) + tv._call_func(result, mock_func, rh) + mock_func.assert_called() + + def test_call_func_error(self): + mock_func = Mock() + tv, rh = self._makeFull(schema=InvalidSchema()) + result = tv._validate_request(rh) + tv._call_func(result, mock_func, rh) + self._mock_errors.assert_called() + eq_(len(mock_func.mock_calls), 0) + + def test_decorator(self): + from cyclone.web import RequestHandler + from autopush.web.validation import threaded_validate + schema = self._makeBasicSchema() + + class AHandler(RequestHandler): + @threaded_validate(schema) + def get(self): + self.write("done") + self.finish() + + req = self._makeDummyRequest() + app = Mock() + app.ui_modules = dict() + app.ui_methods = dict() + vr = AHandler(app, req) + d = Deferred() + vr.finish = lambda: d.callback(True) + vr.write = Mock() + vr._overload_err = Mock() + vr._validation_err = Mock() + vr._response_err = Mock() + vr.ap_settings = Mock() + + e = Deferred() + + def check_result(result): + vr.write.assert_called_with("done") + e.callback(True) + + d.addCallback(check_result) + + vr.get() + return e + + +class TestSimplePushRequestSchema(unittest.TestCase): + def _makeFUT(self): + from autopush.web.validation import SimplePushRequestSchema + schema = SimplePushRequestSchema() + schema.context["settings"] = Mock() + schema.context["log"] = Mock() + return schema + + def _make_test_data(self, headers=None, body="", path_args=None, + path_kwargs=None, arguments=None): + return dict( + headers=headers or {}, + body=body, + path_args=path_args or [], + path_kwargs=path_kwargs or {}, + arguments=arguments or {}, + ) + + def test_valid_data(self): + schema = self._makeFUT() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="simplepush", + ) + result, errors = schema.load(self._make_test_data()) + eq_(errors, {}) + eq_(result["data"], None) + eq_(str(result["subscription"]["uaid"]), dummy_uaid) + + def test_valid_data_in_body(self): + schema = self._makeFUT() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="simplepush", + ) + result, errors = schema.load( + self._make_test_data(body="version=&data=asdfasdf") + ) + eq_(errors, {}) + eq_(result["data"], "asdfasdf") + eq_(str(result["subscription"]["uaid"]), dummy_uaid) + + def test_valid_version(self): + schema = self._makeFUT() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="simplepush", + ) + result, errors = schema.load( + self._make_test_data(body="version=3&data=asdfasdf") + ) + eq_(errors, {}) + eq_(result["data"], "asdfasdf") + eq_(result["version"], 3) + eq_(str(result["subscription"]["uaid"]), dummy_uaid) + + def test_invalid_router_type(self): + schema = self._makeFUT() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + + with assert_raises(InvalidRequest) as cm: + schema.load(self._make_test_data()) + + eq_(cm.exception.errno, 108) + + def test_invalid_uaid_not_found(self): + schema = self._makeFUT() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + + def throw_item(*args, **kwargs): + raise ItemNotFound("Not found") + + schema.context["settings"].router.get_uaid.side_effect = throw_item + + with assert_raises(InvalidRequest) as cm: + schema.load(self._make_test_data()) + + eq_(cm.exception.errno, 103) + + def test_invalid_token(self): + schema = self._makeFUT() + + def throw_item(*args, **kwargs): + raise InvalidTokenException("Not found") + + schema.context["settings"].parse_endpoint.side_effect = throw_item + + with assert_raises(InvalidRequest) as cm: + schema.load(self._make_test_data()) + + eq_(cm.exception.errno, 102) + + def test_invalid_data_size(self): + schema = self._makeFUT() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="simplepush", + ) + schema.context["settings"].max_data = 1 + + with assert_raises(InvalidRequest) as cm: + schema.load(self._make_test_data(body="version=&data=asdfasdf")) + + eq_(cm.exception.errno, 104) diff --git a/autopush/web/__init__.py b/autopush/web/__init__.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/autopush/web/__init__.py @@ -0,0 +1 @@ +# diff --git a/autopush/web/base.py b/autopush/web/base.py new file mode 100644 index 00000000..60c0873f --- /dev/null +++ b/autopush/web/base.py @@ -0,0 +1,234 @@ +import json +import time +import uuid +from collections import namedtuple + +import cyclone.web +from boto.dynamodb2.exceptions import ( + ProvisionedThroughputExceededException, +) +from twisted.logger import Logger +from twisted.python import failure + +from autopush.db import ( + hasher, + normalize_id, +) +from autopush.exceptions import InvalidRequest +from autopush.router.interface import RouterException + +status_codes = { + 200: "OK", + 201: "Created", + 202: "Accepted", + 400: "Bad Request", + 401: "Unauthorized", + 404: "Not Found", + 413: "Payload Too Large", + 500: "Internal Server Error", + 503: "Service Unavailable", +} + + +class Notification(namedtuple("Notification", + "version data channel_id headers ttl")): + """Parsed notification from the request""" + + +class VapidAuthException(Exception): + """Exception if the VAPID Auth token fails""" + pass + + +class BaseHandler(cyclone.web.RequestHandler): + """Common overrides for Push web API's""" + cors_methods = "" + cors_request_headers = [] + cors_response_headers = [] + + log = Logger() + + ############################################################# + # Cyclone API Methods + ############################################################# + def initialize(self, ap_settings): + """Setup basic aliases and attributes""" + self.uaid_hash = "" + self._uaid = "" + self._chid = "" + self.start_time = time.time() + self.ap_settings = ap_settings + self.metrics = ap_settings.metrics + self.request_id = str(uuid.uuid4()) + self._client_info = self._init_info() + + def prepare(self): + """Common request preparation""" + if self.ap_settings.cors: + self.set_header("Access-Control-Allow-Origin", "*") + self.set_header("Access-Control-Allow-Methods", + self.cors_methods) + self.set_header("Access-Control-Allow-Headers", + ",".join(self.cors_request_headers)) + self.set_header("Access-Control-Expose-Headers", + ",".join(self.cors_response_headers)) + + def write_error(self, code, **kwargs): + """Write the error (otherwise unhandled exception when dealing with + unknown method specifications.) + + This is a Cyclone API Override. + + """ + self.set_status(code) + if "exc_info" in kwargs: + fmt = kwargs.get("format", "Exception") + self.log.failure( + format=fmt, + failure=failure.Failure(*kwargs["exc_info"]), + **self._client_info) + else: + self.log.failure("Error in handler: %s" % code, + **self._client_info) + self.finish() + + ############################################################# + # Cyclone HTTP Methods + ############################################################# + def options(self, *args): + """HTTP OPTIONS Handler""" + + def head(self, *args): + """HTTP HEAD Handler""" + + ############################################################# + # Utility Methods + ############################################################# + @property + def uaid(self): + """Return the UAID that was set""" + return self._uaid + + @uaid.setter + def uaid(self, value): + """Set the UAID and update the uaid hash""" + self._uaid = value + self.uaid_hash = hasher(value) + self._client_info["uaid_hash"] = self.uaid_hash + + @property + def chid(self): + """Return the ChannelID""" + return self._chid + + @chid.setter + def chid(self, value): + """Set the ChannelID and record to _client_info""" + self._chid = normalize_id(value) + self._client_info["channelID"] = self._chid + + def _init_info(self): + """Returns a dict of additional client data""" + return { + "request_id": self.request_id, + "user_agent": self.request.headers.get("user-agent", ""), + "remote_ip": self.request.headers.get("x-forwarded-for", + self.request.remote_ip), + "authorization": self.request.headers.get("authorization", ""), + "message_ttl": self.request.headers.get("ttl", ""), + } + + ############################################################# + # Error Callbacks + ############################################################# + def _write_response(self, status_code, errno, message=None, headers=None): + """Writes out a full JSON error and sets the appropriate status""" + self.set_status(status_code) + error_data = dict( + code=status_code, + errno=errno, + error=status_codes.get(status_code, "") + ) + if message: + error_data["message"] = message + self.write(json.dumps(error_data)) + self.set_header("Content-Type", "application/json") + if headers: + for header in headers.keys(): + self.set_header(header, headers.get(header)) + self.finish() + + def _validation_err(self, fail): + """errBack for validation errors""" + fail.trap(InvalidRequest) + exc = fail.value + self.log.info(format="Request validation error", + status_code=exc.status_code, + errno=exc.errno) + self._write_response(exc.status_code, exc.errno) + + def _response_err(self, fail): + """errBack for all exceptions that should be logged + + This traps all exceptions to prevent any further callbacks from + running. + + """ + fmt = fail.value.message or 'Exception' + self.log.failure(format=fmt, failure=fail, + status_code=500, errno=999, **self._client_info) + self._write_response(500, 999) + + def _overload_err(self, fail): + """errBack for throughput provisioned exceptions""" + fail.trap(ProvisionedThroughputExceededException) + self.log.info(format="Throughput Exceeded", status_code=503, + errno=201, **self._client_info) + self._write_response(503, 201) + + def _router_response(self, response): + for name, val in response.headers.items(): + self.set_header(name, val) + + if 200 <= response.status_code < 300: + self.set_status(response.status_code) + self.write(response.response_body) + self.finish() + else: + return self._write_response( + response.status_code, + errno=response.errno or 999, + message=response.response_body) + + def _router_fail_err(self, fail): + """errBack for router failures""" + fail.trap(RouterException) + exc = fail.value + if exc.log_exception or exc.status_code >= 500: + fmt = fail.value.message or 'Exception' + self.log.failure(format=fmt, + failure=fail, status_code=exc.status_code, + errno=exc.errno or "", + **self._client_info) # pragma nocover + if 200 <= exc.status_code < 300: + self.log.info(format="Success", status_code=exc.status_code, + logged_status=exc.logged_status or "", + **self._client_info) + elif 400 <= exc.status_code < 500: + self.log.info(format="Client error", + status_code=exc.status_code, + logged_status=exc.logged_status or "", + errno=exc.errno or "", + **self._client_info) + self._router_response(exc) + + def _write_validation_err(self, errors): + """Writes a set of validation errors out with details about what + went wrong""" + self.set_status(400) + error_data = dict( + code=400, + errors=errors + ) + self.write(json.dumps(error_data)) + self.finish() diff --git a/autopush/web/simplepush.py b/autopush/web/simplepush.py new file mode 100644 index 00000000..a4923cb4 --- /dev/null +++ b/autopush/web/simplepush.py @@ -0,0 +1,53 @@ +import time + +from twisted.internet.defer import Deferred + +from autopush.web.base import ( + BaseHandler, + Notification, +) +from autopush.web.validation import ( + threaded_validate, + SimplePushRequestSchema, +) + + +class SimplePushHandler(BaseHandler): + cors_methods = "PUT" + + @threaded_validate(SimplePushRequestSchema()) + def put(self, api_ver="v1", token=None): + sub = self.valid_input["subscription"] + user_data = sub["user_data"] + router = self.ap_settings.routers[user_data["router_type"]] + self._client_info["message_id"] = self.valid_input["version"] + + notification = Notification( + version=self.valid_input["version"], + data=self.valid_input["data"], + channel_id=str(sub["chid"]), + headers=self.request.headers, + ttl=None) + + d = Deferred() + d.addCallback(router.route_notification, user_data) + d.addCallback(self._router_completed, user_data, "") + d.addErrback(self._router_fail_err) + d.addErrback(self._response_err) + + # Call the prepared router + d.callback(notification) + + def _router_completed(self, response, uaid_data, warning=""): + """Called after router has completed successfully""" + if response.status_code == 200 or response.logged_status == 200: + self.log.info(format="Successful delivery", + **self._client_info) + elif response.status_code == 202 or response.logged_status == 202: + self.log.info(format="Router miss, message stored.", + **self._client_info) + time_diff = time.time() - self.start_time + self.metrics.timing("updates.handled", duration=time_diff) + response.response_body = ( + response.response_body + " " + warning).strip() + self._router_response(response) diff --git a/autopush/web/validation.py b/autopush/web/validation.py new file mode 100644 index 00000000..3f66de06 --- /dev/null +++ b/autopush/web/validation.py @@ -0,0 +1,166 @@ +"""Validation handler and Schemas""" +import time +import urlparse +from functools import wraps + +from boto.dynamodb2.exceptions import ( + ItemNotFound, +) +from marshmallow import ( + Schema, + fields, + pre_load, + validates, + validates_schema, +) +from twisted.internet.threads import deferToThread +from twisted.logger import Logger + +from autopush.exceptions import ( + InvalidRequest, + InvalidTokenException, +) + + +class ThreadedValidate(object): + """A cyclone request validation decorator + + Exposed as a classmethod for running a marshmallow-based validation schema + in a separate thread for a cyclone request handler. + + """ + log = Logger() + + def __init__(self, schema): + self.schema = schema + + def _validate_request(self, request_handler): + """Validates a schema_class against a cyclone request""" + data = { + "headers": request_handler.request.headers, + "body": request_handler.request.body, + "path_args": request_handler.path_args, + "path_kwargs": request_handler.path_kwargs, + "arguments": request_handler.request.arguments, + } + self.schema.context["settings"] = request_handler.ap_settings + self.schema.context["log"] = self.log + return self.schema.load(data) + + def _call_func(self, result, func, request_handler, *args, **kwargs): + output, errors = result + if errors: + return request_handler._write_validation_err(errors) + request_handler.valid_input = output + return func(request_handler, *args, **kwargs) + + def _decorator(self, func): + @wraps(func) + def wrapper(request_handler, *args, **kwargs): + # Wrap the handler in @cyclone.web.synchronous + request_handler._auto_finish = False + + d = deferToThread(self._validate_request, request_handler) + d.addErrback(request_handler._overload_err) + d.addErrback(request_handler._validation_err) + d.addErrback(request_handler._response_err) + d.addCallback(self._call_func, func, request_handler, *args, + **kwargs) + return wrapper + + @classmethod + def validate(cls, schema): + """Validate a request schema in a separate thread before calling the + request handler + + An alias `threaded_validate` should be used from this module. + + Using `cyclone.web.asynchronous` is not needed as this function + will attach equivilant functionality to the method handler. Calling + `self.finish()` is needed on decorated handlers. + + .. code-block:: + + class MyHandler(cyclone.web.RequestHandler): + @threaded_validate(MySchema()) + def post(self): + ... + + """ + return cls(schema)._decorator + + +# Alias to the validation classmethod decorator +threaded_validate = ThreadedValidate.validate + + +class SimplePushSubscriptionSchema(Schema): + uaid = fields.UUID(required=True) + chid = fields.UUID(required=True) + + @pre_load + def extract_subscription(self, d): + try: + result = self.context["settings"].parse_endpoint( + token=d["token"], + version=d["api_ver"], + ) + except InvalidTokenException: + raise InvalidRequest("invalid token", errno=102) + return result + + @validates_schema + def validate_uaid_chid(self, d): + try: + result = self.context["settings"].router.get_uaid(d["uaid"].hex) + except ItemNotFound: + raise InvalidRequest("UAID not found", status_code=410, errno=103) + + if result.get("router_type") != "simplepush": + raise InvalidRequest("Wrong URL for user", errno=108) + + # Propagate the looked up user data back out + d["user_data"] = result + + +class SimplePushRequestSchema(Schema): + subscription = fields.Nested(SimplePushSubscriptionSchema, + load_from="token_info") + version = fields.Integer(missing=time.time) + data = fields.String(missing=None) + + @validates('data') + def validate_data(self, value): + max_data = self.context["settings"].max_data + if value and len(value) > max_data: + raise InvalidRequest( + "Data payload must be smaller than {}".format(max_data), + errno=104, + ) + + @pre_load + def token_prep(self, d): + d["token_info"] = dict( + api_ver=d["path_kwargs"].get("api_ver"), + token=d["path_kwargs"].get("token"), + ) + return d + + @pre_load + def extract_fields(self, d): + body_string = d["body"] + version = data = None + if len(body_string) > 0: + body_args = urlparse.parse_qs(body_string, keep_blank_values=True) + version = body_args.get("version") + data = body_args.get("data") + else: + version = d["arguments"].get("version") + data = d["arguments"].get("data") + version = version[0] if version is not None else version + data = data[0] if data is not None else data + if version and version >= "1": + d["version"] = version + if data: + d["data"] = data + return d diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/autopush/web/webpush.py @@ -0,0 +1 @@ +# diff --git a/autopush/websocket.py b/autopush/websocket.py index 080bc4e8..511ad7d0 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -69,6 +69,9 @@ from autopush.noseplugin import track_object +USER_RECORD_VERSION = 1 + + def extract_code(data): """Extracts and converts a code key if found in data dict""" code = data.get("code", None) @@ -647,6 +650,10 @@ def _register_user(self, existing_user=True): # users if not existing_user: user_item["last_connect"] = generate_last_connect() + + # New users get a record_version so we can track changes that + # may require old user records to be expired on the fly + user_item["record_version"] = USER_RECORD_VERSION if self.ps.use_webpush: user_item["current_month"] = self.ps.message_month @@ -1009,8 +1016,12 @@ def process_register(self, data): return self.bad_message("register", "Invalid UUID specified") self.transport.pauseProducing() - d = self.deferToThread(self.ap_settings.make_endpoint, self.ps.uaid, - chid, data.get("key")) + if self.ps.use_webpush: + d = self.deferToThread(self.ap_settings.make_endpoint, + self.ps.uaid, chid, data.get("key")) + else: + d = self.deferToThread(self.ap_settings.make_simplepush_endpoint, + self.ps.uaid, chid) d.addCallback(self.finish_register, chid) d.addErrback(self.trap_cancel) d.addErrback(self.error_register) diff --git a/pypy-requirements.txt b/pypy-requirements.txt index 377431c5..b12939a5 100644 --- a/pypy-requirements.txt +++ b/pypy-requirements.txt @@ -31,6 +31,7 @@ idna==2.1 ipaddress==1.0.16 itsdangerous==0.24 jmespath==0.9.0 +marshmallow==2.7.3 mccabe==0.4.0 pbr==1.9.1 pluggy==0.3.1 diff --git a/requirements.txt b/requirements.txt index f5db4457..459da2c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,7 @@ idna==2.1 ipaddress==1.0.16 itsdangerous==0.24 jmespath==0.9.0 +marshmallow==2.7.3 mccabe==0.4.0 pbr==1.9.1 pluggy==0.3.1 diff --git a/test-requirements.txt b/test-requirements.txt index b2a64516..7f0f4d88 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -40,6 +40,7 @@ idna==2.1 ipaddress==1.0.16 itsdangerous==0.24 jmespath==0.9.0 +marshmallow==2.7.3 mccabe==0.4.0 pbr==1.9.1 pluggy==0.3.1