diff --git a/autopush/crypto_key.py b/autopush/crypto_key.py new file mode 100644 index 00000000..40ef6844 --- /dev/null +++ b/autopush/crypto_key.py @@ -0,0 +1,93 @@ +"""Crypto-Key header parser and manager""" + + +class CryptoKeyException(Exception): + """Invalid CryptoKey""" + + +class CryptoKey(object): + """Parse the Crypto-Key header per +http://tools.ietf.org/html/draft-ietf-httpbis-encryption-encoding-00#section-4 + + The Crypto-Key header is a data store that has it's own set of rules, + This class manages access to the Crypto-Key data. There are two ways that + data in the Crypto-Key can be used, one by the key-id (an optional + identifier that can be used to associate key data to other fields) and + the data label name. + + There are a few additional functions that may be implemented. For + instance, removing a Crypto-Key component that may not need to be + passed on. + + """ + _values = [] + + def __init__(self, header): + """Parse the Crypto-Key header + + :param header: Header content string. + + """ + chunks = header.split(",") + if self._values: + self._values = [] + for chunk in chunks: + bits = chunk.split(";") + hash = {} + for bit in bits: + try: + (key, value) = bit.split("=", 1) + except ValueError: + raise CryptoKeyException("Invalid Crypto Key value") + hash[key.strip()] = value.strip(' "') + self._values.append(hash) + + def get_keyid(self, keyid): + """Return the Crypto-Key hash referred to by a given keyid. + + For example, for a CryptoKey specified as: + Crypto-Key: keyid="apple";foo="fruit",keyid="gorp";bar="snake" + + get_keyid("apple") would return a hash of + {"keyid": "apple", "foo":"fruit"} + + :param keyid: The keyid to reference + :returns: hash of the matching key data or None + + """ + for val in self._values: + if keyid == val.get('keyid'): + return val + return None + + def get_label(self, label): + """Return the Crypto-Key value referred to by a given label. + + For example, for a CryptoKey specified as: + Crypto-Key: keyid="apple";foo="fruit",keyid="gorp";bar="snake" + + get_label("foo") + would return a value of "apple" + + ..note:: This presumes that "label" is unique. Otherwise it will + only return the FIRST instance of "label". Use get_keyid() if + you know of multiple sections containing similar labels. + + :param label: The label to reference. + :returns: Value associated with the key data or None + + """ + for val in self._values: + if label in val: + return val.get(label) + return None + + def to_string(self): + """Return a reformulated Crypto-Key header string""" + chunks = [] + for val in self._values: + bits = [] + for key in val: + bits.append("{}=\"{}\"".format(key, val[key])) + chunks.append(';'.join(bits)) + return ','.join(chunks) diff --git a/autopush/endpoint.py b/autopush/endpoint.py index 60d77fa9..c1fd9c33 100644 --- a/autopush/endpoint.py +++ b/autopush/endpoint.py @@ -51,7 +51,6 @@ generate_hash, validate_uaid, ErrorLogger, - parse_header, extract_jwt, ) @@ -311,7 +310,7 @@ def _invalid_auth(self, fail): log.msg("Invalid bearer token: " + message, **self._client_info) raise VapidAuthException("Invalid bearer token: " + message) - def _process_auth(self, result, keys): + def _process_auth(self, result): """Process the optional VAPID auth token. VAPID requires two headers to be present; @@ -324,22 +323,16 @@ def _process_auth(self, result, keys): # No auth present, so it's not a VAPID call. if not authorization: return result - if not keys: - raise VapidAuthException("Missing Crypto-Key") - if not isinstance(keys, dict): - raise VapidAuthException("Invalid Crypto-Key") - crypto_key = keys.get('p256ecdsa') - if not crypto_key: - raise VapidAuthException("Invalid bearer token: " - "improperly specified crypto-key") + + public_key = result.get("public_key") try: (auth_type, token) = authorization.split(' ', 1) except ValueError: raise VapidAuthException("Invalid Authorization header") # if it's a bearer token containing what may be a JWT if auth_type.lower() == AUTH_SCHEME and '.' in token: - d = deferToThread(extract_jwt, token, crypto_key) - d.addCallback(self._store_auth, crypto_key, token, result) + d = deferToThread(extract_jwt, token, public_key) + d.addCallback(self._store_auth, public_key, token, result) d.addErrback(self._invalid_auth) return d # otherwise, it's not, so ignore the VAPID data. @@ -412,20 +405,13 @@ def put(self, api_ver="v0", token=None): """ api_ver = api_ver or "v0" self.start_time = time.time() - public_key = None - keys = {} - crypto_key = self.request.headers.get('crypto-key') - if crypto_key: - header_info = parse_header(crypto_key) - keys = header_info[-1] - if isinstance(keys, dict): - public_key = keys.get('p256ecdsa') + crypto_key_header = self.request.headers.get('crypto-key') d = deferToThread(self.ap_settings.parse_endpoint, token, api_ver, - public_key) - d.addCallback(self._process_auth, keys) + crypto_key_header) + d.addCallback(self._process_auth) d.addCallback(self._token_valid) d.addErrback(self._auth_err) d.addErrback(self._token_err) @@ -437,10 +423,8 @@ def put(self, api_ver="v0", token=None): ############################################################# def _token_valid(self, result): """Called after the token is decrypted successfully""" - if len(result) != 2: - raise InvalidTokenException("Wrong subscription token components") - - self.uaid, self.chid = result + self.uaid = result.get("uaid") + self.chid = result.get("chid") d = deferToThread(self.ap_settings.router.get_uaid, self.uaid) d.addCallback(self._uaid_lookup_results) d.addErrback(self._uaid_not_found_err) diff --git a/autopush/settings.py b/autopush/settings.py index 1a5a8bba..669cfe7e 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -39,6 +39,7 @@ ) from autopush.utils import canonical_url, resolve_ip from autopush.senderids import SENDERID_EXPRY, DEFAULT_BUCKET +from autopush.crypto_key import (CryptoKey, CryptoKeyException) VALID_V0_TOKEN = re.compile(r'[0-9A-Za-z-]{32,36}:[0-9A-Za-z-]{32,36}') @@ -280,26 +281,34 @@ def make_endpoint(self, uaid, chid, key=None): return root + 'v2/' + self.fernet.encrypt(base + sha256(key).digest()) - def parse_endpoint(self, token, version="v0", public_key=None): + def parse_endpoint(self, token, version="v0", ckey_header=None): """Parse an endpoint into component elements of UAID, CHID and optional key hash if v2 :param token: The obscured subscription data. :param version: This is the API version of the token. - :param public_key: the public key (from Encryption-Key: p256ecdsa=) + :param ckey_header: the Crypto-Key header bearing the public key + (from Crypto-Key: p256ecdsa=) :raises ValueError: In the case of a malformed endpoint. - :returns: a tuple containing the (UAID, CHID) + :returns: a dict containing (uaid=UAID, chid=CHID, public_key=KEY) """ token = self.fernet.decrypt(token.encode('utf8')) + public_key = None + if ckey_header: + try: + public_key = CryptoKey(ckey_header).get_label('p256ecdsa') + except CryptoKeyException: + raise InvalidTokenException("Invalid key data") if version == 'v0': if not VALID_V0_TOKEN.match(token): raise InvalidTokenException("Corrupted push token") - return tuple(token.split(':')) + items = token.split(':') + return dict(uaid=items[0], chid=items[1], public_key=public_key) if version == 'v1' and len(token) != 32: raise InvalidTokenException("Corrupted push token") if version == 'v2': @@ -310,4 +319,6 @@ def parse_endpoint(self, token, version="v0", public_key=None): if not constant_time.bytes_eq(sha256(public_key).digest(), token[32:]): raise InvalidTokenException("Key mismatch") - return (token[:16].encode('hex'), token[16:32].encode('hex')) + return dict(uaid=token[:16].encode('hex'), + chid=token[16:32].encode('hex'), + public_key=public_key) diff --git a/autopush/tests/test_cryptokey.py b/autopush/tests/test_cryptokey.py new file mode 100644 index 00000000..f423f0d3 --- /dev/null +++ b/autopush/tests/test_cryptokey.py @@ -0,0 +1,42 @@ +import unittest + +from nose.tools import (eq_, ok_) + +from autopush.crypto_key import CryptoKey + + +class CryptoKeyTestCase(unittest.TestCase): + + valid_key = ( + 'keyid="p256dh";dh="BDw9T0eImd4ax818VcYqDK_DOhcuDswKero' + 'YyNkdhYmygoLSDlSiWpuoWYUSSFxi25cyyNTR5k9Ny93DzZc0UI4",' + 'p256ecdsa="BF92zdI_AKcH5Q31_Rr-04bPqOHU_Qg6lAawHbvfQrY' + 'xV_vIsAsHSyaiuyfofvxT8ZVIXccykd4V2Z7iJVfreT8"') + + def test_parse(self): + ckey = CryptoKey(self.valid_key) + eq_(ckey.get_keyid("p256dh"), + {"keyid": "p256dh", + "dh": "BDw9T0eImd4ax818VcYqDK_DOhcuDswKero" + "YyNkdhYmygoLSDlSiWpuoWYUSSFxi25cyyNTR5k9Ny93DzZc0UI4"}) + eq_(ckey.get_label("p256ecdsa"), + "BF92zdI_AKcH5Q31_Rr-04bPqOHU_Qg6lAawHbvfQrY" + "xV_vIsAsHSyaiuyfofvxT8ZVIXccykd4V2Z7iJVfreT8") + ok_(ckey.get_keyid("missing") is None) + ok_(ckey.get_label("missing") is None) + + def test_parse_lenient(self): + ckey = CryptoKey(self.valid_key.replace('"', '')) + str = ckey.to_string() + ckey2 = CryptoKey(str) + ok_(ckey.get_keyid("p256dh"), ckey2.get_keyid("p256dh")) + ok_(ckey.get_label("p256ecdsa") is not None) + ok_(ckey.get_label("p256ecdsa"), ckey2.get_label("p256ecdsa")) + + def test_string(self): + ckey = CryptoKey(self.valid_key) + str = ckey.to_string() + ckey2 = CryptoKey(str) + ok_(ckey.get_keyid("p256dh"), ckey2.get_keyid("p256dh")) + ok_(ckey.get_label("p256ecdsa") is not None) + ok_(ckey.get_label("p256ecdsa"), ckey2.get_label("p256ecdsa")) diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index 49c9f99e..183ca9e0 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -231,12 +231,6 @@ def _check_error(self, code, errno, error): eq_(d.get("errno"), errno) eq_(d.get("error"), error) - def test_parse_header(self): - result = utils.parse_header('fee; fie="foe"; fum=""foobar="five""";') - eq_(result[0], 'fee') - eq_(result[1], {'fum': 'foobar="five"', 'fie': 'foe'}) - eq_([], utils.parse_header("")) - def test_uaid_lookup_results(self): fresult = dict(router_type="test") frouter = Mock(spec=Router) @@ -590,12 +584,7 @@ def handle_finish(result): self.endpoint.version, self.endpoint.data = 789, None - exc = self.assertRaises(InvalidTokenException, - self.endpoint._token_valid, - ['invalid']) - eq_(exc.message, "Wrong subscription token components") - - self.endpoint._token_valid(['123', dummy_chid]) + self.endpoint._token_valid(dict(uaid='123', chid=dummy_chid)) return self.finish_deferred def test_put_default_router(self): @@ -716,7 +705,7 @@ def test_put_invalid_vapid_crypto_header(self): def handle_finish(result): self.assertTrue(result) - self.endpoint.set_status.assert_called_with(401) + self.endpoint.set_status.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) self.endpoint.put(None, dummy_uaid) @@ -770,29 +759,6 @@ def handle_finish(result): self.endpoint.put(None, dummy_uaid) return self.finish_deferred - def test_put_missing_vapid_crypto_header(self): - self.request_mock.headers["encryption"] = "ignored" - self.request_mock.headers["content-encoding"] = 'text' - self.request_mock.headers["authorization"] = "some auth" - self.request_mock.body = b' ' - self.fernet_mock.decrypt.return_value = dummy_token - self.router_mock.get_uaid.return_value = dict( - router_type="webpush", - router_data=dict(), - ) - self.wp_router_mock.route_notification.return_value = RouterResponse( - status_code=200, - router_data={}, - ) - - def handle_finish(result): - self.assertTrue(result) - self.endpoint.set_status.assert_called_with(401) - - self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(None, dummy_uaid) - return self.finish_deferred - def test_post_webpush_with_headers_in_response(self): self.fernet_mock.decrypt.return_value = dummy_token self.endpoint.set_header = Mock() @@ -1286,12 +1252,13 @@ def test_parse_endpoint(self): chid_dec = chid_strip.decode('hex') v1_valid = uaid_dec + chid_dec pub_key = uuid.uuid4().hex + crypto_key = "p256ecdsa=" + pub_key v2_valid = sha256(pub_key).digest() v2_invalid = sha256(uuid.uuid4().hex).digest() # v0 good self.fernet_mock.decrypt.return_value = v0_valid tokens = self.settings.parse_endpoint('/valid') - eq_(tokens, (dummy_uaid, dummy_chid)) + eq_(tokens, dict(uaid=dummy_uaid, chid=dummy_chid, public_key=None)) # v0 bad self.fernet_mock.decrypt.return_value = v1_valid @@ -1308,22 +1275,24 @@ def test_parse_endpoint(self): self.fernet_mock.decrypt.return_value = v1_valid tokens = self.settings.parse_endpoint('valid', 'v1') - eq_(tokens, (uaid_strip, chid_strip)) + eq_(tokens, dict(uaid=uaid_strip, chid=chid_strip, public_key=None)) self.fernet_mock.decrypt.return_value = v1_valid + v2_valid - tokens = self.settings.parse_endpoint('valid', 'v2', pub_key) - eq_(tokens, (uaid_strip, chid_strip)) + tokens = self.settings.parse_endpoint('valid', 'v2', crypto_key) + eq_(tokens, + dict(uaid=uaid_strip, chid=chid_strip, public_key=pub_key)) self.fernet_mock.decrypt.return_value = v1_valid + "invalid" exc = self.assertRaises(InvalidTokenException, self.settings.parse_endpoint, - 'invalid', 'v2', pub_key) + 'invalid', 'v2', crypto_key) eq_(exc.message, "Corrupted push token") self.fernet_mock.decrypt.return_value = v1_valid + v2_valid exc = self.assertRaises(InvalidTokenException, self.settings.parse_endpoint, - 'invalid', 'v2', pub_key[:30]) + 'invalid', 'v2', + "p256ecdsa="+pub_key[:30]) eq_(exc.message, "Key mismatch") self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid @@ -1335,7 +1304,7 @@ def test_parse_endpoint(self): self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid exc = self.assertRaises(InvalidTokenException, self.settings.parse_endpoint, - 'invalid', 'v2', pub_key) + 'invalid', 'v2', crypto_key) eq_(exc.message, "Key mismatch") def test_make_endpoint(self): diff --git a/autopush/utils.py b/autopush/utils.py index d8d0b26d..2f03597f 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -58,49 +58,6 @@ def generate_hash(key, payload): return h.hexdigest() -def parse_header(header, major=";", minor="="): - """Convert a multi-component header line (e.g. "a;b=c;d=e;...") to - a list. - - For example, if the header content line is - - `"a;c=1;b;d=2+3=5"` - - and presuming default values for major and minor, then the - response would be: - - `['a', 'b', {'c': '1', 'd': '2+3=5'}]` - - items defined with values will always appear as a dictionary at the - end of the list. If no items are assigned values, then no dictionary - is appended. - - :param header: Header content line to parse. - :param major: Major item separator. - :param minor: Minor item separator. - - """ - vals = dict() - items = [] - if not header: - return items - for v in map(lambda x: x.strip().split(minor, 1), - header.split(major)): - try: - val = v[1] - # Trim quotes equally off of start and end - # because ""this is "quoted""" is a thing. - while val[0] == val[-1] == '"': - val = val[1:-1] - vals[v[0].lower()] = val - except IndexError: - if len(v[0]): - items.append(v[0].strip('"')) - if vals: - items.append(vals) - return items - - def fix_padding(string): """ Some JWT fields may strip the end padding from base64 strings """ if len(string) % 4: