Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: add ADM support
Browse files Browse the repository at this point in the history
Closes #1275
  • Loading branch information
jrconlin committed Aug 29, 2018
1 parent 6e0452c commit 0d84de6
Show file tree
Hide file tree
Showing 12 changed files with 667 additions and 5 deletions.
9 changes: 8 additions & 1 deletion autopush/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def from_argparse(cls, ns, **kwargs):
"max_data": ns.max_data,
"collapsekey": ns.gcm_collapsekey,
"senderIDs": sender_ids}

client_certs = None
# endpoint only
if getattr(ns, 'client_certs', None):
Expand Down Expand Up @@ -296,6 +295,14 @@ def from_argparse(cls, ns, **kwargs):
"auth": ns.fcm_auth,
"senderid": ns.fcm_senderid}

if ns.adm_creds:
# Create a common admclient
try:
router_conf["adm"] = json.loads(ns.adm_creds)
except (ValueError, TypeError):
raise InvalidConfig(
"Invalid JSON specified for ADM config options")

ami_id = None
# Not a fan of double negatives, but this makes more
# understandable args
Expand Down
6 changes: 6 additions & 0 deletions autopush/main_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def _add_external_router_args(parser):
"APNS settings",
type=str, default="",
env_var="APNS_CREDS")
# Amazon Device Messaging client credentials
parser.add_argument('--adm_creds', help="JSON dictionary of "
"Amazon Device Message "
"credentials",
type=str, default="",
env_var="ADM_CREDS")


def parse_connection(config_files, args):
Expand Down
6 changes: 5 additions & 1 deletion autopush/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from autopush.router.interface import IRouter # noqa
from autopush.router.webpush import WebPushRouter
from autopush.router.fcm import FCMRouter
from autopush.router.adm import ADMRouter

__all__ = ["APNSRouter", "FCMRouter", "GCMRouter", "WebPushRouter"]
__all__ = ["APNSRouter", "FCMRouter", "GCMRouter", "WebPushRouter",
"ADMRouter"]


def routers_from_config(conf, db, agent):
Expand All @@ -30,4 +32,6 @@ def routers_from_config(conf, db, agent):
routers["apns"] = APNSRouter(conf, router_conf["apns"], db.metrics)
if 'gcm' in router_conf:
routers["gcm"] = GCMRouter(conf, router_conf["gcm"], db.metrics)
if 'adm' in router_conf:
routers["adm"] = ADMRouter(conf, router_conf["adm"], db.metrics)
return routers
220 changes: 220 additions & 0 deletions autopush/router/adm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""ADM Router"""
import time
import requests

from typing import Any # noqa

from requests.exceptions import ConnectionError, Timeout
from twisted.internet.threads import deferToThread
from twisted.logger import Logger

from autopush.exceptions import RouterException
from autopush.metrics import make_tags
from autopush.router.interface import RouterResponse
from autopush.types import JSONDict # noqa


class ADMAuthError(Exception):
pass


class ADMClient(object):
def __init__(self,
credentials=None,
logger=None,
metrics=None,
endpoint="api.amazon.com",
timeout=2,
**options
):

self._client_id = credentials["client_id"]
self._client_secret = credentials["client_secret"]
self._token_exp = 0
self._auth_token = None
self._aws_host = endpoint
self._logger = logger
self._metrics = metrics
self._request = requests
self._timeout = timeout

def refresh_key(self):
url = "https://{}/auth/O2/token".format(self._aws_host)
if self._auth_token is None or self._token_exp < time.time():
body = dict(
grant_type="client_credentials",
scope="messaging:push",
client_id=self._client_id,
client_secret=self._client_secret
)
headers = {
"content-type": "application/x-www-form-urlencoded"
}
resp = self._request.post(url, data=body, headers=headers,
timeout=self._timeout)
if resp.status_code != 200:
self._logger.error("Could not get ADM Auth token {}".format(
resp.text
))
raise ADMAuthError("Could not fetch auth token")
reply = resp.json()
self._auth_token = reply['access_token']
self._token_exp = time.time() + reply.get('expires_in', 0)

def send(self, reg_id, payload, ttl=None, collapseKey=None):
self.refresh_key()
headers = {
"Authorization": "Bearer {}".format(self._auth_token),
"Content-Type": "application/json",
"X-Amzn-Type-Version":
"[email protected]",
"X-Amzn-Accept-Type":
"[email protected]",
"Accept": "application/json",
}
data = {}
if ttl:
data["expiresAfter"] = ttl
if collapseKey:
data["consolidationKey"] = collapseKey
data["data"] = payload
url = ("https://api.amazon.com/messaging/registrations"
"/{}/messages".format(reg_id))
resp = self._request.post(
url,
json=data,
headers=headers,
timeout=self._timeout,
)
# in fine tradition, the response message can sometimes be less than
# helpful. Still, good idea to include it anyway.
if resp.status_code != 200:
self._logger.error("Could not send ADM message: " + resp.text)
raise RouterException(resp.content)


class ADMRouter(object):
"""Amazon Device Messaging Router Implementation"""
log = Logger()
dryRun = 0
collapseKey = None
MAX_TTL = 2419200

def __init__(self, conf, router_conf, metrics):
"""Create a new ADM router and connect to ADM"""
self.conf = conf
self.router_conf = router_conf
self.metrics = metrics
self.min_ttl = router_conf.get("ttl", 60)
timeout = router_conf.get("timeout", 10)
self.profiles = dict()
for profile in router_conf:
config = router_conf[profile]
if "client_id" not in config or "client_secret" not in config:
raise IOError("Profile info incomplete, missing id or secret")
self.profiles[profile] = ADMClient(
credentials=config,
logger=self.log,
metrics=self.metrics,
timeout=timeout)
self._base_tags = ["platform:adm"]
self.log.debug("Starting ADM router...")

def amend_endpoint_response(self, response, router_data):
# type: (JSONDict, JSONDict) -> None
pass

def register(self, uaid, router_data, app_id, *args, **kwargs):
# type: (str, JSONDict, str, *Any, **Any) -> None
"""Validate that the ADM Registration ID is in the ``router_data``"""
if "token" not in router_data:
raise self._error("connect info missing ADM Instance 'token'",
status=401)
profile_id = app_id
if profile_id not in self.profiles:
raise self._error("Invalid ADM Profile",
status=410, errno=105,
uri=kwargs.get('uri'),
profile_id=profile_id)
# Assign a profile
router_data["creds"] = {"profile": profile_id}

def route_notification(self, notification, uaid_data):
"""Start the ADM notification routing, returns a deferred"""
# Kick the entire notification routing off to a thread
return deferToThread(self._route, notification, uaid_data)

def _route(self, notification, uaid_data):
"""Blocking ADM call to route the notification"""
router_data = uaid_data["router_data"]
# THIS MUST MATCH THE CHANNELID GENERATED BY THE REGISTRATION SERVICE
# Currently this value is in hex form.
data = {"chid": notification.channel_id.hex}
# Payload data is optional. The endpoint handler validates that the
# correct encryption headers are included with the data.
if notification.data:
data['body'] = notification.data
data['con'] = notification.headers['encoding']

if 'encryption' in notification.headers:
data['enc'] = notification.headers.get('encryption')
if 'crypto_key' in notification.headers:
data['cryptokey'] = notification.headers['crypto_key']

# registration_ids are the ADM instance tokens (specified during
# registration.
ttl = min(self.MAX_TTL,
max(notification.ttl or 0, self.min_ttl))

try:
adm = self.profiles[router_data['creds']['profile']]
adm.send(
reg_id=router_data.get("token"),
payload=data,
collapseKey=notification.topic,
ttl=ttl
)
except RouterException:
raise # pragma nocover
except Timeout as e:
self.log.warn("ADM Timeout: %s" % e)
self.metrics.increment("notification.bridge.error",
tags=make_tags(
self._base_tags,
reason="timeout"))
raise RouterException("Server error", status_code=502,
errno=902,
log_exception=False)
except ConnectionError as e:
self.log.warn("ADM Unavailable: %s" % e)
self.metrics.increment("notification.bridge.error",
tags=make_tags(
self._base_tags,
reason="connection_unavailable"))
raise RouterException("Server error", status_code=502,
errno=902,
log_exception=False)
except ADMAuthError as e:
self.log.error("ADM unable to authorize: %s" % e)
self.metrics.increment("notification.bridge.error",
tags=make_tags(
self._base_tags,
reason="auth failure"
))
raise RouterException("Server error", status_code=500,
errno=902,
log_exception=False)
except Exception as e:
self.log.error("Unhandled exception in ADM Routing: %s" % e)
raise RouterException("Server error", status_code=500)
location = "%s/m/%s" % (self.conf.endpoint_url, notification.version)
return RouterResponse(status_code=201, response_body="",
headers={"TTL": ttl,
"Location": location},
logged_status=200)

def _error(self, err, status, **kwargs):
"""Error handler that raises the RouterException"""
self.log.debug(err, **kwargs)
return RouterException(err, status_code=status, response_body=err,
**kwargs)
Loading

0 comments on commit 0d84de6

Please sign in to comment.