Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feature: allow channels to register with public key
Browse files Browse the repository at this point in the history
Channels can include a public key when registering a new subscription
channel. This public key should match the public key used to send
subscription updates later.

NOTE: this patch changes the format of the endpoint URLs, & the content
of the endpoint URL token. This change also requires that ChannelIDs be
normalized to dashed format, (e.g. a lower case, dash delimited string
"deadbeef-0000-0000-deca-fbad11112222") This is the default mechanism
used by Firefox for UUID generation. It is STRONGLY urged that clients
normalize UUIDs used for ChannelIDs and User Agent IDs. While this
should not break existing clients, additional testing may be required.

Closes #326
  • Loading branch information
jrconlin committed Mar 3, 2016
1 parent 2a2c95f commit baf7e67
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 176 deletions.
39 changes: 29 additions & 10 deletions autopush/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ def hasher(uaid):
return uaid


def normalize_id(id):
if not id:
return id
if (len(id) == 36 and
id[8] == id[13] == id[18] == id[23] == '-'):
return id.lower()
raw = filter(lambda x: x in '0123456789abcdef', id.lower())
if len(raw) != 32:
raise ValueError("Invalid UUID")
return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:]))


def make_rotating_tablename(prefix, delta=0):
"""Creates a tablename for table rotation based on a prefix with a given
month delta."""
Expand Down Expand Up @@ -261,7 +273,8 @@ def save_notification(self, uaid, chid, version):
cond = "attribute_not_exists(version) or version < :ver"
conn.put_item(
self.table.table_name,
item=self.encode(dict(uaid=hasher(uaid), chid=chid,
item=self.encode(dict(uaid=hasher(uaid),
chid=normalize_id(chid),
version=version)),
condition_expression=cond,
expression_attribute_values={
Expand All @@ -281,10 +294,12 @@ def delete_notification(self, uaid, chid, version=None):
"""
try:
if version:
self.table.delete_item(uaid=hasher(uaid), chid=chid,
self.table.delete_item(uaid=hasher(uaid),
chid=normalize_id(chid),
expected={"version__eq": version})
else:
self.table.delete_item(uaid=hasher(uaid), chid=chid)
self.table.delete_item(uaid=hasher(uaid),
chid=normalize_id(chid))
return True
except ProvisionedThroughputExceededException:
self.metrics.increment("error.provisioned.delete_notification")
Expand Down Expand Up @@ -312,7 +327,8 @@ def register_channel(self, uaid, channel_id):
db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "})
# Generate our update expression
expr = "ADD chids :channel_id"
expr_values = self.encode({":channel_id": set([channel_id])})
expr_values = self.encode({":channel_id":
set([normalize_id(channel_id)])})
conn.update_item(
self.table.table_name,
db_key,
Expand All @@ -327,7 +343,8 @@ def unregister_channel(self, uaid, channel_id, **kwargs):
conn = self.table.connection
db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "})
expr = "DELETE chids :channel_id"
expr_values = self.encode({":channel_id": set([channel_id])})
expr_values = self.encode({":channel_id":
set([normalize_id(channel_id)])})

result = conn.update_item(
self.table.table_name,
Expand Down Expand Up @@ -375,7 +392,7 @@ def store_message(self, uaid, channel_id, message_id, ttl, data=None,
with the message id"""
item = dict(
uaid=hasher(uaid),
chidmessageid="%s:%s" % (channel_id, message_id),
chidmessageid="%s:%s" % (normalize_id(channel_id), message_id),
data=data,
headers=headers,
ttl=ttl,
Expand Down Expand Up @@ -407,7 +424,7 @@ def update_message(self, uaid, channel_id, message_id, ttl, data=None,
item["headers"] = headers
item["data"] = data
try:
chidmessageid = "%s:%s" % (channel_id, message_id)
chidmessageid = "%s:%s" % (normalize_id(channel_id), message_id)
db_key = self.encode({"uaid": hasher(uaid),
"chidmessageid": chidmessageid})
expr = ("SET #tl=:ttl, #ts=:timestamp,"
Expand Down Expand Up @@ -439,14 +456,16 @@ def delete_message(self, uaid, channel_id, message_id, updateid=None):
try:
self.table.delete_item(
uaid=hasher(uaid),
chidmessageid="%s:%s" % (channel_id, message_id),
chidmessageid="%s:%s" % (normalize_id(channel_id),
message_id),
expected={'updateid__eq': updateid})
except ConditionalCheckFailedException:
return False
else:
self.table.delete_item(
uaid=hasher(uaid),
chidmessageid="%s:%s" % (channel_id, message_id))
chidmessageid="%s:%s" % (normalize_id(channel_id),
message_id))
return True

def delete_messages(self, uaid, chidmessageids):
Expand All @@ -463,7 +482,7 @@ def delete_messages_for_channel(self, uaid, channel_id):
"""Deletes all messages for a uaid/channel_id"""
results = self.table.query_2(
uaid__eq=hasher(uaid),
chidmessageid__beginswith="%s:" % channel_id,
chidmessageid__beginswith="%s:" % normalize_id(channel_id),
consistent=True,
attributes=("chidmessageid",),
)
Expand Down
70 changes: 39 additions & 31 deletions autopush/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@

from autopush.db import (
generate_last_connect,
hasher
hasher,
normalize_id,
)
from autopush.router.interface import RouterException
from autopush.utils import (
Expand Down Expand Up @@ -174,8 +175,8 @@ def chid(self):
@chid.setter
def chid(self, value):
"""Set the ChannelID and record to _client_info"""
self._chid = value
self._client_info["channelID"] = value
self._chid = normalize_id(value)
self._client_info["channelID"] = self._chid

def _init_info(self):
"""Returns a dict of additional client data"""
Expand Down Expand Up @@ -309,7 +310,7 @@ def _invalid_auth(self, fail):
log.msg("Invalid bearer token: " + message, **self._client_info)
raise VapidAuthException("Invalid bearer token: " + message)

def _process_auth(self, result):
def _process_auth(self, result, keys):
"""Process the optional VAPID auth token.
VAPID requires two headers to be present;
Expand All @@ -322,28 +323,26 @@ def _process_auth(self, result):
# No auth present, so it's not a VAPID call.
if not authorization:
return result

header_info = parse_header(self.request.headers.get('crypto-key'))
if not header_info:
if not keys:
raise VapidAuthException("Missing Crypto-Key")
values = header_info[-1]
if isinstance(values, dict):
crypto_key = values.get('p256ecdsa')
try:
(auth_type, token) = authorization.split(' ', 1)
except ValueError:
raise VapidAuthException("Invalid Authorization header")
# if it's a bearer token containing what may be a JWT
if auth_type.lower() == AUTH_SCHEME and '.' in token:
d = deferToThread(extract_jwt, token, crypto_key)
d.addCallback(self._store_auth, crypto_key, token, result)
d.addErrback(self._invalid_auth)
return d
# otherwise, it's not, so ignore the VAPID data.
return result
else:
if not isinstance(keys, dict):
raise VapidAuthException("Invalid Crypto-Key")
crypto_key = keys.get('p256ecdsa')
if not crypto_key:
raise VapidAuthException("Invalid bearer token: "
"improperly specified crypto-key")
try:
(auth_type, token) = authorization.split(' ', 1)
except ValueError:
raise VapidAuthException("Invalid Authorization header")
# if it's a bearer token containing what may be a JWT
if auth_type.lower() == AUTH_SCHEME and '.' in token:
d = deferToThread(extract_jwt, token, crypto_key)
d.addCallback(self._store_auth, crypto_key, token, result)
d.addErrback(self._invalid_auth)
return d
# otherwise, it's not, so ignore the VAPID data.
return result


class MessageHandler(AutoendpointHandler):
Expand Down Expand Up @@ -404,17 +403,27 @@ class EndpointHandler(AutoendpointHandler):
# Cyclone HTTP Methods
#############################################################
@cyclone.web.asynchronous
def put(self, token):
def put(self, api_ver="v0", token=None):
"""HTTP PUT Handler
Primary entry-point to handling a notification for a push client.
"""
self.start_time = time.time()
fernet = self.ap_settings.fernet

d = deferToThread(fernet.decrypt, token.encode('utf8'))
d.addCallback(self._process_auth)
public_key = None
keys = {}
crypto_key = self.request.headers.get('crypto-key')
if crypto_key:
header_info = parse_header(crypto_key)
keys = header_info[-1]
if isinstance(keys, dict):
public_key = keys.get('p256ecdsa')

d = deferToThread(self.ap_settings.parse_endpoint,
token,
api_ver,
public_key)
d.addCallback(self._process_auth, keys)
d.addCallback(self._token_valid)
d.addErrback(self._auth_err)
d.addErrback(self._token_err)
Expand All @@ -426,11 +435,10 @@ def put(self, token):
#############################################################
def _token_valid(self, result):
"""Called after the token is decrypted successfully"""
info = result.split(":")
if len(info) != 2:
if len(result) != 2:
raise ValueError("Wrong subscription token components")

self.uaid, self.chid = info
self.uaid, self.chid = result
d = deferToThread(self.ap_settings.router.get_uaid, self.uaid)
d.addCallback(self._uaid_lookup_results)
d.addErrback(self._uaid_not_found_err)
Expand Down
3 changes: 2 additions & 1 deletion autopush/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import sys

import raven
from eliot import add_destination, fields, Logger, MessageType
from eliot import (add_destination, fields,
Logger, MessageType)
from twisted.python.log import textFromEventDict, startLoggingWithObserver

HOSTNAME = socket.getfqdn()
Expand Down
3 changes: 2 additions & 1 deletion autopush/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ def endpoint_main(sysargs=None, use_files=True):

# Endpoint HTTP router
site = cyclone.web.Application([
(r"/push/([^\/]+)", EndpointHandler, dict(ap_settings=settings)),
(r"/push/(?:(v\d+)\/)?([^\/]+)", EndpointHandler,
dict(ap_settings=settings)),
(r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)),
# PUT /register/ => connect info
# GET /register/uaid => chid + endpoint
Expand Down
7 changes: 5 additions & 2 deletions autopush/router/webpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from autopush.protocol import IgnoreBody
from autopush.router.interface import RouterException, RouterResponse
from autopush.router.simple import SimpleRouter
from autopush.db import normalize_id

TTL_URL = "https://webpush-wg.github.io/webpush-protocol/#rfc.section.6.2"

Expand Down Expand Up @@ -72,7 +73,8 @@ def preflight_check(self, uaid, channel_id):
self.ap_settings.message_tables[month_table].all_channels,
uaid=uaid)

if not exists or channel_id not in chans:
if (not exists or channel_id.lower() not
in map(lambda x: normalize_id(x), chans)):
raise RouterException("No such subscription", status_code=404,
log_exception=False, errno=106)
returnValue(month_table)
Expand All @@ -84,7 +86,8 @@ def _send_notification(self, uaid, node_id, notification):
headers for the notification.
"""
payload = {"channelID": notification.channel_id,
# Firefox currently requires channelIDs to be '-' formatted.
payload = {"channelID": normalize_id(notification.channel_id),
"version": notification.version,
"ttl": notification.ttl or 0,
"timestamp": int(time.time()),
Expand Down
71 changes: 67 additions & 4 deletions autopush/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import datetime
import socket

from hashlib import sha256

from cryptography.fernet import Fernet, MultiFernet
from twisted.internet import reactor
from twisted.internet.defer import (
Expand Down Expand Up @@ -232,6 +234,7 @@ def update_rotating_tables(self):
self.current_msg_month = message_table.table_name
self.message_tables[self.current_msg_month] = \
Message(message_table, self.metrics)

returnValue(True)

def update(self, **kwargs):
Expand All @@ -248,7 +251,67 @@ def update(self, **kwargs):
else:
setattr(self, key, val)

def make_endpoint(self, uaid, chid):
""" Create an endpoint from the identifiers"""
return self.endpoint_url + '/push/' + \
self.fernet.encrypt((uaid + ':' + chid).encode('utf8'))
def make_endpoint(self, uaid, chid, key=None):
"""Create an v1 or v2 endpoint from the indentifiers.
Both endpoints use bytes instead of hex to reduce ID length.
v0 is uaid.hex + ':' + chid.hex and is deprecated.
v1 is the uaid + chid
v2 is the uaid + chid + sha256(key).bytes
:param uaid: User Agent Identifier
:param chid: Channel or Subscription ID
:param key: Optional provided Public Key
:returns: Push endpoint
"""
root = self.endpoint_url + '/push/'
base = (uaid.replace('-', '').decode("hex") +
chid.replace('-', '').decode("hex"))

if key is None:
return root + 'v1/' + self.fernet.encrypt(base).strip('=')

return root + 'v2/' + self.fernet.encrypt(base + sha256(key).digest())

def parse_endpoint(self, token, version="v0", public_key=None):
"""Parse an endpoint into component elements of UAID, CHID and optional
key hash if v2
:param token: The obscured subscription data.
:param version: This is the API version of the token.
:param public_key: the public key (from Encryption-Key: p256ecdsa=)
:raises ValueError: In the case of a malformed endpoint.
:returns: a tuple containing the (UAID, CHID)
"""

def comp_digest(a, b):
"""Constant time comparison function, this is not always
backported."""
if len(a) != len(b):
return False
test = 0
for x, y in zip(a, b):
test |= (ord(x) ^ ord(y))
return test == 0

token = self.fernet.decrypt(token.encode('utf8'))

if version == 'v0':
if ':' not in token:
raise ValueError("Corrupted push token")
return tuple(token.split(':'))
if version == 'v1' and len(token) != 32:
raise ValueError("Corrupted push token")
if version == 'v2':
if len(token) != 64:
raise ValueError("Corrupted push token")
if not public_key:
raise ValueError("Invalid key data")
if not comp_digest(sha256(public_key).digest(),
token[32:]):
raise ValueError("Key mismatch")
return (token[:16].encode('hex'), token[16:32].encode('hex'))
Loading

0 comments on commit baf7e67

Please sign in to comment.