Skip to content

Commit

Permalink
Support raw json web key set
Browse files Browse the repository at this point in the history
  • Loading branch information
coopfeathy committed Oct 14, 2020
1 parent 7422150 commit 058986e
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 117 deletions.
3 changes: 1 addition & 2 deletions authlib/jose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .rfc7516 import (
JsonWebEncryption, JWEAlgorithm, JWEEncAlgorithm, JWEZipAlgorithm,
)
from .rfc7517 import Key, KeySet
from .rfc7517 import Key, KeySet, JsonWebKey
from .rfc7518 import (
register_jws_rfc7518,
register_jwe_rfc7518,
Expand All @@ -25,7 +25,6 @@
from .drafts import register_jwe_draft

from .errors import JoseError
from .jwk import JsonWebKey

# register algorithms
register_jws_rfc7518()
Expand Down
74 changes: 1 addition & 73 deletions authlib/jose/jwk.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,4 @@
from authlib.common.encoding import text_types, json_loads
from .rfc7517 import KeySet
from .rfc7518 import (
OctKey,
RSAKey,
ECKey,
load_pem_key,
)
from .rfc8037 import OKPKey


class JsonWebKey(object):
JWK_KEY_CLS = {
OctKey.kty: OctKey,
RSAKey.kty: RSAKey,
ECKey.kty: ECKey,
OKPKey.kty: OKPKey,
}

@classmethod
def generate_key(cls, kty, crv_or_size, options=None, is_private=False):
"""Generate a Key with the given key type, curve name or bit size.
:param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP``
:param crv_or_size: curve name or bit size
:param options: a dict of other options for Key
:param is_private: create a private key or public key
:return: Key instance
"""
key_cls = cls.JWK_KEY_CLS[kty]
return key_cls.generate_key(crv_or_size, options, is_private)

@classmethod
def import_key(cls, raw, options=None):
"""Import a Key from bytes, string, PEM or dict.
:return: Key instance
"""
kty = None
if options is not None:
kty = options.get('kty')

if kty is None and isinstance(raw, dict):
kty = raw.get('kty')

if kty is None:
raw_key = load_pem_key(raw)
for _kty in cls.JWK_KEY_CLS:
key_cls = cls.JWK_KEY_CLS[_kty]
if isinstance(raw_key, key_cls.RAW_KEY_CLS):
return key_cls.import_key(raw_key, options)

key_cls = cls.JWK_KEY_CLS[kty]
return key_cls.import_key(raw, options)

@classmethod
def import_key_set(cls, raw):
"""Import KeySet from string, dict or a list of keys.
:return: KeySet instance
"""
if isinstance(raw, text_types) and \
raw.startswith('{') and raw.endswith('}'):
raw = json_loads(raw)
keys = raw.get('keys')
elif isinstance(raw, dict) and 'keys' in raw:
keys = raw.get('keys')
elif isinstance(raw, (tuple, list)):
keys = raw
else:
return None

return KeySet([cls.import_key(k) for k in keys])
from .rfc7517 import JsonWebKey


def loads(obj, kid=None):
Expand Down
4 changes: 3 additions & 1 deletion authlib/jose/rfc7517/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
https://tools.ietf.org/html/rfc7517
"""
from .models import Key, KeySet
from ._cryptography_key import load_pem_key
from .jwk import JsonWebKey


__all__ = ['Key', 'KeySet']
__all__ = ['Key', 'KeySet', 'JsonWebKey', 'load_pem_key']
34 changes: 34 additions & 0 deletions authlib/jose/rfc7517/_cryptography_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from cryptography.x509 import load_pem_x509_certificate
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key, load_pem_public_key, load_ssh_public_key,
)
from cryptography.hazmat.backends import default_backend
from authlib.common.encoding import to_bytes


def load_pem_key(raw, ssh_type=None, key_type=None, password=None):
raw = to_bytes(raw)

if ssh_type and raw.startswith(ssh_type):
return load_ssh_public_key(raw, backend=default_backend())

if key_type == 'public':
return load_pem_public_key(raw, backend=default_backend())

if key_type == 'private' or password is not None:
return load_pem_private_key(raw, password=password, backend=default_backend())

if b'PUBLIC' in raw:
return load_pem_public_key(raw, backend=default_backend())

if b'PRIVATE' in raw:
return load_pem_private_key(raw, password=password, backend=default_backend())

if b'CERTIFICATE' in raw:
cert = load_pem_x509_certificate(raw, default_backend())
return cert.public_key()

try:
return load_pem_private_key(raw, password=password, backend=default_backend())
except ValueError:
return load_pem_public_key(raw, backend=default_backend())
63 changes: 63 additions & 0 deletions authlib/jose/rfc7517/jwk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from authlib.common.encoding import text_types, json_loads
from ._cryptography_key import load_pem_key
from .models import KeySet


class JsonWebKey(object):
JWK_KEY_CLS = {}

@classmethod
def generate_key(cls, kty, crv_or_size, options=None, is_private=False):
"""Generate a Key with the given key type, curve name or bit size.
:param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP``
:param crv_or_size: curve name or bit size
:param options: a dict of other options for Key
:param is_private: create a private key or public key
:return: Key instance
"""
key_cls = cls.JWK_KEY_CLS[kty]
return key_cls.generate_key(crv_or_size, options, is_private)

@classmethod
def import_key(cls, raw, options=None):
"""Import a Key from bytes, string, PEM or dict.
:return: Key instance
"""
kty = None
if options is not None:
kty = options.get('kty')

if kty is None and isinstance(raw, dict):
kty = raw.get('kty')

if kty is None:
raw_key = load_pem_key(raw)
for _kty in cls.JWK_KEY_CLS:
key_cls = cls.JWK_KEY_CLS[_kty]
if isinstance(raw_key, key_cls.RAW_KEY_CLS):
return key_cls.import_key(raw_key, options)

key_cls = cls.JWK_KEY_CLS[kty]
return key_cls.import_key(raw, options)

@classmethod
def import_key_set(cls, raw):
"""Import KeySet from string, dict or a list of keys.
:return: KeySet instance
"""
raw = _transform_raw_key(raw)
if isinstance(raw, dict) and 'keys' in raw:
keys = raw.get('keys')
return KeySet([cls.import_key(k) for k in keys])


def _transform_raw_key(raw):
if isinstance(raw, text_types) and \
raw.startswith('{') and raw.endswith('}'):
return json_loads(raw)
elif isinstance(raw, (tuple, list)):
return {'keys': raw}
return raw
32 changes: 1 addition & 31 deletions authlib/jose/rfc7518/_cryptography_backends/_keys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from cryptography.x509 import load_pem_x509_certificate
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key, load_pem_public_key, load_ssh_public_key,
Encoding, PrivateFormat, PublicFormat,
BestAvailableEncryption, NoEncryption,
)
Expand All @@ -17,7 +15,7 @@
SECP256R1, SECP384R1, SECP521R1,
)
from cryptography.hazmat.backends import default_backend
from authlib.jose.rfc7517 import Key
from authlib.jose.rfc7517 import Key, load_pem_key
from authlib.common.encoding import to_bytes
from authlib.common.encoding import base64_to_int, int_to_base64

Expand Down Expand Up @@ -236,34 +234,6 @@ def generate_key(cls, crv='P-256', options=None, is_private=False):
return cls.import_key(raw_key, options=options)


def load_pem_key(raw, ssh_type=None, key_type=None, password=None):
raw = to_bytes(raw)

if ssh_type and raw.startswith(ssh_type):
return load_ssh_public_key(raw, backend=default_backend())

if key_type == 'public':
return load_pem_public_key(raw, backend=default_backend())

if key_type == 'private' or password is not None:
return load_pem_private_key(raw, password=password, backend=default_backend())

if b'PUBLIC' in raw:
return load_pem_public_key(raw, backend=default_backend())

if b'PRIVATE' in raw:
return load_pem_private_key(raw, password=password, backend=default_backend())

if b'CERTIFICATE' in raw:
cert = load_pem_x509_certificate(raw, default_backend())
return cert.public_key()

try:
return load_pem_private_key(raw, password=password, backend=default_backend())
except ValueError:
return load_pem_public_key(raw, backend=default_backend())


def import_key(cls, raw, public_key_cls, private_key_cls, ssh_type=None, options=None):
if isinstance(raw, cls):
if options is not None:
Expand Down
32 changes: 23 additions & 9 deletions authlib/jose/rfc7519/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ..errors import DecodeError, InsecureClaimError
from ..rfc7515 import JsonWebSignature
from ..rfc7516 import JsonWebEncryption
from ..rfc7517 import KeySet


class JsonWebToken(object):
Expand Down Expand Up @@ -60,9 +59,7 @@ def encode(self, header, payload, key, check=True):
if check:
self.check_sensitive_data(payload)

if isinstance(key, KeySet):
key = key.find_by_kid(header.get('kid'))

key = prepare_raw_key(key, header)
text = to_bytes(json_dumps(payload))
if 'enc' in header:
return self._jwe.serialize_compact(header, text, key)
Expand All @@ -86,11 +83,8 @@ def decode(self, s, key, claims_cls=None,
if claims_cls is None:
claims_cls = JWTClaims

if isinstance(key, KeySet):
def load_key(header, payload):
return key.find_by_kid(header.get('kid'))
else:
load_key = key
def load_key(header, payload):
return prepare_raw_key(key, header)

s = to_bytes(s)
dot_count = s.count(b'.')
Expand All @@ -115,3 +109,23 @@ def decode_payload(bytes_payload):
if not isinstance(payload, dict):
raise DecodeError('Invalid payload type')
return payload


def prepare_raw_key(raw, headers=None):
if isinstance(raw, text_types) and \
raw.startswith('{') and raw.endswith('}'):
raw = json_loads(raw)
elif isinstance(raw, (tuple, list)):
raw = {'keys': raw}

if isinstance(raw, dict) and 'keys' in raw:
keys = raw['keys']
if headers is not None:
kid = headers.get('kid')
else:
kid = None
for k in keys:
if k.get('kid') == kid:
return k
raise ValueError('Invalid JSON Web Key Set')
return raw
12 changes: 11 additions & 1 deletion tests/core/test_jose/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
from authlib.jose import errors
from authlib.jose import JsonWebToken, JWTClaims, jwt
from authlib.jose.errors import UnsupportedAlgorithmError, InvalidUseError
from authlib.jose.errors import UnsupportedAlgorithmError
from tests.util import read_file_path


Expand Down Expand Up @@ -177,6 +177,16 @@ def test_use_jwe(self):
claims = jwt.decode(data, private_key)
self.assertEqual(claims['name'], 'hi')

def test_use_jwks(self):
header = {'alg': 'RS256', 'kid': 'abc'}
payload = {'name': 'hi'}
private_key = read_file_path('jwks_private.json')
pub_key = read_file_path('jwks_public.json')
data = jwt.encode(header, payload, private_key)
self.assertEqual(data.count(b'.'), 2)
claims = jwt.decode(data, pub_key)
self.assertEqual(claims['name'], 'hi')

def test_with_ec(self):
payload = {'name': 'hi'}
private_key = read_file_path('ec_private.json')
Expand Down

0 comments on commit 058986e

Please sign in to comment.