Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: Use cryptography based JWT parser for increased speed
Browse files Browse the repository at this point in the history
Switches to using python cryptography library instead of jose/jwt
(which relied on python ecdsa library). Also discovered lots of fun
discrepencies between how ecdsa and libssl sign things.

closes #785
  • Loading branch information
jrconlin committed Mar 28, 2017
1 parent 4128d94 commit 333df85
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 17 deletions.
98 changes: 98 additions & 0 deletions autopush/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import base64
import binascii
import json
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes

from pyasn1.type import univ, namedtype
from pyasn1.codec.der.encoder import encode


"""Why hand roll?
Most python JWT libraries either use a python elliptic curve library directly,
or call one that does, or is abandoned, or a dozen other reasons.
After spending half a day looking for reasonable replacements, I decided to
just write the functions we need directly.
THIS IS NOT A FULL JWT REPLACEMENT.
"""


def repad(string):
# type: (str) -> str
"""Adds padding to strings for base64 decoding"""
if len(string) % 4:
string += '===='[len(string) % 4:]
return string


class VerifyJWT:

def __init__(self): # pragma: no cover
pass

@classmethod
def raw_sig_to_der(cls, auth):
# type: (cls, str) -> str
"""Fix the JWT auth token.
The `ecdsa` library signs using a raw, 32octet pair of values (r,s).
Cryptography, which uses OpenSSL, uses a DER sequence of (s, r).
This function converts the raw ecdsa to DER.
"""
payload, asig = auth.rsplit(".", 1)
sig = base64.urlsafe_b64decode(repad(asig).encode('utf8'))
if len(sig) != 64:
return auth

# ecdsa and openssl transpose the "r" and "s" values of the signatures.
# for ecdsa signature is (r, s)
# for openssl, signature is (s, r)
# It's ok, though, because neither label them, even though openssl
# uses namedtypes, with the names set to ""
#
# Oh frabjous day!
ds = univ.Sequence(
componentType=namedtype.NamedTypes()
).setComponents(
univ.Integer(int(binascii.hexlify(sig[:32]), 16)),
univ.Integer(int(binascii.hexlify(sig[32:]), 16)),
)
encoded = encode(ds)
new_sig = base64.urlsafe_b64encode(encoded)
return "{}.{}".format(payload, new_sig)

@classmethod
def decode(cls, token, key=None, *args, **kwargs):
# type (cls, str, str) -> dict()
"""Decode a web token into a assertion dictionary.
This attempts to rectify both ecdsa and openssl generated
signatures. We use the built-in cryptography library since it wraps
libssl and is faster than the python only approach.
This raises an InvalidSignature exception if the signature fails.
"""
# convert the signature if needed.
token = cls.raw_sig_to_der(token)
sig_material, signature = token.rsplit('.', 1)
pkey = ec.EllipticCurvePublicNumbers.from_encoded_point(
ec.SECP256R1(),
key
).public_key(default_backend())
# eventually this would be a map of algs to signing algorithms.

verifier = pkey.verifier(
base64.urlsafe_b64decode(repad(signature.encode())),
ec.ECDSA(hashes.SHA256()))
verifier.update(sig_material.encode())
# This will raise an InvalidSignature exception if failure.
# It will be captured externally.
verifier.verify()
return json.loads(
base64.urlsafe_b64decode(repad(sig_material.split('.')[1])))
22 changes: 5 additions & 17 deletions autopush/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
import uuid

import ecdsa
import requests
from attr import (
Factory,
Expand All @@ -16,7 +15,6 @@
)
from boto.dynamodb2.items import Item # noqa
from cryptography.fernet import Fernet # noqa
from jose import jwt
from typing import ( # noqa
Any,
Dict,
Expand All @@ -27,6 +25,7 @@
from ua_parser import user_agent_parser

from autopush.exceptions import InvalidTokenException
from autopush.jwt import VerifyJWT as jwt


# Remove trailing padding characters from complex header items like
Expand Down Expand Up @@ -171,11 +170,11 @@ def decipher_public_key(key_data):
# key data is actually a raw coordinate pair
key_data = base64url_decode(key_data)
key_len = len(key_data)
if key_len == 64:
if key_len == 65 and key_data[0] == '\x04':
return key_data
# Key format is "raw"
if key_len == 65 and key_data[0] == '\x04':
return key_data[-64:]
if key_len == 64:
return '\04' + key_data
# key format is "spki"
if key_len == 88 and key_data[:3] == '0V0':
return key_data[-64:]
Expand All @@ -188,18 +187,7 @@ def extract_jwt(token, crypto_key):
# first split and convert the jwt.
if not token or not crypto_key:
return {}

key = decipher_public_key(crypto_key)
vk = ecdsa.VerifyingKey.from_string(key, curve=ecdsa.NIST256p)
# jose offers jwt.decode(token, vk, ...) which does a full check
# on the JWT object. Vapid is a bit more creative in how it
# stores data into a JWT and breaks expectations. We would have to
# turn off most of the validation in order for it to be useful.
return jwt.decode(token, dict(keys=[vk]), options=dict(
verify_aud=False,
verify_sub=False,
verify_exp=False,
))
return jwt.decode(token, decipher_public_key(crypto_key))


def parse_user_agent(agent_string):
Expand Down

0 comments on commit 333df85

Please sign in to comment.