Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

feat: Use cryptography based JWT parser for increased speed #854

Merged
merged 1 commit into from
Apr 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions autopush/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import base64
import binascii
import json

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec, utils
from cryptography.hazmat.primitives import hashes
from pyasn1.error import PyAsn1Error
from twisted.logger import Logger


def repad(string):
# type: (str) -> str
"""Adds padding to strings for base64 decoding"""
if len(string) % 4:
string += '===='[len(string) % 4:]
return string


class VerifyJWT(object):
"""Minimally verify a Vapid JWT object.

Why hand roll? Most python JWT libraries either use a python elliptic
curve library directly, or call one that does, or is abandoned, or a
dozen other reasons.

After spending half a day looking for reasonable replacements, I
decided to just write the functions we need directly.

THIS IS NOT A FULL JWT REPLACEMENT.

"""

@staticmethod
def extract_signature(auth):
# type: (str) -> tuple()
"""Fix the JWT auth token.

The JWA spec defines the signature to be a pair of 32octet encoded
longs.
The `ecdsa` library signs using a raw, 32octet pair of values (s, r).
Cryptography, which uses OpenSSL, uses a DER sequence of (s, r).
This function converts the raw ecdsa to DER.

:param auth: A JWT authorization token.
:type auth: str

:return tuple containing the signature material and signature

"""
payload, asig = auth.encode('utf8').rsplit(".", 1)
sig = base64.urlsafe_b64decode(repad(asig))
if len(sig) != 64:
return payload, sig

encoded = utils.encode_dss_signature(
s=int(binascii.hexlify(sig[32:]), 16),
r=int(binascii.hexlify(sig[:32]), 16)
)
return payload, encoded

@staticmethod
def decode(token, key):
# type (str, str) -> dict()
"""Decode a web token into a assertion dictionary.

This attempts to rectify both ecdsa and openssl generated
signatures. We use the built-in cryptography library since it wraps
libssl and is faster than the python only approach.

:param token: VAPID auth token
:type token: str
:param key: bitarray containing public key
:type key: str or bitarray

:return dict of the VAPID claims

:raise InvalidSignature

"""
# convert the signature if needed.
try:
sig_material, signature = VerifyJWT.extract_signature(token)
pkey = ec.EllipticCurvePublicNumbers.from_encoded_point(
ec.SECP256R1(),
key
).public_key(default_backend())
# NOTE: verify() will take any string as the signature. It appears
# to be doing lazy verification and matching strings rather than
# comparing content values. If the signatures start failing for
# some unknown reason in the future, decode the signature and
# make sure it matches how we're reconstructing it.
# This will raise an InvalidSignature exception if failure.
# It will be captured externally.
pkey.verify(
signature,
sig_material.encode('utf8'),
ec.ECDSA(hashes.SHA256()))
return json.loads(
base64.urlsafe_b64decode(
repad(sig_material.split('.')[1]).encode('utf8')))
except (ValueError, TypeError, binascii.Error, PyAsn1Error):
raise InvalidSignature()
except Exception: # pragma: no cover
Logger().failure("Unexpected error processing JWT")
raise InvalidSignature()
25 changes: 1 addition & 24 deletions autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import uuid

import jose
import twisted.internet.base
from cryptography.fernet import Fernet, InvalidToken
from cyclone.web import Application
from mock import Mock, patch
from mock import Mock
from moto import mock_dynamodb2
from nose.tools import eq_, ok_
from twisted.internet.defer import Deferred
Expand Down Expand Up @@ -538,28 +537,6 @@ def restore(*args, **kwargs):
chid=str(dummy_chid)))
return self.finish_deferred

@patch('jose.jws.verify', side_effect=jose.exceptions.JWTError)
def test_post_bad_jwt(self, *args):
self.reg.request.body = json.dumps(dict(
channelID=str(dummy_chid),
))

def handle_finish(value):
self._check_error(401, 109, 'Unauthorized')

def restore(*args, **kwargs):
uuid.uuid4 = old_func

old_func = uuid.uuid4
uuid.uuid4 = lambda: dummy_uaid
self.finish_deferred.addBoth(restore)
self.finish_deferred.addCallback(handle_finish)
self.reg.request.headers["Authorization"] = "WebPush Dummy"
self.reg.post(self._make_req(router_type="webpush",
uaid=dummy_uaid.hex,
chid=str(dummy_chid)))
return self.finish_deferred

def test_post_uaid_chid(self, *args):
self.reg.request.body = json.dumps(dict(
type="simplepush",
Expand Down
4 changes: 2 additions & 2 deletions autopush/tests/test_web_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
ItemNotFound,
)
from cryptography.fernet import InvalidToken
from cryptography.exceptions import InvalidSignature
from jose import jws
from jose.exceptions import JWTClaimsError
from marshmallow import Schema, fields
from mock import Mock, patch
from moto import mock_dynamodb2
Expand Down Expand Up @@ -1005,7 +1005,7 @@ def test_invalid_encryption_header(self, mock_jwt):
def test_invalid_encryption_jwt(self, mock_jwt):
schema = self._make_fut()
# use a deeply superclassed error to make sure that it gets picked up.
mock_jwt.side_effect = JWTClaimsError("invalid claim")
mock_jwt.side_effect = InvalidSignature("invalid signature")

header = {"typ": "JWT", "alg": "ES256"}
payload = {"aud": "https://push.example.com",
Expand Down
30 changes: 5 additions & 25 deletions autopush/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
import uuid

import ecdsa
import requests
from attr import (
Factory,
Expand All @@ -16,7 +15,6 @@
)
from boto.dynamodb2.items import Item # noqa
from cryptography.fernet import Fernet # noqa
from jose import jwt
from typing import ( # noqa
Any,
Dict,
Expand All @@ -27,6 +25,7 @@
from ua_parser import user_agent_parser

from autopush.exceptions import InvalidTokenException
from autopush.jwt import repad, VerifyJWT as jwt


# Remove trailing padding characters from complex header items like
Expand Down Expand Up @@ -122,14 +121,6 @@ def base64url_encode(string):
return base64.urlsafe_b64encode(string).strip('=')


def repad(string):
# type: (str) -> str
"""Adds padding to strings for base64 decoding"""
if len(string) % 4:
string += '===='[len(string) % 4:]
return string


def base64url_decode(string):
# type: (str) -> str
"""Decodes a Base64 URL-encoded string per RFC 7515.
Expand Down Expand Up @@ -171,11 +162,11 @@ def decipher_public_key(key_data):
# key data is actually a raw coordinate pair
key_data = base64url_decode(key_data)
key_len = len(key_data)
if key_len == 64:
if key_len == 65 and key_data[0] == '\x04':
return key_data
# Key format is "raw"
if key_len == 65 and key_data[0] == '\x04':
return key_data[-64:]
if key_len == 64:
return '\04' + key_data
# key format is "spki"
if key_len == 88 and key_data[:3] == '0V0':
return key_data[-64:]
Expand All @@ -188,18 +179,7 @@ def extract_jwt(token, crypto_key):
# first split and convert the jwt.
if not token or not crypto_key:
return {}

key = decipher_public_key(crypto_key)
vk = ecdsa.VerifyingKey.from_string(key, curve=ecdsa.NIST256p)
# jose offers jwt.decode(token, vk, ...) which does a full check
# on the JWT object. Vapid is a bit more creative in how it
# stores data into a JWT and breaks expectations. We would have to
# turn off most of the validation in order for it to be useful.
return jwt.decode(token, dict(keys=[vk]), options=dict(
verify_aud=False,
verify_sub=False,
verify_exp=False,
))
return jwt.decode(token, decipher_public_key(crypto_key))


def parse_user_agent(agent_string):
Expand Down
4 changes: 2 additions & 2 deletions autopush/web/webpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from boto.dynamodb2.exceptions import ItemNotFound
from cryptography.fernet import InvalidToken
from jose import JOSEError
from cryptography.exceptions import InvalidSignature
from marshmallow import (
Schema,
fields,
Expand Down Expand Up @@ -321,7 +321,7 @@ def validate_auth(self, d):

try:
jwt = extract_jwt(token, public_key)
except (AssertionError, ValueError, JOSEError):
except (ValueError, InvalidSignature, Exception):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exception here definitely not needed

raise InvalidRequest("Invalid Authorization Header",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ documentation is organized alphabetically by module name.
api/ssl
api/utils
api/websocket
api/jwt
10 changes: 10 additions & 0 deletions docs/api/jwt.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. _jwt_module:

:mod:`autopush.jwt`
------------------------

.. automodule:: autopush.jwt

.. autoclass:: VerifyJWT
:members:
:member-order: bysource
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ cryptography==1.7.2
cyclone==1.1
datadog==0.13.0
decorator==4.0.10
ecdsa==0.13
enum34==1.1.6
future==0.15.2
futures==3.0.5
Expand All @@ -32,7 +31,6 @@ pycparser==2.17
pycrypto==2.6.1
pyfcm==1.0.4
python-dateutil==2.5.3
python-jose==1.3.2
raven==5.25.0
requests==2.11.0
service-identity==16.0.0
Expand Down
2 changes: 2 additions & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
-r requirements.txt
nose
coverage
ecdsa==0.13
python-jose==1.3.2
mock>=1.0.1
funcsigs==1.0.2
pbr==1.10.0
Expand Down