Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: Add multiple cert handlers for APNs
Browse files Browse the repository at this point in the history
This patch updates APNs handlers to accept platform based cert
configurations. See `configs/autopush_shared.ini.sample`. In addition,
this patch clarifies some argument references for routers (e.g. less
than useful `result` is now slightly more descriptive `uaid_data`)

Custom item names have been normalized to match gcm/fcm.
Document errors cleaned up a bit as well.

BREAKING CHANGE: the APNS configuration options have been altered, see
`configs/autopush_shared.ini.sample` for new APNS configuration
settings.

Closes #655
  • Loading branch information
jrconlin committed Sep 14, 2016
1 parent 054a9f1 commit 712e3d3
Show file tree
Hide file tree
Showing 14 changed files with 196 additions and 138 deletions.
20 changes: 11 additions & 9 deletions autopush/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,11 @@ def _token_valid(self, result):
d.addErrback(self._uaid_not_found_err)
self._db_error_handling(d)

def _uaid_lookup_results(self, result):
def _uaid_lookup_results(self, uaid_data):
"""Process the result of the AWS UAID lookup"""
# Save the whole record
router_key = self.router_key = result.get("router_type", "simplepush")
router_key = self.router_key = uaid_data.get("router_type",
"simplepush")
self._client_info["router_key"] = router_key

try:
Expand Down Expand Up @@ -558,7 +559,7 @@ def _uaid_lookup_results(self, result):
return

if use_simplepush:
self._route_notification(self.version, result, data)
self._route_notification(self.version, uaid_data, data)
return

# Web Push and bridged messages are encrypted binary blobs. We store
Expand All @@ -568,10 +569,10 @@ def _uaid_lookup_results(self, result):
# Generate a message ID, then route the notification.
d = deferToThread(self.ap_settings.fernet.encrypt, ':'.join([
'm', self.uaid, self.chid]).encode('utf8'))
d.addCallback(self._route_notification, result, data, ttl)
d.addCallback(self._route_notification, uaid_data, data, ttl)
return d

def _route_notification(self, version, result, data, ttl=None):
def _route_notification(self, version, uaid_data, data, ttl=None):
self.version = self._client_info['message_id'] = version
warning = ""
# Clean up the header values (remove padding)
Expand All @@ -587,8 +588,8 @@ def _route_notification(self, version, result, data, ttl=None):
ttl=ttl)

d = Deferred()
d.addCallback(self.router.route_notification, result)
d.addCallback(self._router_completed, result, warning)
d.addCallback(self.router.route_notification, uaid_data)
d.addCallback(self._router_completed, uaid_data, warning)
d.addErrback(self._router_fail_err)
d.addErrback(self._response_err)

Expand Down Expand Up @@ -698,7 +699,7 @@ def post(self, router_type="", router_token="", uaid="", chid=""):
self.app_server_key = params.get("key")
if new_uaid:
d = Deferred()
d.addCallback(router.register, params, router_token,
d.addCallback(router.register, params, router_token=router_token,
uri=self.request.uri)
d.addCallback(self._save_router_data, router_type)
d.addCallback(self._create_endpoint)
Expand Down Expand Up @@ -736,7 +737,8 @@ def put(self, router_type="", router_token="", uaid="", chid=""):

self.add_header("Content-Type", "application/json")
d = Deferred()
d.addCallback(router.register, router_data, uri=self.request.uri)
d.addCallback(router.register, router_data, router_token=router_token,
uri=self.request.uri)
d.addCallback(self._save_router_data, router_type)
d.addCallback(self._success)
d.addErrback(self._router_fail_err)
Expand Down
37 changes: 20 additions & 17 deletions autopush/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def obsolete_args(parser):
parser.add_argument('--max_message_size', type=int, help="OBSOLETE")
parser.add_argument('--s3_bucket', help='OBSOLETE')
parser.add_argument('--senderid_expry', help='OBSOLETE')
# old APNs args
parser.add_argument('--apns_enabled', help="OBSOLETE")
parser.add_argument('--apns_sandbox', help="OBSOLETE")
parser.add_argument('--apns_cert_file', help="OBSOLETE")
parser.add_argument('--apns_key_file', help="OBSOLETE")


def add_external_router_args(parser):
Expand Down Expand Up @@ -187,18 +192,14 @@ def add_external_router_args(parser):
parser.add_argument('--fcm_senderid', help='SenderID for FCM',
type=str, default="")
# Apple Push Notification system (APNs) for iOS
parser.add_argument('--apns_enabled', help="Enable APNS Bridge",
action="store_true", default=False,
env_var="APNS_ENABLED")
label = "APNS Router:"
parser.add_argument('--apns_sandbox', help="%s Use Dev Sandbox" % label,
action="store_true", default=False,
env_var="APNS_SANDBOX")
parser.add_argument('--apns_cert_file',
help="%s Certificate PEM file" % label,
type=str, env_var="APNS_CERT_FILE")
parser.add_argument('--apns_key_file', help="%s Key PEM file" % label,
type=str, env_var="APNS_KEY_FILE")
# credentials consist of JSON struct containing a channel type
# followed by the settings,
# e.g. {'firefox':{'cert': 'path.cert', 'key': 'path.key',
# 'sandbox': false}, ... }
parser.add_argument('--apns_creds', help="JSON dictionary of "
"APNS settings",
type=str, default="",
env_var="APNS_CREDS")
# UDP
parser.add_argument('--wake_timeout',
help="UDP: idle timeout before closing socket",
Expand Down Expand Up @@ -313,12 +314,14 @@ def make_settings(args, **kwargs):
router_conf["simplepush"] = {"idle": args.wake_timeout,
"server": args.wake_server,
"cert": args.wake_pem}
if args.apns_enabled:
if args.apns_creds:
# if you have the critical elements for each external router, create it
if args.apns_cert_file is not None and args.apns_key_file is not None:
router_conf["apns"] = {"sandbox": args.apns_sandbox,
"cert_file": args.apns_cert_file,
"key_file": args.apns_key_file}
try:
router_conf["apns"] = json.loads(args.apns_creds)
except (ValueError, TypeError):
log.critical(format="Invalid JSON specified for APNS config "
"options")
return
if args.gcm_enabled:
# Create a common gcmclient
try:
Expand Down
127 changes: 89 additions & 38 deletions autopush/router/apnsrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,78 +27,129 @@ class APNSRouter(object):
255: 'Unknown',
}

def _connect(self):
"""Connect to APNS"""
self.apns = apns.APNs(use_sandbox=self.config.get("sandbox", False),
cert_file=self.config.get("cert_file"),
key_file=self.config.get("key_file"),
enhanced=True)
def _connect(self, cert_info):
"""Connect to APNS
:param cert_info: APNS certificate configuration info
:type cert_info: dict
returns an instance of APNs that can be stored under the proper
channel.
"""
# Do I still need to call this in _error?
return apns.APNs(
use_sandbox=cert_info.get("sandbox", False),
cert_file=cert_info.get("cert"),
key_file=cert_info.get("key"),
enhanced=True)

def __init__(self, ap_settings, router_conf):
"""Create a new APNS router and connect to APNS"""
self.ap_settings = ap_settings
self._base_tags = []
self.config = router_conf
self.default_title = router_conf.get("default_title", "SimplePush")
self.default_body = router_conf.get("default_body", "New Alert")
self._connect()
self.apns = dict()
self._config = router_conf
for channel in self._config.keys():
self.apns[channel] = self._connect(self._config[channel])
self._connect(router_conf)
self.log.debug("Starting APNS router...")
self.ap_settings = ap_settings

def register(self, uaid, router_data, *args, **kwargs):
"""Validate that an APNs instance token is in the ``router_data``"""
def register(self, uaid, router_data, router_token="firefox",
*args, **kwargs):
"""Register an endpoint for APNS, on the `router_token` channel.
This will validate that an APNs instance token is in the
``router_data``,
:param uaid: User Agent Identifier
:type uaid: str
:param router_data: Dict containing router specific configuration info
:type router_data: dict
:param router_token: The channel identifier for cert info lookup
:type router_token: str
returns a modified router_data dict to be stored
in the user agent record.
"""
if router_token not in self.apns:
raise RouterException("Unknown channel specified",
status_code=400,
response_body="Unknown channel")
if not router_data.get("token"):
raise RouterException("No token registered", status_code=500,
response_body="No token registered")
router_data["channel"] = router_token
return router_data

def amend_msg(self, msg, router_data=None):
"""This function is stubbed out for this router"""
return msg

def route_notification(self, notification, uaid_data):
"""Start the APNS notification routing, returns a deferred"""
"""Start the APNS notification routing, returns a deferred
:param notification: Notification data to send
:type notification: dict
:param uaid_data: User Agent specific data
:type uaid_data: dict
"""
router_data = uaid_data["router_data"]
# Kick the entire notification routing off to a thread
return deferToThread(self._route, notification, router_data)

def _route(self, notification, router_data):
"""Blocking APNS call to route the notification"""
token = router_data["token"]
"""Blocking APNS call to route the notification
:param notification: Notification data to send
:type notification: dict
:param router_data: Pre-initialized data for this connection
:type router_data: dict
"""
router_token = router_data["token"]
self._channel = router_data["channel"]
config = self._config[self._channel]
apns_client = self.apns[self._channel]
custom = {
"Chid": notification.channel_id,
"Ver": notification.version,
"chid": notification.channel_id,
"ver": notification.version,
}
if notification.data:
custom["Msg"] = notification.data
custom["Con"] = notification.headers["content-encoding"]
custom["Enc"] = notification.headers["encryption"]
custom["body"] = notification.data
custom["con"] = notification.headers["content-encoding"]
custom["enc"] = notification.headers["encryption"]

if "crypto-key" in notification.headers:
custom["Cryptokey"] = notification.headers["crypto-key"]
custom["cryptokey"] = notification.headers["crypto-key"]
elif "encryption-key" in notification.headers:
custom["Enckey"] = notification.headers["encryption-key"]
custom["enckey"] = notification.headers["encryption-key"]

payload = apns.Payload(alert=router_data.get("title",
self.default_title),
content_available=1,
custom=custom)
payload = apns.Payload(
alert=router_data.get("title", config.get('default_title',
'Mozilla Push')),
content_available=1,
custom=custom)
now = int(time.time())
self.messages[now] = {"token": token, "payload": payload}
# TODO: Add listener for error handling.
self.apns.gateway_server.register_response_listener(self._error)
self.messages[now] = {"token": router_token, "payload": payload}
apns_client.gateway_server.register_response_listener(self._error)
self.ap_settings.metrics.increment(
"updates.client.bridge.apns.attempted",
"updates.client.bridge.apns.%s.attempted" % router_data["channel"],
self._base_tags)

self.apns.gateway_server.send_notification(token, payload, now)
apns_client.gateway_server.send_notification(router_token, payload,
now)

# cleanup sent messages
if self.messages:
for time_sent in self.messages.keys():
if time_sent < now - self.config.get("expry", 10):
if time_sent < now - config.get("expry", 10):
del self.messages[time_sent]
self.ap_settings.metrics.increment(
"updates.client.bridge.apns.succeed",
"updates.client.bridge.apns.%s.succeed" % router_data["channel"],
self._base_tags)
location = "%s/m/%s" % (self.ap_settings.endpoint_url,
notification.version)
Expand All @@ -117,11 +168,11 @@ def _error(self, err):
status=self.errors[err['status']])
if err['status'] in [1, 255]:
self.log.debug("Retrying...")
self._connect()
resend = self.messages.get(err.get('identifier'))
if resend is None:
return
self.apns.gateway_server.send_notification(resend['token'],
resend['payload'],
err['identifier'],
)
apns_client = self.apns[self._channel]
apns_client.gateway_server.send_notification(resend['token'],
resend['payload'],
err['identifier'],
)
3 changes: 2 additions & 1 deletion autopush/router/fcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def amend_msg(self, msg, data=None):
msg["senderid"] = data.get('creds', {}).get('senderID')
return msg

def register(self, uaid, router_data, senderid=None, *args, **kwargs):
def register(self, uaid, router_data, router_token=None, *args, **kwargs):
"""Validate that the FCM Instance Token is in the ``router_data``"""
senderid = router_token
# "token" is the GCM registration id token generated by the client.
if "token" not in router_data:
raise self._error("connect info missing FCM Instance 'token'",
Expand Down
3 changes: 2 additions & 1 deletion autopush/router/gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def amend_msg(self, msg, data=None):
msg["senderid"] = data.get('creds', {}).get('senderID')
return msg

def register(self, uaid, router_data, senderid=None, *args, **kwargs):
def register(self, uaid, router_data, router_token=None, *args, **kwargs):
"""Validate that the GCM Instance Token is in the ``router_data``"""
# "token" is the GCM registration id token generated by the client.
if "token" not in router_data:
Expand All @@ -56,6 +56,7 @@ def register(self, uaid, router_data, senderid=None, *args, **kwargs):
# be able to match senderID to it's corresponding auth key.
# If the client has an unexpected or invalid SenderID,
# it is impossible for us to reach them.
senderid = router_token
if senderid not in self.senderIDs:
raise self._error("Invalid SenderID", status=410, errno=105,
uri=kwargs.get('uri'),
Expand Down
2 changes: 1 addition & 1 deletion autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,7 +1830,7 @@ def test_put(self, *args):
def handle_finish(value):
self.reg.write.assert_called_with({})
frouter.register.assert_called_with(
dummy_uaid, data, uri=self.reg.request.uri
dummy_uaid, data, router_token="", uri=self.reg.request.uri
)

self.finish_deferred.addCallback(handle_finish)
Expand Down
16 changes: 10 additions & 6 deletions autopush/tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import datetime
import json

from mock import Mock, patch
from moto import mock_dynamodb2
Expand Down Expand Up @@ -203,12 +204,10 @@ def test_skip_logging(self):
class EndpointMainTestCase(unittest.TestCase):
class TestArg:
# important stuff
apns_enabled = True
apns_cert_file = "cert.file"
apns_key_file = "key.file"
apns_creds = json.dumps({"firefox": {"cert": "cert.file",
"key": "key.file"}})
gcm_enabled = True
# less important stuff
apns_sandbox = False
gcm_ttl = 999
gcm_dryrun = False
gcm_collapsekey = "collapse"
Expand Down Expand Up @@ -278,14 +277,19 @@ def test_bad_senderidlist(self):
"--senderid_list='[Invalid'"
], False)

def test_bad_apnsconf(self):
assert endpoint_main([
"--apns_creds='[Invalid'"
], False)

def test_ping_settings(self):
ap = make_settings(self.TestArg)
# verify that the hostname is what we said.
eq_(ap.hostname, self.TestArg.hostname)
# gcm isn't created until later since we may have to pull
# config info from s3
eq_(ap.routers["apns"].apns.cert_file, self.TestArg.apns_cert_file)
eq_(ap.routers["apns"].apns.key_file, self.TestArg.apns_key_file)
eq_(ap.routers["apns"].apns["firefox"].cert_file, "cert.file")
eq_(ap.routers["apns"].apns["firefox"].key_file, "key.file")
eq_(ap.wake_timeout, 10)

def test_bad_senders(self):
Expand Down
Loading

0 comments on commit 712e3d3

Please sign in to comment.