Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

feature: allow channels to register with public key #382

Merged
merged 1 commit into from
Mar 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions autopush/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ def hasher(uaid):
return uaid


def normalize_id(id):
if not id:
return id
if (len(id) == 36 and
id[8] == id[13] == id[18] == id[23] == '-'):
return id.lower()
raw = filter(lambda x: x in '0123456789abcdef', id.lower())
if len(raw) != 32:
raise ValueError("Invalid UUID")
return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:]))


def make_rotating_tablename(prefix, delta=0):
"""Creates a tablename for table rotation based on a prefix with a given
month delta."""
Expand Down Expand Up @@ -261,7 +273,8 @@ def save_notification(self, uaid, chid, version):
cond = "attribute_not_exists(version) or version < :ver"
conn.put_item(
self.table.table_name,
item=self.encode(dict(uaid=hasher(uaid), chid=chid,
item=self.encode(dict(uaid=hasher(uaid),
chid=normalize_id(chid),
version=version)),
condition_expression=cond,
expression_attribute_values={
Expand All @@ -281,10 +294,12 @@ def delete_notification(self, uaid, chid, version=None):
"""
try:
if version:
self.table.delete_item(uaid=hasher(uaid), chid=chid,
self.table.delete_item(uaid=hasher(uaid),
chid=normalize_id(chid),
expected={"version__eq": version})
else:
self.table.delete_item(uaid=hasher(uaid), chid=chid)
self.table.delete_item(uaid=hasher(uaid),
chid=normalize_id(chid))
return True
except ProvisionedThroughputExceededException:
self.metrics.increment("error.provisioned.delete_notification")
Expand Down Expand Up @@ -312,7 +327,8 @@ def register_channel(self, uaid, channel_id):
db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "})
# Generate our update expression
expr = "ADD chids :channel_id"
expr_values = self.encode({":channel_id": set([channel_id])})
expr_values = self.encode({":channel_id":
set([normalize_id(channel_id)])})
conn.update_item(
self.table.table_name,
db_key,
Expand All @@ -327,7 +343,8 @@ def unregister_channel(self, uaid, channel_id, **kwargs):
conn = self.table.connection
db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "})
expr = "DELETE chids :channel_id"
expr_values = self.encode({":channel_id": set([channel_id])})
expr_values = self.encode({":channel_id":
set([normalize_id(channel_id)])})

result = conn.update_item(
self.table.table_name,
Expand Down Expand Up @@ -375,7 +392,7 @@ def store_message(self, uaid, channel_id, message_id, ttl, data=None,
with the message id"""
item = dict(
uaid=hasher(uaid),
chidmessageid="%s:%s" % (channel_id, message_id),
chidmessageid="%s:%s" % (normalize_id(channel_id), message_id),
data=data,
headers=headers,
ttl=ttl,
Expand Down Expand Up @@ -407,7 +424,7 @@ def update_message(self, uaid, channel_id, message_id, ttl, data=None,
item["headers"] = headers
item["data"] = data
try:
chidmessageid = "%s:%s" % (channel_id, message_id)
chidmessageid = "%s:%s" % (normalize_id(channel_id), message_id)
db_key = self.encode({"uaid": hasher(uaid),
"chidmessageid": chidmessageid})
expr = ("SET #tl=:ttl, #ts=:timestamp,"
Expand Down Expand Up @@ -439,14 +456,16 @@ def delete_message(self, uaid, channel_id, message_id, updateid=None):
try:
self.table.delete_item(
uaid=hasher(uaid),
chidmessageid="%s:%s" % (channel_id, message_id),
chidmessageid="%s:%s" % (normalize_id(channel_id),
message_id),
expected={'updateid__eq': updateid})
except ConditionalCheckFailedException:
return False
else:
self.table.delete_item(
uaid=hasher(uaid),
chidmessageid="%s:%s" % (channel_id, message_id))
chidmessageid="%s:%s" % (normalize_id(channel_id),
message_id))
return True

def delete_messages(self, uaid, chidmessageids):
Expand All @@ -463,7 +482,7 @@ def delete_messages_for_channel(self, uaid, channel_id):
"""Deletes all messages for a uaid/channel_id"""
results = self.table.query_2(
uaid__eq=hasher(uaid),
chidmessageid__beginswith="%s:" % channel_id,
chidmessageid__beginswith="%s:" % normalize_id(channel_id),
consistent=True,
attributes=("chidmessageid",),
)
Expand Down
70 changes: 39 additions & 31 deletions autopush/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@

from autopush.db import (
generate_last_connect,
hasher
hasher,
normalize_id,
)
from autopush.router.interface import RouterException
from autopush.utils import (
Expand Down Expand Up @@ -174,8 +175,8 @@ def chid(self):
@chid.setter
def chid(self, value):
"""Set the ChannelID and record to _client_info"""
self._chid = value
self._client_info["channelID"] = value
self._chid = normalize_id(value)
self._client_info["channelID"] = self._chid

def _init_info(self):
"""Returns a dict of additional client data"""
Expand Down Expand Up @@ -309,7 +310,7 @@ def _invalid_auth(self, fail):
log.msg("Invalid bearer token: " + message, **self._client_info)
raise VapidAuthException("Invalid bearer token: " + message)

def _process_auth(self, result):
def _process_auth(self, result, keys):
"""Process the optional VAPID auth token.

VAPID requires two headers to be present;
Expand All @@ -322,28 +323,26 @@ def _process_auth(self, result):
# No auth present, so it's not a VAPID call.
if not authorization:
return result

header_info = parse_header(self.request.headers.get('crypto-key'))
if not header_info:
if not keys:
raise VapidAuthException("Missing Crypto-Key")
values = header_info[-1]
if isinstance(values, dict):
crypto_key = values.get('p256ecdsa')
try:
(auth_type, token) = authorization.split(' ', 1)
except ValueError:
raise VapidAuthException("Invalid Authorization header")
# if it's a bearer token containing what may be a JWT
if auth_type.lower() == AUTH_SCHEME and '.' in token:
d = deferToThread(extract_jwt, token, crypto_key)
d.addCallback(self._store_auth, crypto_key, token, result)
d.addErrback(self._invalid_auth)
return d
# otherwise, it's not, so ignore the VAPID data.
return result
else:
if not isinstance(keys, dict):
raise VapidAuthException("Invalid Crypto-Key")
crypto_key = keys.get('p256ecdsa')
if not crypto_key:
raise VapidAuthException("Invalid bearer token: "
"improperly specified crypto-key")
try:
(auth_type, token) = authorization.split(' ', 1)
except ValueError:
raise VapidAuthException("Invalid Authorization header")
# if it's a bearer token containing what may be a JWT
if auth_type.lower() == AUTH_SCHEME and '.' in token:
d = deferToThread(extract_jwt, token, crypto_key)
d.addCallback(self._store_auth, crypto_key, token, result)
d.addErrback(self._invalid_auth)
return d
# otherwise, it's not, so ignore the VAPID data.
return result


class MessageHandler(AutoendpointHandler):
Expand Down Expand Up @@ -404,17 +403,27 @@ class EndpointHandler(AutoendpointHandler):
# Cyclone HTTP Methods
#############################################################
@cyclone.web.asynchronous
def put(self, token):
def put(self, api_ver="v0", token=None):
"""HTTP PUT Handler

Primary entry-point to handling a notification for a push client.

"""
self.start_time = time.time()
fernet = self.ap_settings.fernet

d = deferToThread(fernet.decrypt, token.encode('utf8'))
d.addCallback(self._process_auth)
public_key = None
keys = {}
crypto_key = self.request.headers.get('crypto-key')
if crypto_key:
header_info = parse_header(crypto_key)
keys = header_info[-1]
if isinstance(keys, dict):
public_key = keys.get('p256ecdsa')

d = deferToThread(self.ap_settings.parse_endpoint,
token,
api_ver,
public_key)
d.addCallback(self._process_auth, keys)
d.addCallback(self._token_valid)
d.addErrback(self._auth_err)
d.addErrback(self._token_err)
Expand All @@ -426,11 +435,10 @@ def put(self, token):
#############################################################
def _token_valid(self, result):
"""Called after the token is decrypted successfully"""
info = result.split(":")
if len(info) != 2:
if len(result) != 2:
raise ValueError("Wrong subscription token components")

self.uaid, self.chid = info
self.uaid, self.chid = result
d = deferToThread(self.ap_settings.router.get_uaid, self.uaid)
d.addCallback(self._uaid_lookup_results)
d.addErrback(self._uaid_not_found_err)
Expand Down
3 changes: 2 additions & 1 deletion autopush/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import sys

import raven
from eliot import add_destination, fields, Logger, MessageType
from eliot import (add_destination, fields,
Logger, MessageType)
from twisted.python.log import textFromEventDict, startLoggingWithObserver

HOSTNAME = socket.getfqdn()
Expand Down
3 changes: 2 additions & 1 deletion autopush/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ def endpoint_main(sysargs=None, use_files=True):

# Endpoint HTTP router
site = cyclone.web.Application([
(r"/push/([^\/]+)", EndpointHandler, dict(ap_settings=settings)),
(r"/push/(?:(v\d+)\/)?([^\/]+)", EndpointHandler,
dict(ap_settings=settings)),
(r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)),
# PUT /register/ => connect info
# GET /register/uaid => chid + endpoint
Expand Down
7 changes: 5 additions & 2 deletions autopush/router/webpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from autopush.protocol import IgnoreBody
from autopush.router.interface import RouterException, RouterResponse
from autopush.router.simple import SimpleRouter
from autopush.db import normalize_id

TTL_URL = "https://webpush-wg.github.io/webpush-protocol/#rfc.section.6.2"

Expand Down Expand Up @@ -72,7 +73,8 @@ def preflight_check(self, uaid, channel_id):
self.ap_settings.message_tables[month_table].all_channels,
uaid=uaid)

if not exists or channel_id not in chans:
if (not exists or channel_id.lower() not
in map(lambda x: normalize_id(x), chans)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not going to be great to do at scale....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're prolly going to have to file a separate issue to replace the uaid lookup with a lookup that does a set membership test in dynamodb, so that we're not pulling the entire list of channels then checking them.

We can ensure the channels are normalized, by updating the migration code to normalize them on monthly rotation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. We're kinda stuck at the moment because of legacy records. Eventually, they should expire out, but for now, we have to account for them.

raise RouterException("No such subscription", status_code=404,
log_exception=False, errno=106)
returnValue(month_table)
Expand All @@ -84,7 +86,8 @@ def _send_notification(self, uaid, node_id, notification):
headers for the notification.

"""
payload = {"channelID": notification.channel_id,
# Firefox currently requires channelIDs to be '-' formatted.
payload = {"channelID": normalize_id(notification.channel_id),
"version": notification.version,
"ttl": notification.ttl or 0,
"timestamp": int(time.time()),
Expand Down
62 changes: 58 additions & 4 deletions autopush/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import datetime
import socket

from hashlib import sha256

from cryptography.fernet import Fernet, MultiFernet
from cryptography.hazmat.primitives import constant_time
from twisted.internet import reactor
from twisted.internet.defer import (
inlineCallbacks,
Expand Down Expand Up @@ -232,6 +235,7 @@ def update_rotating_tables(self):
self.current_msg_month = message_table.table_name
self.message_tables[self.current_msg_month] = \
Message(message_table, self.metrics)

returnValue(True)

def update(self, **kwargs):
Expand All @@ -248,7 +252,57 @@ def update(self, **kwargs):
else:
setattr(self, key, val)

def make_endpoint(self, uaid, chid):
""" Create an endpoint from the identifiers"""
return self.endpoint_url + '/push/' + \
self.fernet.encrypt((uaid + ':' + chid).encode('utf8'))
def make_endpoint(self, uaid, chid, key=None):
"""Create an v1 or v2 endpoint from the indentifiers.

Both endpoints use bytes instead of hex to reduce ID length.
v0 is uaid.hex + ':' + chid.hex and is deprecated.
v1 is the uaid + chid
v2 is the uaid + chid + sha256(key).bytes

:param uaid: User Agent Identifier
:param chid: Channel or Subscription ID
:param key: Optional provided Public Key
:returns: Push endpoint

"""
root = self.endpoint_url + '/push/'
base = (uaid.replace('-', '').decode("hex") +
chid.replace('-', '').decode("hex"))

if key is None:
return root + 'v1/' + self.fernet.encrypt(base).strip('=')

return root + 'v2/' + self.fernet.encrypt(base + sha256(key).digest())

def parse_endpoint(self, token, version="v0", public_key=None):
"""Parse an endpoint into component elements of UAID, CHID and optional
key hash if v2

:param token: The obscured subscription data.
:param version: This is the API version of the token.
:param public_key: the public key (from Encryption-Key: p256ecdsa=)

:raises ValueError: In the case of a malformed endpoint.

:returns: a tuple containing the (UAID, CHID)

"""

token = self.fernet.decrypt(token.encode('utf8'))

if version == 'v0':
if ':' not in token:
raise ValueError("Corrupted push token")
return tuple(token.split(':'))
if version == 'v1' and len(token) != 32:
raise ValueError("Corrupted push token")
if version == 'v2':
if len(token) != 64:
raise ValueError("Corrupted push token")
if not public_key:
raise ValueError("Invalid key data")
if not constant_time.bytes_eq(sha256(public_key).digest(),
token[32:]):
raise ValueError("Key mismatch")
return (token[:16].encode('hex'), token[16:32].encode('hex'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮 🏆

Loading