Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
Merge pull request #647 from mozilla-services/bug/641
Browse files Browse the repository at this point in the history
feat: update jwt library
  • Loading branch information
jrconlin authored Sep 12, 2016
2 parents d5c05df + 68fccf9 commit 054a9f1
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 67 deletions.
10 changes: 2 additions & 8 deletions autopush/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
hasher,
normalize_id,
)
from autopush.exceptions import InvalidTokenException
from autopush.exceptions import InvalidTokenException, VapidAuthException
from autopush.router.interface import RouterException
from autopush.utils import (
generate_hash,
Expand Down Expand Up @@ -79,11 +79,6 @@
}


class VapidAuthException(Exception):
"""Exception if the VAPID Auth token fails"""
pass


class Notification(namedtuple("Notification",
"version data channel_id headers ttl")):
"""Parsed notification from the request"""
Expand Down Expand Up @@ -339,8 +334,7 @@ def _db_error_handling(self, d):
def _store_auth(self, jwt, crypto_key, token, result):
if jwt.get('exp', 0) < time.time():
raise VapidAuthException("Invalid bearer token: Auth expired")
jwt_crypto_key = base64url_encode(crypto_key)
self._client_info["jwt_crypto_key"] = jwt_crypto_key
self._client_info["jwt_crypto_key"] = crypto_key
for i in jwt:
self._client_info["jwt_" + i] = jwt[i]
return result
Expand Down
5 changes: 5 additions & 0 deletions autopush/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ def __init__(self, message, status_code=400, errno=None, headers=None):
self.status_code = status_code
self.errno = errno
self.headers = {} if headers is None else headers


class VapidAuthException(Exception):
"""Exception if the VAPID Auth token fails"""
pass
21 changes: 12 additions & 9 deletions autopush/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Router,
Message,
)
from autopush.exceptions import InvalidTokenException
from autopush.exceptions import InvalidTokenException, VapidAuthException
from autopush.metrics import (
DatadogMetrics,
TwistedMetrics,
Expand Down Expand Up @@ -324,7 +324,8 @@ def make_endpoint(self, uaid, chid, key=None):
ep = self.fernet.encrypt(base + sha256(raw_key).digest()).strip('=')
return root + 'v2/' + ep

def parse_endpoint(self, token, version="v0", ckey_header=None):
def parse_endpoint(self, token, version="v0", ckey_header=None,
auth_header=None):
"""Parse an endpoint into component elements of UAID, CHID and optional
key hash if v2
Expand All @@ -345,12 +346,7 @@ def parse_endpoint(self, token, version="v0", ckey_header=None):
crypto_key = CryptoKey(ckey_header)
except CryptoKeyException:
raise InvalidTokenException("Invalid key data")
label = crypto_key.get_label('p256ecdsa')
try:
public_key = base64url_decode(label)
except TypeError:
# Ignore missing and malformed app server keys.
pass
public_key = crypto_key.get_label('p256ecdsa')

if version == 'v0':
if not VALID_V0_TOKEN.match(token):
Expand All @@ -360,13 +356,20 @@ def parse_endpoint(self, token, version="v0", ckey_header=None):
if version == 'v1' and len(token) != 32:
raise InvalidTokenException("Corrupted push token")
if version == 'v2':
if not auth_header:
raise VapidAuthException("Missing Authorization Header")
if len(token) != 64:
raise InvalidTokenException("Corrupted push token")
if not public_key:
raise InvalidTokenException("Invalid key data")
if not constant_time.bytes_eq(sha256(public_key).digest(),
try:
decoded_key = base64url_decode(public_key)
except TypeError:
raise InvalidTokenException("Invalid key data")
if not constant_time.bytes_eq(sha256(decoded_key).digest(),
token[32:]):
raise InvalidTokenException("Key mismatch")
return dict(uaid=token[:16].encode('hex'),
chid=token[16:32].encode('hex'),
version=version,
public_key=public_key)
74 changes: 48 additions & 26 deletions autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
import uuid
import random
import base64

from hashlib import sha256

Expand Down Expand Up @@ -35,10 +34,15 @@
has_connected_this_month,
hasher
)
from autopush.exceptions import InvalidTokenException
from autopush.exceptions import InvalidTokenException, VapidAuthException
from autopush.settings import AutopushSettings
from autopush.router.interface import IRouter, RouterResponse
from autopush.utils import (generate_hash, decipher_public_key)
from autopush.utils import (
generate_hash,
decipher_public_key,
base64url_decode,
base64url_encode,
)

mock_dynamodb2 = mock_dynamodb2()
dummy_uaid = uuid.UUID("abad1dea00000000aabbccdd00000000").hex
Expand Down Expand Up @@ -817,7 +821,7 @@ def _gen_jwt(self, header, payload):
sk256p = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p)
vk = sk256p.get_verifying_key()
sig = jws.sign(payload, sk256p, algorithm="ES256").strip('=')
crypto_key = utils.base64url_encode(vk.to_string()).strip('=')
crypto_key = utils.base64url_encode('\x04' + vk.to_string()).strip('=')
return (sig, crypto_key)

def test_post_webpush_with_vapid_auth(self):
Expand Down Expand Up @@ -869,19 +873,18 @@ def test_decipher_public_key(self):
payload = {"aud": "https://pusher_origin.example.com",
"exp": int(time.time()) + 86400,
"sub": "mailto:[email protected]"}
(_, crypto) = self._gen_jwt(header, payload)
crypto_key = utils.base64url_decode(crypto)

spki_header = ('0V0\x10\x06\x04+\x81\x04p\x06\x08*'
'\x86H\xce=\x03\x01\x07\x03B\x00\x04')
eq_(decipher_public_key(crypto_key), crypto_key)
eq_(decipher_public_key('\x04' + crypto_key), crypto_key)
eq_(decipher_public_key(spki_header + crypto_key), crypto_key)
(_, crypto_key) = self._gen_jwt(header, payload)
base_key = base64url_decode(crypto_key)[-64:]
s_head = ('0V0\x10\x06\x04+\x81\x04p\x06\x08*'
'\x86H\xce=\x03\x01\x07\x03B\x00\x04')
s_key = base64url_encode(s_head + base_key)
eq_(decipher_public_key(crypto_key), base_key)
eq_(decipher_public_key(s_key), base_key)
assert_raises(ValueError, decipher_public_key, "banana")
crap = ''.join([random.choice('012345abcdef') for i in range(0, 64)])
assert_raises(ValueError, decipher_public_key, '\x05' + crap)
assert_raises(ValueError, decipher_public_key,
crap[:len(spki_header)] + crap)
crap[:len(s_head)] + crap)
assert_raises(ValueError, decipher_public_key, crap[:60])

def test_post_webpush_with_other_than_vapid_auth(self):
Expand Down Expand Up @@ -993,8 +996,7 @@ def test_util_extract_jwt(self):
"sub": "mailto:[email protected]"}

(sig, crypto_key) = self._gen_jwt(header, payload)
eq_(utils.extract_jwt(sig, utils.base64url_decode(crypto_key)),
payload)
eq_(utils.extract_jwt(sig, crypto_key), payload)

def test_post_webpush_bad_sig(self):
self.fernet_mock.decrypt.return_value = dummy_token
Expand Down Expand Up @@ -1333,11 +1335,12 @@ def test_v2_padded_fernet_decode(self):
if len(token) % 4:
break
reply = self.settings.parse_endpoint(token, "v2",
"p256ecdsa=" + dummy_key)
eq_(reply, {'public_key': base64.urlsafe_b64decode(
utils.repad(dummy_key)),
"p256ecdsa=" + dummy_key,
auth_header="webpush header")
eq_(reply, {'public_key': dummy_key,
'chid': dummy_chid.replace('-', ''),
'uaid': dummy_uaid.replace('-', '')})
'uaid': dummy_uaid.replace('-', ''),
'version': 'v2'})

def test_parse_endpoint(self):
v0_valid = dummy_uaid + ":" + dummy_chid
Expand Down Expand Up @@ -1369,34 +1372,53 @@ def test_parse_endpoint(self):

self.fernet_mock.decrypt.return_value = v1_valid
tokens = self.settings.parse_endpoint('valid', 'v1')
eq_(tokens, dict(uaid=uaid_strip, chid=chid_strip, public_key=None))
eq_(tokens, dict(uaid=uaid_strip, chid=chid_strip, public_key=None,
version="v1"))

self.fernet_mock.decrypt.return_value = v1_valid + v2_valid
tokens = self.settings.parse_endpoint('valid', 'v2', crypto_key)
tokens = self.settings.parse_endpoint('valid', 'v2', crypto_key,
auth_header="webpush auth")
eq_(tokens,
dict(uaid=uaid_strip, chid=chid_strip, public_key=raw_pub_key))
dict(uaid=uaid_strip, chid=chid_strip, public_key=pub_key,
version="v2"))

self.fernet_mock.decrypt.return_value = v1_valid + v2_valid
with assert_raises(VapidAuthException) as cx:
self.settings.parse_endpoint('valid', 'v2', crypto_key)
eq_(cx.exception.message, "Missing Authorization Header")

self.fernet_mock.decrypt.return_value = v1_valid + "invalid"
with assert_raises(InvalidTokenException) as cx:
self.settings.parse_endpoint('invalid', 'v2', crypto_key)
self.settings.parse_endpoint('invalid', 'v2', crypto_key,
auth_header="webpush header")
eq_(cx.exception.message, "Corrupted push token")

self.fernet_mock.decrypt.return_value = v1_valid + v2_valid
with assert_raises(InvalidTokenException) as cx:
self.settings.parse_endpoint('invalid', 'v2',
"p256ecdsa="+pub_key[:12])
"p256ecdsa="+pub_key[:12],
auth_header="webpush header")
eq_(cx.exception.message, "Key mismatch")

self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid
with assert_raises(InvalidTokenException) as cx:
self.settings.parse_endpoint('invalid', 'v2')
self.settings.parse_endpoint('invalid', 'v2', None,
auth_header="webpush header")
eq_(cx.exception.message, "Invalid key data")

self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid
with assert_raises(InvalidTokenException) as cx:
self.settings.parse_endpoint('invalid', 'v2', crypto_key)
self.settings.parse_endpoint('invalid', 'v2', crypto_key,
auth_header="webpush header")
eq_(cx.exception.message, "Key mismatch")

with assert_raises(InvalidTokenException) as cx:
ck = crypto_key[:30]+'z.'
self.settings.parse_endpoint('invalid', 'v2',
"p256ecdsa="+ck,
auth_header="webpush header")
eq_(cx.exception.message, "Invalid key data")

def test_make_endpoint(self):

def echo(val):
Expand Down
Loading

0 comments on commit 054a9f1

Please sign in to comment.