Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
bug: decode and process crypto-key header correctly
Browse files Browse the repository at this point in the history
The Crypto-Key header has a different format than what was first
understood. A valid Crypto-Key header may look like `keyid="foo";
key1="data",key2="data"`. This patch specialized Crypto-Key parsing
to a class.

Closes #410
  • Loading branch information
jrconlin committed Mar 17, 2016
1 parent a2af7ff commit f546ed7
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 117 deletions.
93 changes: 93 additions & 0 deletions autopush/crypto_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Crypto-Key header parser and manager"""


class CryptoKeyException(Exception):
"""Invalid CryptoKey"""


class CryptoKey(object):
"""Parse the Crypto-Key header per
http://tools.ietf.org/html/draft-ietf-httpbis-encryption-encoding-00#section-4
The Crypto-Key header is a data store that has it's own set of rules,
This class manages access to the Crypto-Key data. There are two ways that
data in the Crypto-Key can be used, one by the key-id (an optional
identifier that can be used to associate key data to other fields) and
the data label name.
There are a few additional functions that may be implemented. For
instance, removing a Crypto-Key component that may not need to be
passed on.
"""
_values = []

def __init__(self, header):
"""Parse the Crypto-Key header
:param header: Header content string.
"""
chunks = header.split(",")
if self._values:
self._values = []
for chunk in chunks:
bits = chunk.split(";")
hash = {}
for bit in bits:
try:
(key, value) = bit.split("=", 1)
except ValueError:
raise CryptoKeyException("Invalid Crypto Key value")
hash[key.strip()] = value.strip(' "')
self._values.append(hash)

def get_keyid(self, keyid):
"""Return the Crypto-Key hash referred to by a given keyid.
For example, for a CryptoKey specified as:
Crypto-Key: keyid="apple";foo="fruit",keyid="gorp";bar="snake"
get_keyid("apple") would return a hash of
{"keyid": "apple", "foo":"fruit"}
:param keyid: The keyid to reference
:returns: hash of the matching key data or None
"""
for val in self._values:
if keyid == val.get('keyid'):
return val
return None

def get_label(self, label):
"""Return the Crypto-Key value referred to by a given label.
For example, for a CryptoKey specified as:
Crypto-Key: keyid="apple";foo="fruit",keyid="gorp";bar="snake"
get_label("foo")
would return a value of "apple"
..note:: This presumes that "label" is unique. Otherwise it will
only return the FIRST instance of "label". Use get_keyid() if
you know of multiple sections containing similar labels.
:param label: The label to reference.
:returns: Value associated with the key data or None
"""
for val in self._values:
if label in val:
return val.get(label)
return None

def to_string(self):
"""Return a reformulated Crypto-Key header string"""
chunks = []
for val in self._values:
bits = []
for key in val:
bits.append("{}=\"{}\"".format(key, val[key]))
chunks.append(';'.join(bits))
return ','.join(chunks)
36 changes: 10 additions & 26 deletions autopush/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
generate_hash,
validate_uaid,
ErrorLogger,
parse_header,
extract_jwt,
)

Expand Down Expand Up @@ -311,7 +310,7 @@ def _invalid_auth(self, fail):
log.msg("Invalid bearer token: " + message, **self._client_info)
raise VapidAuthException("Invalid bearer token: " + message)

def _process_auth(self, result, keys):
def _process_auth(self, result):
"""Process the optional VAPID auth token.
VAPID requires two headers to be present;
Expand All @@ -324,22 +323,16 @@ def _process_auth(self, result, keys):
# No auth present, so it's not a VAPID call.
if not authorization:
return result
if not keys:
raise VapidAuthException("Missing Crypto-Key")
if not isinstance(keys, dict):
raise VapidAuthException("Invalid Crypto-Key")
crypto_key = keys.get('p256ecdsa')
if not crypto_key:
raise VapidAuthException("Invalid bearer token: "
"improperly specified crypto-key")

public_key = result.get("public_key")
try:
(auth_type, token) = authorization.split(' ', 1)
except ValueError:
raise VapidAuthException("Invalid Authorization header")
# if it's a bearer token containing what may be a JWT
if auth_type.lower() == AUTH_SCHEME and '.' in token:
d = deferToThread(extract_jwt, token, crypto_key)
d.addCallback(self._store_auth, crypto_key, token, result)
d = deferToThread(extract_jwt, token, public_key)
d.addCallback(self._store_auth, public_key, token, result)
d.addErrback(self._invalid_auth)
return d
# otherwise, it's not, so ignore the VAPID data.
Expand Down Expand Up @@ -412,20 +405,13 @@ def put(self, api_ver="v0", token=None):
"""
api_ver = api_ver or "v0"
self.start_time = time.time()
public_key = None
keys = {}
crypto_key = self.request.headers.get('crypto-key')
if crypto_key:
header_info = parse_header(crypto_key)
keys = header_info[-1]
if isinstance(keys, dict):
public_key = keys.get('p256ecdsa')
crypto_key_header = self.request.headers.get('crypto-key')

d = deferToThread(self.ap_settings.parse_endpoint,
token,
api_ver,
public_key)
d.addCallback(self._process_auth, keys)
crypto_key_header)
d.addCallback(self._process_auth)
d.addCallback(self._token_valid)
d.addErrback(self._auth_err)
d.addErrback(self._token_err)
Expand All @@ -437,10 +423,8 @@ def put(self, api_ver="v0", token=None):
#############################################################
def _token_valid(self, result):
"""Called after the token is decrypted successfully"""
if len(result) != 2:
raise InvalidTokenException("Wrong subscription token components")

self.uaid, self.chid = result
self.uaid = result.get("uaid")
self.chid = result.get("chid")
d = deferToThread(self.ap_settings.router.get_uaid, self.uaid)
d.addCallback(self._uaid_lookup_results)
d.addErrback(self._uaid_not_found_err)
Expand Down
21 changes: 16 additions & 5 deletions autopush/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from autopush.utils import canonical_url, resolve_ip
from autopush.senderids import SENDERID_EXPRY, DEFAULT_BUCKET
from autopush.crypto_key import (CryptoKey, CryptoKeyException)


VALID_V0_TOKEN = re.compile(r'[0-9A-Za-z-]{32,36}:[0-9A-Za-z-]{32,36}')
Expand Down Expand Up @@ -280,26 +281,34 @@ def make_endpoint(self, uaid, chid, key=None):

return root + 'v2/' + self.fernet.encrypt(base + sha256(key).digest())

def parse_endpoint(self, token, version="v0", public_key=None):
def parse_endpoint(self, token, version="v0", ckey_header=None):
"""Parse an endpoint into component elements of UAID, CHID and optional
key hash if v2
:param token: The obscured subscription data.
:param version: This is the API version of the token.
:param public_key: the public key (from Encryption-Key: p256ecdsa=)
:param ckey_header: the Crypto-Key header bearing the public key
(from Crypto-Key: p256ecdsa=)
:raises ValueError: In the case of a malformed endpoint.
:returns: a tuple containing the (UAID, CHID)
:returns: a dict containing (uaid=UAID, chid=CHID, public_key=KEY)
"""

token = self.fernet.decrypt(token.encode('utf8'))
public_key = None
if ckey_header:
try:
public_key = CryptoKey(ckey_header).get_label('p256ecdsa')
except CryptoKeyException:
raise InvalidTokenException("Invalid key data")

if version == 'v0':
if not VALID_V0_TOKEN.match(token):
raise InvalidTokenException("Corrupted push token")
return tuple(token.split(':'))
items = token.split(':')
return dict(uaid=items[0], chid=items[1], public_key=public_key)
if version == 'v1' and len(token) != 32:
raise InvalidTokenException("Corrupted push token")
if version == 'v2':
Expand All @@ -310,4 +319,6 @@ def parse_endpoint(self, token, version="v0", public_key=None):
if not constant_time.bytes_eq(sha256(public_key).digest(),
token[32:]):
raise InvalidTokenException("Key mismatch")
return (token[:16].encode('hex'), token[16:32].encode('hex'))
return dict(uaid=token[:16].encode('hex'),
chid=token[16:32].encode('hex'),
public_key=public_key)
42 changes: 42 additions & 0 deletions autopush/tests/test_cryptokey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest

from nose.tools import (eq_, ok_)

from autopush.crypto_key import CryptoKey


class CryptoKeyTestCase(unittest.TestCase):

valid_key = (
'keyid="p256dh";dh="BDw9T0eImd4ax818VcYqDK_DOhcuDswKero'
'YyNkdhYmygoLSDlSiWpuoWYUSSFxi25cyyNTR5k9Ny93DzZc0UI4",'
'p256ecdsa="BF92zdI_AKcH5Q31_Rr-04bPqOHU_Qg6lAawHbvfQrY'
'xV_vIsAsHSyaiuyfofvxT8ZVIXccykd4V2Z7iJVfreT8"')

def test_parse(self):
ckey = CryptoKey(self.valid_key)
eq_(ckey.get_keyid("p256dh"),
{"keyid": "p256dh",
"dh": "BDw9T0eImd4ax818VcYqDK_DOhcuDswKero"
"YyNkdhYmygoLSDlSiWpuoWYUSSFxi25cyyNTR5k9Ny93DzZc0UI4"})
eq_(ckey.get_label("p256ecdsa"),
"BF92zdI_AKcH5Q31_Rr-04bPqOHU_Qg6lAawHbvfQrY"
"xV_vIsAsHSyaiuyfofvxT8ZVIXccykd4V2Z7iJVfreT8")
ok_(ckey.get_keyid("missing") is None)
ok_(ckey.get_label("missing") is None)

def test_parse_lenient(self):
ckey = CryptoKey(self.valid_key.replace('"', ''))
str = ckey.to_string()
ckey2 = CryptoKey(str)
ok_(ckey.get_keyid("p256dh"), ckey2.get_keyid("p256dh"))
ok_(ckey.get_label("p256ecdsa") is not None)
ok_(ckey.get_label("p256ecdsa"), ckey2.get_label("p256ecdsa"))

def test_string(self):
ckey = CryptoKey(self.valid_key)
str = ckey.to_string()
ckey2 = CryptoKey(str)
ok_(ckey.get_keyid("p256dh"), ckey2.get_keyid("p256dh"))
ok_(ckey.get_label("p256ecdsa") is not None)
ok_(ckey.get_label("p256ecdsa"), ckey2.get_label("p256ecdsa"))
55 changes: 12 additions & 43 deletions autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,6 @@ def _check_error(self, code, errno, error):
eq_(d.get("errno"), errno)
eq_(d.get("error"), error)

def test_parse_header(self):
result = utils.parse_header('fee; fie="foe"; fum=""foobar="five""";')
eq_(result[0], 'fee')
eq_(result[1], {'fum': 'foobar="five"', 'fie': 'foe'})
eq_([], utils.parse_header(""))

def test_uaid_lookup_results(self):
fresult = dict(router_type="test")
frouter = Mock(spec=Router)
Expand Down Expand Up @@ -590,12 +584,7 @@ def handle_finish(result):

self.endpoint.version, self.endpoint.data = 789, None

exc = self.assertRaises(InvalidTokenException,
self.endpoint._token_valid,
['invalid'])
eq_(exc.message, "Wrong subscription token components")

self.endpoint._token_valid(['123', dummy_chid])
self.endpoint._token_valid(dict(uaid='123', chid=dummy_chid))
return self.finish_deferred

def test_put_default_router(self):
Expand Down Expand Up @@ -716,7 +705,7 @@ def test_put_invalid_vapid_crypto_header(self):

def handle_finish(result):
self.assertTrue(result)
self.endpoint.set_status.assert_called_with(401)
self.endpoint.set_status.assert_called_with(400)

self.finish_deferred.addCallback(handle_finish)
self.endpoint.put(None, dummy_uaid)
Expand Down Expand Up @@ -770,29 +759,6 @@ def handle_finish(result):
self.endpoint.put(None, dummy_uaid)
return self.finish_deferred

def test_put_missing_vapid_crypto_header(self):
self.request_mock.headers["encryption"] = "ignored"
self.request_mock.headers["content-encoding"] = 'text'
self.request_mock.headers["authorization"] = "some auth"
self.request_mock.body = b' '
self.fernet_mock.decrypt.return_value = dummy_token
self.router_mock.get_uaid.return_value = dict(
router_type="webpush",
router_data=dict(),
)
self.wp_router_mock.route_notification.return_value = RouterResponse(
status_code=200,
router_data={},
)

def handle_finish(result):
self.assertTrue(result)
self.endpoint.set_status.assert_called_with(401)

self.finish_deferred.addCallback(handle_finish)
self.endpoint.put(None, dummy_uaid)
return self.finish_deferred

def test_post_webpush_with_headers_in_response(self):
self.fernet_mock.decrypt.return_value = dummy_token
self.endpoint.set_header = Mock()
Expand Down Expand Up @@ -1286,12 +1252,13 @@ def test_parse_endpoint(self):
chid_dec = chid_strip.decode('hex')
v1_valid = uaid_dec + chid_dec
pub_key = uuid.uuid4().hex
crypto_key = "p256ecdsa=" + pub_key
v2_valid = sha256(pub_key).digest()
v2_invalid = sha256(uuid.uuid4().hex).digest()
# v0 good
self.fernet_mock.decrypt.return_value = v0_valid
tokens = self.settings.parse_endpoint('/valid')
eq_(tokens, (dummy_uaid, dummy_chid))
eq_(tokens, dict(uaid=dummy_uaid, chid=dummy_chid, public_key=None))

# v0 bad
self.fernet_mock.decrypt.return_value = v1_valid
Expand All @@ -1308,22 +1275,24 @@ def test_parse_endpoint(self):

self.fernet_mock.decrypt.return_value = v1_valid
tokens = self.settings.parse_endpoint('valid', 'v1')
eq_(tokens, (uaid_strip, chid_strip))
eq_(tokens, dict(uaid=uaid_strip, chid=chid_strip, public_key=None))

self.fernet_mock.decrypt.return_value = v1_valid + v2_valid
tokens = self.settings.parse_endpoint('valid', 'v2', pub_key)
eq_(tokens, (uaid_strip, chid_strip))
tokens = self.settings.parse_endpoint('valid', 'v2', crypto_key)
eq_(tokens,
dict(uaid=uaid_strip, chid=chid_strip, public_key=pub_key))

self.fernet_mock.decrypt.return_value = v1_valid + "invalid"
exc = self.assertRaises(InvalidTokenException,
self.settings.parse_endpoint,
'invalid', 'v2', pub_key)
'invalid', 'v2', crypto_key)
eq_(exc.message, "Corrupted push token")

self.fernet_mock.decrypt.return_value = v1_valid + v2_valid
exc = self.assertRaises(InvalidTokenException,
self.settings.parse_endpoint,
'invalid', 'v2', pub_key[:30])
'invalid', 'v2',
"p256ecdsa="+pub_key[:30])
eq_(exc.message, "Key mismatch")

self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid
Expand All @@ -1335,7 +1304,7 @@ def test_parse_endpoint(self):
self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid
exc = self.assertRaises(InvalidTokenException,
self.settings.parse_endpoint,
'invalid', 'v2', pub_key)
'invalid', 'v2', crypto_key)
eq_(exc.message, "Key mismatch")

def test_make_endpoint(self):
Expand Down
Loading

0 comments on commit f546ed7

Please sign in to comment.