From 058986e7e8dbdcf3c157bca0c6f75ec2dd308c90 Mon Sep 17 00:00:00 2001 From: coopfeathy Date: Wed, 14 Oct 2020 16:36:56 +0900 Subject: [PATCH] Support raw json web key set Fixes https://github.com/lepture/authlib/issues/280 --- authlib/jose/__init__.py | 3 +- authlib/jose/jwk.py | 74 +------------------ authlib/jose/rfc7517/__init__.py | 4 +- authlib/jose/rfc7517/_cryptography_key.py | 34 +++++++++ authlib/jose/rfc7517/jwk.py | 63 ++++++++++++++++ .../rfc7518/_cryptography_backends/_keys.py | 32 +------- authlib/jose/rfc7519/jwt.py | 32 +++++--- tests/core/test_jose/test_jwt.py | 12 ++- 8 files changed, 137 insertions(+), 117 deletions(-) create mode 100644 authlib/jose/rfc7517/_cryptography_key.py create mode 100644 authlib/jose/rfc7517/jwk.py diff --git a/authlib/jose/__init__.py b/authlib/jose/__init__.py index 86db6a7..c023ae2 100644 --- a/authlib/jose/__init__.py +++ b/authlib/jose/__init__.py @@ -11,7 +11,7 @@ from .rfc7516 import ( JsonWebEncryption, JWEAlgorithm, JWEEncAlgorithm, JWEZipAlgorithm, ) -from .rfc7517 import Key, KeySet +from .rfc7517 import Key, KeySet, JsonWebKey from .rfc7518 import ( register_jws_rfc7518, register_jwe_rfc7518, @@ -25,7 +25,6 @@ from .drafts import register_jwe_draft from .errors import JoseError -from .jwk import JsonWebKey # register algorithms register_jws_rfc7518() diff --git a/authlib/jose/jwk.py b/authlib/jose/jwk.py index c78ef70..02dbbab 100644 --- a/authlib/jose/jwk.py +++ b/authlib/jose/jwk.py @@ -1,76 +1,4 @@ -from authlib.common.encoding import text_types, json_loads -from .rfc7517 import KeySet -from .rfc7518 import ( - OctKey, - RSAKey, - ECKey, - load_pem_key, -) -from .rfc8037 import OKPKey - - -class JsonWebKey(object): - JWK_KEY_CLS = { - OctKey.kty: OctKey, - RSAKey.kty: RSAKey, - ECKey.kty: ECKey, - OKPKey.kty: OKPKey, - } - - @classmethod - def generate_key(cls, kty, crv_or_size, options=None, is_private=False): - """Generate a Key with the given key type, curve name or bit size. - - :param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP`` - :param crv_or_size: curve name or bit size - :param options: a dict of other options for Key - :param is_private: create a private key or public key - :return: Key instance - """ - key_cls = cls.JWK_KEY_CLS[kty] - return key_cls.generate_key(crv_or_size, options, is_private) - - @classmethod - def import_key(cls, raw, options=None): - """Import a Key from bytes, string, PEM or dict. - - :return: Key instance - """ - kty = None - if options is not None: - kty = options.get('kty') - - if kty is None and isinstance(raw, dict): - kty = raw.get('kty') - - if kty is None: - raw_key = load_pem_key(raw) - for _kty in cls.JWK_KEY_CLS: - key_cls = cls.JWK_KEY_CLS[_kty] - if isinstance(raw_key, key_cls.RAW_KEY_CLS): - return key_cls.import_key(raw_key, options) - - key_cls = cls.JWK_KEY_CLS[kty] - return key_cls.import_key(raw, options) - - @classmethod - def import_key_set(cls, raw): - """Import KeySet from string, dict or a list of keys. - - :return: KeySet instance - """ - if isinstance(raw, text_types) and \ - raw.startswith('{') and raw.endswith('}'): - raw = json_loads(raw) - keys = raw.get('keys') - elif isinstance(raw, dict) and 'keys' in raw: - keys = raw.get('keys') - elif isinstance(raw, (tuple, list)): - keys = raw - else: - return None - - return KeySet([cls.import_key(k) for k in keys]) +from .rfc7517 import JsonWebKey def loads(obj, kid=None): diff --git a/authlib/jose/rfc7517/__init__.py b/authlib/jose/rfc7517/__init__.py index 079a7cc..e2f1595 100644 --- a/authlib/jose/rfc7517/__init__.py +++ b/authlib/jose/rfc7517/__init__.py @@ -8,6 +8,8 @@ https://tools.ietf.org/html/rfc7517 """ from .models import Key, KeySet +from ._cryptography_key import load_pem_key +from .jwk import JsonWebKey -__all__ = ['Key', 'KeySet'] +__all__ = ['Key', 'KeySet', 'JsonWebKey', 'load_pem_key'] diff --git a/authlib/jose/rfc7517/_cryptography_key.py b/authlib/jose/rfc7517/_cryptography_key.py new file mode 100644 index 0000000..f7194a3 --- /dev/null +++ b/authlib/jose/rfc7517/_cryptography_key.py @@ -0,0 +1,34 @@ +from cryptography.x509 import load_pem_x509_certificate +from cryptography.hazmat.primitives.serialization import ( + load_pem_private_key, load_pem_public_key, load_ssh_public_key, +) +from cryptography.hazmat.backends import default_backend +from authlib.common.encoding import to_bytes + + +def load_pem_key(raw, ssh_type=None, key_type=None, password=None): + raw = to_bytes(raw) + + if ssh_type and raw.startswith(ssh_type): + return load_ssh_public_key(raw, backend=default_backend()) + + if key_type == 'public': + return load_pem_public_key(raw, backend=default_backend()) + + if key_type == 'private' or password is not None: + return load_pem_private_key(raw, password=password, backend=default_backend()) + + if b'PUBLIC' in raw: + return load_pem_public_key(raw, backend=default_backend()) + + if b'PRIVATE' in raw: + return load_pem_private_key(raw, password=password, backend=default_backend()) + + if b'CERTIFICATE' in raw: + cert = load_pem_x509_certificate(raw, default_backend()) + return cert.public_key() + + try: + return load_pem_private_key(raw, password=password, backend=default_backend()) + except ValueError: + return load_pem_public_key(raw, backend=default_backend()) diff --git a/authlib/jose/rfc7517/jwk.py b/authlib/jose/rfc7517/jwk.py new file mode 100644 index 0000000..b52d319 --- /dev/null +++ b/authlib/jose/rfc7517/jwk.py @@ -0,0 +1,63 @@ +from authlib.common.encoding import text_types, json_loads +from ._cryptography_key import load_pem_key +from .models import KeySet + + +class JsonWebKey(object): + JWK_KEY_CLS = {} + + @classmethod + def generate_key(cls, kty, crv_or_size, options=None, is_private=False): + """Generate a Key with the given key type, curve name or bit size. + + :param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP`` + :param crv_or_size: curve name or bit size + :param options: a dict of other options for Key + :param is_private: create a private key or public key + :return: Key instance + """ + key_cls = cls.JWK_KEY_CLS[kty] + return key_cls.generate_key(crv_or_size, options, is_private) + + @classmethod + def import_key(cls, raw, options=None): + """Import a Key from bytes, string, PEM or dict. + + :return: Key instance + """ + kty = None + if options is not None: + kty = options.get('kty') + + if kty is None and isinstance(raw, dict): + kty = raw.get('kty') + + if kty is None: + raw_key = load_pem_key(raw) + for _kty in cls.JWK_KEY_CLS: + key_cls = cls.JWK_KEY_CLS[_kty] + if isinstance(raw_key, key_cls.RAW_KEY_CLS): + return key_cls.import_key(raw_key, options) + + key_cls = cls.JWK_KEY_CLS[kty] + return key_cls.import_key(raw, options) + + @classmethod + def import_key_set(cls, raw): + """Import KeySet from string, dict or a list of keys. + + :return: KeySet instance + """ + raw = _transform_raw_key(raw) + if isinstance(raw, dict) and 'keys' in raw: + keys = raw.get('keys') + return KeySet([cls.import_key(k) for k in keys]) + + +def _transform_raw_key(raw): + if isinstance(raw, text_types) and \ + raw.startswith('{') and raw.endswith('}'): + return json_loads(raw) + elif isinstance(raw, (tuple, list)): + return {'keys': raw} + return raw diff --git a/authlib/jose/rfc7518/_cryptography_backends/_keys.py b/authlib/jose/rfc7518/_cryptography_backends/_keys.py index 9ca4389..786a49d 100644 --- a/authlib/jose/rfc7518/_cryptography_backends/_keys.py +++ b/authlib/jose/rfc7518/_cryptography_backends/_keys.py @@ -1,6 +1,4 @@ -from cryptography.x509 import load_pem_x509_certificate from cryptography.hazmat.primitives.serialization import ( - load_pem_private_key, load_pem_public_key, load_ssh_public_key, Encoding, PrivateFormat, PublicFormat, BestAvailableEncryption, NoEncryption, ) @@ -17,7 +15,7 @@ SECP256R1, SECP384R1, SECP521R1, ) from cryptography.hazmat.backends import default_backend -from authlib.jose.rfc7517 import Key +from authlib.jose.rfc7517 import Key, load_pem_key from authlib.common.encoding import to_bytes from authlib.common.encoding import base64_to_int, int_to_base64 @@ -236,34 +234,6 @@ def generate_key(cls, crv='P-256', options=None, is_private=False): return cls.import_key(raw_key, options=options) -def load_pem_key(raw, ssh_type=None, key_type=None, password=None): - raw = to_bytes(raw) - - if ssh_type and raw.startswith(ssh_type): - return load_ssh_public_key(raw, backend=default_backend()) - - if key_type == 'public': - return load_pem_public_key(raw, backend=default_backend()) - - if key_type == 'private' or password is not None: - return load_pem_private_key(raw, password=password, backend=default_backend()) - - if b'PUBLIC' in raw: - return load_pem_public_key(raw, backend=default_backend()) - - if b'PRIVATE' in raw: - return load_pem_private_key(raw, password=password, backend=default_backend()) - - if b'CERTIFICATE' in raw: - cert = load_pem_x509_certificate(raw, default_backend()) - return cert.public_key() - - try: - return load_pem_private_key(raw, password=password, backend=default_backend()) - except ValueError: - return load_pem_public_key(raw, backend=default_backend()) - - def import_key(cls, raw, public_key_cls, private_key_cls, ssh_type=None, options=None): if isinstance(raw, cls): if options is not None: diff --git a/authlib/jose/rfc7519/jwt.py b/authlib/jose/rfc7519/jwt.py index 7ffdebc..c7339ef 100644 --- a/authlib/jose/rfc7519/jwt.py +++ b/authlib/jose/rfc7519/jwt.py @@ -9,7 +9,6 @@ from ..errors import DecodeError, InsecureClaimError from ..rfc7515 import JsonWebSignature from ..rfc7516 import JsonWebEncryption -from ..rfc7517 import KeySet class JsonWebToken(object): @@ -60,9 +59,7 @@ def encode(self, header, payload, key, check=True): if check: self.check_sensitive_data(payload) - if isinstance(key, KeySet): - key = key.find_by_kid(header.get('kid')) - + key = prepare_raw_key(key, header) text = to_bytes(json_dumps(payload)) if 'enc' in header: return self._jwe.serialize_compact(header, text, key) @@ -86,11 +83,8 @@ def decode(self, s, key, claims_cls=None, if claims_cls is None: claims_cls = JWTClaims - if isinstance(key, KeySet): - def load_key(header, payload): - return key.find_by_kid(header.get('kid')) - else: - load_key = key + def load_key(header, payload): + return prepare_raw_key(key, header) s = to_bytes(s) dot_count = s.count(b'.') @@ -115,3 +109,23 @@ def decode_payload(bytes_payload): if not isinstance(payload, dict): raise DecodeError('Invalid payload type') return payload + + +def prepare_raw_key(raw, headers=None): + if isinstance(raw, text_types) and \ + raw.startswith('{') and raw.endswith('}'): + raw = json_loads(raw) + elif isinstance(raw, (tuple, list)): + raw = {'keys': raw} + + if isinstance(raw, dict) and 'keys' in raw: + keys = raw['keys'] + if headers is not None: + kid = headers.get('kid') + else: + kid = None + for k in keys: + if k.get('kid') == kid: + return k + raise ValueError('Invalid JSON Web Key Set') + return raw diff --git a/tests/core/test_jose/test_jwt.py b/tests/core/test_jose/test_jwt.py index 106149e..df73245 100644 --- a/tests/core/test_jose/test_jwt.py +++ b/tests/core/test_jose/test_jwt.py @@ -2,7 +2,7 @@ import datetime from authlib.jose import errors from authlib.jose import JsonWebToken, JWTClaims, jwt -from authlib.jose.errors import UnsupportedAlgorithmError, InvalidUseError +from authlib.jose.errors import UnsupportedAlgorithmError from tests.util import read_file_path @@ -177,6 +177,16 @@ def test_use_jwe(self): claims = jwt.decode(data, private_key) self.assertEqual(claims['name'], 'hi') + def test_use_jwks(self): + header = {'alg': 'RS256', 'kid': 'abc'} + payload = {'name': 'hi'} + private_key = read_file_path('jwks_private.json') + pub_key = read_file_path('jwks_public.json') + data = jwt.encode(header, payload, private_key) + self.assertEqual(data.count(b'.'), 2) + claims = jwt.decode(data, pub_key) + self.assertEqual(claims['name'], 'hi') + def test_with_ec(self): payload = {'name': 'hi'} private_key = read_file_path('ec_private.json')