diff --git a/autopush/jwt.py b/autopush/jwt.py new file mode 100644 index 00000000..1fb711c1 --- /dev/null +++ b/autopush/jwt.py @@ -0,0 +1,107 @@ +import base64 +import binascii +import json + +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec, utils +from cryptography.hazmat.primitives import hashes +from pyasn1.error import PyAsn1Error +from twisted.logger import Logger + + +def repad(string): + # type: (str) -> str + """Adds padding to strings for base64 decoding""" + if len(string) % 4: + string += '===='[len(string) % 4:] + return string + + +class VerifyJWT(object): + """Minimally verify a Vapid JWT object. + + Why hand roll? Most python JWT libraries either use a python elliptic + curve library directly, or call one that does, or is abandoned, or a + dozen other reasons. + + After spending half a day looking for reasonable replacements, I + decided to just write the functions we need directly. + + THIS IS NOT A FULL JWT REPLACEMENT. + + """ + + @staticmethod + def extract_signature(auth): + # type: (str) -> tuple() + """Fix the JWT auth token. + + The JWA spec defines the signature to be a pair of 32octet encoded + longs. + The `ecdsa` library signs using a raw, 32octet pair of values (s, r). + Cryptography, which uses OpenSSL, uses a DER sequence of (s, r). + This function converts the raw ecdsa to DER. + + :param auth: A JWT authorization token. + :type auth: str + + :return tuple containing the signature material and signature + + """ + payload, asig = auth.encode('utf8').rsplit(".", 1) + sig = base64.urlsafe_b64decode(repad(asig)) + if len(sig) != 64: + return payload, sig + + encoded = utils.encode_dss_signature( + s=int(binascii.hexlify(sig[32:]), 16), + r=int(binascii.hexlify(sig[:32]), 16) + ) + return payload, encoded + + @staticmethod + def decode(token, key): + # type (str, str) -> dict() + """Decode a web token into a assertion dictionary. + + This attempts to rectify both ecdsa and openssl generated + signatures. We use the built-in cryptography library since it wraps + libssl and is faster than the python only approach. + + :param token: VAPID auth token + :type token: str + :param key: bitarray containing public key + :type key: str or bitarray + + :return dict of the VAPID claims + + :raise InvalidSignature + + """ + # convert the signature if needed. + try: + sig_material, signature = VerifyJWT.extract_signature(token) + pkey = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), + key + ).public_key(default_backend()) + # NOTE: verify() will take any string as the signature. It appears + # to be doing lazy verification and matching strings rather than + # comparing content values. If the signatures start failing for + # some unknown reason in the future, decode the signature and + # make sure it matches how we're reconstructing it. + # This will raise an InvalidSignature exception if failure. + # It will be captured externally. + pkey.verify( + signature, + sig_material.encode('utf8'), + ec.ECDSA(hashes.SHA256())) + return json.loads( + base64.urlsafe_b64decode( + repad(sig_material.split('.')[1]).encode('utf8'))) + except (ValueError, TypeError, binascii.Error, PyAsn1Error): + raise InvalidSignature() + except Exception: # pragma: no cover + Logger().failure("Unexpected error processing JWT") + raise InvalidSignature() diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index 69b35fe0..7af20c91 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -1,11 +1,10 @@ import json import uuid -import jose import twisted.internet.base from cryptography.fernet import Fernet, InvalidToken from cyclone.web import Application -from mock import Mock, patch +from mock import Mock from moto import mock_dynamodb2 from nose.tools import eq_, ok_ from twisted.internet.defer import Deferred @@ -538,28 +537,6 @@ def restore(*args, **kwargs): chid=str(dummy_chid))) return self.finish_deferred - @patch('jose.jws.verify', side_effect=jose.exceptions.JWTError) - def test_post_bad_jwt(self, *args): - self.reg.request.body = json.dumps(dict( - channelID=str(dummy_chid), - )) - - def handle_finish(value): - self._check_error(401, 109, 'Unauthorized') - - def restore(*args, **kwargs): - uuid.uuid4 = old_func - - old_func = uuid.uuid4 - uuid.uuid4 = lambda: dummy_uaid - self.finish_deferred.addBoth(restore) - self.finish_deferred.addCallback(handle_finish) - self.reg.request.headers["Authorization"] = "WebPush Dummy" - self.reg.post(self._make_req(router_type="webpush", - uaid=dummy_uaid.hex, - chid=str(dummy_chid))) - return self.finish_deferred - def test_post_uaid_chid(self, *args): self.reg.request.body = json.dumps(dict( type="simplepush", diff --git a/autopush/tests/test_web_validation.py b/autopush/tests/test_web_validation.py index 5e2634f1..99b09ee6 100644 --- a/autopush/tests/test_web_validation.py +++ b/autopush/tests/test_web_validation.py @@ -8,8 +8,8 @@ ItemNotFound, ) from cryptography.fernet import InvalidToken +from cryptography.exceptions import InvalidSignature from jose import jws -from jose.exceptions import JWTClaimsError from marshmallow import Schema, fields from mock import Mock, patch from moto import mock_dynamodb2 @@ -1005,7 +1005,7 @@ def test_invalid_encryption_header(self, mock_jwt): def test_invalid_encryption_jwt(self, mock_jwt): schema = self._make_fut() # use a deeply superclassed error to make sure that it gets picked up. - mock_jwt.side_effect = JWTClaimsError("invalid claim") + mock_jwt.side_effect = InvalidSignature("invalid signature") header = {"typ": "JWT", "alg": "ES256"} payload = {"aud": "https://push.example.com", diff --git a/autopush/utils.py b/autopush/utils.py index c808b6ca..150bf281 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -7,7 +7,6 @@ import time import uuid -import ecdsa import requests from attr import ( Factory, @@ -16,7 +15,6 @@ ) from boto.dynamodb2.items import Item # noqa from cryptography.fernet import Fernet # noqa -from jose import jwt from typing import ( # noqa Any, Dict, @@ -27,6 +25,7 @@ from ua_parser import user_agent_parser from autopush.exceptions import InvalidTokenException +from autopush.jwt import repad, VerifyJWT as jwt # Remove trailing padding characters from complex header items like @@ -122,14 +121,6 @@ def base64url_encode(string): return base64.urlsafe_b64encode(string).strip('=') -def repad(string): - # type: (str) -> str - """Adds padding to strings for base64 decoding""" - if len(string) % 4: - string += '===='[len(string) % 4:] - return string - - def base64url_decode(string): # type: (str) -> str """Decodes a Base64 URL-encoded string per RFC 7515. @@ -171,11 +162,11 @@ def decipher_public_key(key_data): # key data is actually a raw coordinate pair key_data = base64url_decode(key_data) key_len = len(key_data) - if key_len == 64: + if key_len == 65 and key_data[0] == '\x04': return key_data # Key format is "raw" - if key_len == 65 and key_data[0] == '\x04': - return key_data[-64:] + if key_len == 64: + return '\04' + key_data # key format is "spki" if key_len == 88 and key_data[:3] == '0V0': return key_data[-64:] @@ -188,18 +179,7 @@ def extract_jwt(token, crypto_key): # first split and convert the jwt. if not token or not crypto_key: return {} - - key = decipher_public_key(crypto_key) - vk = ecdsa.VerifyingKey.from_string(key, curve=ecdsa.NIST256p) - # jose offers jwt.decode(token, vk, ...) which does a full check - # on the JWT object. Vapid is a bit more creative in how it - # stores data into a JWT and breaks expectations. We would have to - # turn off most of the validation in order for it to be useful. - return jwt.decode(token, dict(keys=[vk]), options=dict( - verify_aud=False, - verify_sub=False, - verify_exp=False, - )) + return jwt.decode(token, decipher_public_key(crypto_key)) def parse_user_agent(agent_string): diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 090318bb..2cb6d4ba 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -3,7 +3,7 @@ from boto.dynamodb2.exceptions import ItemNotFound from cryptography.fernet import InvalidToken -from jose import JOSEError +from cryptography.exceptions import InvalidSignature from marshmallow import ( Schema, fields, @@ -321,7 +321,7 @@ def validate_auth(self, d): try: jwt = extract_jwt(token, public_key) - except (AssertionError, ValueError, JOSEError): + except (ValueError, InvalidSignature, Exception): raise InvalidRequest("Invalid Authorization Header", status_code=401, errno=109, headers={"www-authenticate": PREF_SCHEME}) diff --git a/docs/api.rst b/docs/api.rst index 271481a9..1051a208 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -31,3 +31,4 @@ documentation is organized alphabetically by module name. api/ssl api/utils api/websocket + api/jwt diff --git a/docs/api/jwt.rst b/docs/api/jwt.rst new file mode 100644 index 00000000..6da6c6ec --- /dev/null +++ b/docs/api/jwt.rst @@ -0,0 +1,10 @@ +.. _jwt_module: + +:mod:`autopush.jwt` +------------------------ + +.. automodule:: autopush.jwt + +.. autoclass:: VerifyJWT + :members: + :member-order: bysource diff --git a/requirements.txt b/requirements.txt index 77c90a21..600d605c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,6 @@ cryptography==1.7.2 cyclone==1.1 datadog==0.13.0 decorator==4.0.10 -ecdsa==0.13 enum34==1.1.6 future==0.15.2 futures==3.0.5 @@ -32,7 +31,6 @@ pycparser==2.17 pycrypto==2.6.1 pyfcm==1.0.4 python-dateutil==2.5.3 -python-jose==1.3.2 raven==5.25.0 requests==2.11.0 service-identity==16.0.0 diff --git a/test-requirements.txt b/test-requirements.txt index c546e86a..924ec702 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,6 +1,8 @@ -r requirements.txt nose coverage +ecdsa==0.13 +python-jose==1.3.2 mock>=1.0.1 funcsigs==1.0.2 pbr==1.10.0