Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
Merge pull request #854 from mozilla-services/feat/785
Browse files Browse the repository at this point in the history
feat: Use cryptography based JWT parser for increased speed
  • Loading branch information
bbangert authored Apr 4, 2017
2 parents 76f772e + fe9b776 commit a03910d
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 55 deletions.
107 changes: 107 additions & 0 deletions autopush/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import base64
import binascii
import json

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec, utils
from cryptography.hazmat.primitives import hashes
from pyasn1.error import PyAsn1Error
from twisted.logger import Logger


def repad(string):
# type: (str) -> str
"""Adds padding to strings for base64 decoding"""
if len(string) % 4:
string += '===='[len(string) % 4:]
return string


class VerifyJWT(object):
"""Minimally verify a Vapid JWT object.
Why hand roll? Most python JWT libraries either use a python elliptic
curve library directly, or call one that does, or is abandoned, or a
dozen other reasons.
After spending half a day looking for reasonable replacements, I
decided to just write the functions we need directly.
THIS IS NOT A FULL JWT REPLACEMENT.
"""

@staticmethod
def extract_signature(auth):
# type: (str) -> tuple()
"""Fix the JWT auth token.
The JWA spec defines the signature to be a pair of 32octet encoded
longs.
The `ecdsa` library signs using a raw, 32octet pair of values (s, r).
Cryptography, which uses OpenSSL, uses a DER sequence of (s, r).
This function converts the raw ecdsa to DER.
:param auth: A JWT authorization token.
:type auth: str
:return tuple containing the signature material and signature
"""
payload, asig = auth.encode('utf8').rsplit(".", 1)
sig = base64.urlsafe_b64decode(repad(asig))
if len(sig) != 64:
return payload, sig

encoded = utils.encode_dss_signature(
s=int(binascii.hexlify(sig[32:]), 16),
r=int(binascii.hexlify(sig[:32]), 16)
)
return payload, encoded

@staticmethod
def decode(token, key):
# type (str, str) -> dict()
"""Decode a web token into a assertion dictionary.
This attempts to rectify both ecdsa and openssl generated
signatures. We use the built-in cryptography library since it wraps
libssl and is faster than the python only approach.
:param token: VAPID auth token
:type token: str
:param key: bitarray containing public key
:type key: str or bitarray
:return dict of the VAPID claims
:raise InvalidSignature
"""
# convert the signature if needed.
try:
sig_material, signature = VerifyJWT.extract_signature(token)
pkey = ec.EllipticCurvePublicNumbers.from_encoded_point(
ec.SECP256R1(),
key
).public_key(default_backend())
# NOTE: verify() will take any string as the signature. It appears
# to be doing lazy verification and matching strings rather than
# comparing content values. If the signatures start failing for
# some unknown reason in the future, decode the signature and
# make sure it matches how we're reconstructing it.
# This will raise an InvalidSignature exception if failure.
# It will be captured externally.
pkey.verify(
signature,
sig_material.encode('utf8'),
ec.ECDSA(hashes.SHA256()))
return json.loads(
base64.urlsafe_b64decode(
repad(sig_material.split('.')[1]).encode('utf8')))
except (ValueError, TypeError, binascii.Error, PyAsn1Error):
raise InvalidSignature()
except Exception: # pragma: no cover
Logger().failure("Unexpected error processing JWT")
raise InvalidSignature()
25 changes: 1 addition & 24 deletions autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import uuid

import jose
import twisted.internet.base
from cryptography.fernet import Fernet, InvalidToken
from cyclone.web import Application
from mock import Mock, patch
from mock import Mock
from moto import mock_dynamodb2
from nose.tools import eq_, ok_
from twisted.internet.defer import Deferred
Expand Down Expand Up @@ -538,28 +537,6 @@ def restore(*args, **kwargs):
chid=str(dummy_chid)))
return self.finish_deferred

@patch('jose.jws.verify', side_effect=jose.exceptions.JWTError)
def test_post_bad_jwt(self, *args):
self.reg.request.body = json.dumps(dict(
channelID=str(dummy_chid),
))

def handle_finish(value):
self._check_error(401, 109, 'Unauthorized')

def restore(*args, **kwargs):
uuid.uuid4 = old_func

old_func = uuid.uuid4
uuid.uuid4 = lambda: dummy_uaid
self.finish_deferred.addBoth(restore)
self.finish_deferred.addCallback(handle_finish)
self.reg.request.headers["Authorization"] = "WebPush Dummy"
self.reg.post(self._make_req(router_type="webpush",
uaid=dummy_uaid.hex,
chid=str(dummy_chid)))
return self.finish_deferred

def test_post_uaid_chid(self, *args):
self.reg.request.body = json.dumps(dict(
type="simplepush",
Expand Down
4 changes: 2 additions & 2 deletions autopush/tests/test_web_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
ItemNotFound,
)
from cryptography.fernet import InvalidToken
from cryptography.exceptions import InvalidSignature
from jose import jws
from jose.exceptions import JWTClaimsError
from marshmallow import Schema, fields
from mock import Mock, patch
from moto import mock_dynamodb2
Expand Down Expand Up @@ -1005,7 +1005,7 @@ def test_invalid_encryption_header(self, mock_jwt):
def test_invalid_encryption_jwt(self, mock_jwt):
schema = self._make_fut()
# use a deeply superclassed error to make sure that it gets picked up.
mock_jwt.side_effect = JWTClaimsError("invalid claim")
mock_jwt.side_effect = InvalidSignature("invalid signature")

header = {"typ": "JWT", "alg": "ES256"}
payload = {"aud": "https://push.example.com",
Expand Down
30 changes: 5 additions & 25 deletions autopush/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
import uuid

import ecdsa
import requests
from attr import (
Factory,
Expand All @@ -16,7 +15,6 @@
)
from boto.dynamodb2.items import Item # noqa
from cryptography.fernet import Fernet # noqa
from jose import jwt
from typing import ( # noqa
Any,
Dict,
Expand All @@ -27,6 +25,7 @@
from ua_parser import user_agent_parser

from autopush.exceptions import InvalidTokenException
from autopush.jwt import repad, VerifyJWT as jwt


# Remove trailing padding characters from complex header items like
Expand Down Expand Up @@ -122,14 +121,6 @@ def base64url_encode(string):
return base64.urlsafe_b64encode(string).strip('=')


def repad(string):
# type: (str) -> str
"""Adds padding to strings for base64 decoding"""
if len(string) % 4:
string += '===='[len(string) % 4:]
return string


def base64url_decode(string):
# type: (str) -> str
"""Decodes a Base64 URL-encoded string per RFC 7515.
Expand Down Expand Up @@ -171,11 +162,11 @@ def decipher_public_key(key_data):
# key data is actually a raw coordinate pair
key_data = base64url_decode(key_data)
key_len = len(key_data)
if key_len == 64:
if key_len == 65 and key_data[0] == '\x04':
return key_data
# Key format is "raw"
if key_len == 65 and key_data[0] == '\x04':
return key_data[-64:]
if key_len == 64:
return '\04' + key_data
# key format is "spki"
if key_len == 88 and key_data[:3] == '0V0':
return key_data[-64:]
Expand All @@ -188,18 +179,7 @@ def extract_jwt(token, crypto_key):
# first split and convert the jwt.
if not token or not crypto_key:
return {}

key = decipher_public_key(crypto_key)
vk = ecdsa.VerifyingKey.from_string(key, curve=ecdsa.NIST256p)
# jose offers jwt.decode(token, vk, ...) which does a full check
# on the JWT object. Vapid is a bit more creative in how it
# stores data into a JWT and breaks expectations. We would have to
# turn off most of the validation in order for it to be useful.
return jwt.decode(token, dict(keys=[vk]), options=dict(
verify_aud=False,
verify_sub=False,
verify_exp=False,
))
return jwt.decode(token, decipher_public_key(crypto_key))


def parse_user_agent(agent_string):
Expand Down
4 changes: 2 additions & 2 deletions autopush/web/webpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from boto.dynamodb2.exceptions import ItemNotFound
from cryptography.fernet import InvalidToken
from jose import JOSEError
from cryptography.exceptions import InvalidSignature
from marshmallow import (
Schema,
fields,
Expand Down Expand Up @@ -321,7 +321,7 @@ def validate_auth(self, d):

try:
jwt = extract_jwt(token, public_key)
except (AssertionError, ValueError, JOSEError):
except (ValueError, InvalidSignature, Exception):
raise InvalidRequest("Invalid Authorization Header",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ documentation is organized alphabetically by module name.
api/ssl
api/utils
api/websocket
api/jwt
10 changes: 10 additions & 0 deletions docs/api/jwt.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. _jwt_module:

:mod:`autopush.jwt`
------------------------

.. automodule:: autopush.jwt

.. autoclass:: VerifyJWT
:members:
:member-order: bysource
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ cryptography==1.7.2
cyclone==1.1
datadog==0.13.0
decorator==4.0.10
ecdsa==0.13
enum34==1.1.6
future==0.15.2
futures==3.0.5
Expand All @@ -32,7 +31,6 @@ pycparser==2.17
pycrypto==2.6.1
pyfcm==1.0.4
python-dateutil==2.5.3
python-jose==1.3.2
raven==5.25.0
requests==2.11.0
service-identity==16.0.0
Expand Down
2 changes: 2 additions & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
-r requirements.txt
nose
coverage
ecdsa==0.13
python-jose==1.3.2
mock>=1.0.1
funcsigs==1.0.2
pbr==1.10.0
Expand Down

0 comments on commit a03910d

Please sign in to comment.