From 333df8559099cfaee9197fc2387e02a12f7184ac Mon Sep 17 00:00:00 2001 From: jrconlin Date: Tue, 28 Mar 2017 15:31:48 -0700 Subject: [PATCH] feat: Use cryptography based JWT parser for increased speed Switches to using python cryptography library instead of jose/jwt (which relied on python ecdsa library). Also discovered lots of fun discrepencies between how ecdsa and libssl sign things. closes #785 --- autopush/jwt.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++ autopush/utils.py | 22 +++-------- 2 files changed, 103 insertions(+), 17 deletions(-) create mode 100644 autopush/jwt.py diff --git a/autopush/jwt.py b/autopush/jwt.py new file mode 100644 index 00000000..2f9aed87 --- /dev/null +++ b/autopush/jwt.py @@ -0,0 +1,98 @@ +import base64 +import binascii +import json +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import hashes + +from pyasn1.type import univ, namedtype +from pyasn1.codec.der.encoder import encode + + +"""Why hand roll? +Most python JWT libraries either use a python elliptic curve library directly, +or call one that does, or is abandoned, or a dozen other reasons. + +After spending half a day looking for reasonable replacements, I decided to +just write the functions we need directly. + +THIS IS NOT A FULL JWT REPLACEMENT. + +""" + + +def repad(string): + # type: (str) -> str + """Adds padding to strings for base64 decoding""" + if len(string) % 4: + string += '===='[len(string) % 4:] + return string + + +class VerifyJWT: + + def __init__(self): # pragma: no cover + pass + + @classmethod + def raw_sig_to_der(cls, auth): + # type: (cls, str) -> str + """Fix the JWT auth token. + + The `ecdsa` library signs using a raw, 32octet pair of values (r,s). + Cryptography, which uses OpenSSL, uses a DER sequence of (s, r). + This function converts the raw ecdsa to DER. + + """ + payload, asig = auth.rsplit(".", 1) + sig = base64.urlsafe_b64decode(repad(asig).encode('utf8')) + if len(sig) != 64: + return auth + + # ecdsa and openssl transpose the "r" and "s" values of the signatures. + # for ecdsa signature is (r, s) + # for openssl, signature is (s, r) + # It's ok, though, because neither label them, even though openssl + # uses namedtypes, with the names set to "" + # + # Oh frabjous day! + ds = univ.Sequence( + componentType=namedtype.NamedTypes() + ).setComponents( + univ.Integer(int(binascii.hexlify(sig[:32]), 16)), + univ.Integer(int(binascii.hexlify(sig[32:]), 16)), + ) + encoded = encode(ds) + new_sig = base64.urlsafe_b64encode(encoded) + return "{}.{}".format(payload, new_sig) + + @classmethod + def decode(cls, token, key=None, *args, **kwargs): + # type (cls, str, str) -> dict() + """Decode a web token into a assertion dictionary. + + This attempts to rectify both ecdsa and openssl generated + signatures. We use the built-in cryptography library since it wraps + libssl and is faster than the python only approach. + + This raises an InvalidSignature exception if the signature fails. + + """ + # convert the signature if needed. + token = cls.raw_sig_to_der(token) + sig_material, signature = token.rsplit('.', 1) + pkey = ec.EllipticCurvePublicNumbers.from_encoded_point( + ec.SECP256R1(), + key + ).public_key(default_backend()) + # eventually this would be a map of algs to signing algorithms. + + verifier = pkey.verifier( + base64.urlsafe_b64decode(repad(signature.encode())), + ec.ECDSA(hashes.SHA256())) + verifier.update(sig_material.encode()) + # This will raise an InvalidSignature exception if failure. + # It will be captured externally. + verifier.verify() + return json.loads( + base64.urlsafe_b64decode(repad(sig_material.split('.')[1]))) diff --git a/autopush/utils.py b/autopush/utils.py index c808b6ca..e2dc4cd5 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -7,7 +7,6 @@ import time import uuid -import ecdsa import requests from attr import ( Factory, @@ -16,7 +15,6 @@ ) from boto.dynamodb2.items import Item # noqa from cryptography.fernet import Fernet # noqa -from jose import jwt from typing import ( # noqa Any, Dict, @@ -27,6 +25,7 @@ from ua_parser import user_agent_parser from autopush.exceptions import InvalidTokenException +from autopush.jwt import VerifyJWT as jwt # Remove trailing padding characters from complex header items like @@ -171,11 +170,11 @@ def decipher_public_key(key_data): # key data is actually a raw coordinate pair key_data = base64url_decode(key_data) key_len = len(key_data) - if key_len == 64: + if key_len == 65 and key_data[0] == '\x04': return key_data # Key format is "raw" - if key_len == 65 and key_data[0] == '\x04': - return key_data[-64:] + if key_len == 64: + return '\04' + key_data # key format is "spki" if key_len == 88 and key_data[:3] == '0V0': return key_data[-64:] @@ -188,18 +187,7 @@ def extract_jwt(token, crypto_key): # first split and convert the jwt. if not token or not crypto_key: return {} - - key = decipher_public_key(crypto_key) - vk = ecdsa.VerifyingKey.from_string(key, curve=ecdsa.NIST256p) - # jose offers jwt.decode(token, vk, ...) which does a full check - # on the JWT object. Vapid is a bit more creative in how it - # stores data into a JWT and breaks expectations. We would have to - # turn off most of the validation in order for it to be useful. - return jwt.decode(token, dict(keys=[vk]), options=dict( - verify_aud=False, - verify_sub=False, - verify_exp=False, - )) + return jwt.decode(token, decipher_public_key(crypto_key)) def parse_user_agent(agent_string):