From fc5b94eb3575254caba599218246616c75fecdc7 Mon Sep 17 00:00:00 2001 From: "Haoyu(Jerry) Wu" Date: Mon, 1 Aug 2022 09:37:48 -0700 Subject: [PATCH] Add cacheing functionality for JWK set (#781) * Initial implementation of ttl jwk set cache (cherry picked from commit 479a7c124d63113a2190bd48972cc19172215096) * Add unit test for jwk set cache * Fix failed unit test * Disable cache signing key by default * Add a negative unit test for get_jwk_set * Add functionality to force refresh the jwk set cache when no matching signing key can be found from the cache * Add unit test for refresh cache * Add unit test to unset cache when the network call throws error * fix naming typo * Update unit test naming * Update comment * Add check for lifespan * Update comments for get_signing_key * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix ci error * Add type declaration to fix CI error * Add more unit tests to improve coverage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Try to increase test coverage to 100% Co-authored-by: Jerry Wu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- jwt/api_jwk.py | 13 +++ jwt/jwk_set_cache.py | 32 ++++++++ jwt/jwks_client.py | 82 ++++++++++++++---- tests/test_jwks_client.py | 169 ++++++++++++++++++++++++++++++++++---- 4 files changed, 263 insertions(+), 33 deletions(-) create mode 100644 jwt/jwk_set_cache.py diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index bbd8a9e1..aa3dd321 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import time from .algorithms import get_default_algorithms from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError @@ -110,3 +111,15 @@ def __getitem__(self, kid): if key.key_id == kid: return key raise KeyError(f"keyset has no key for kid: {kid}") + + +class PyJWTSetWithTimestamp: + def __init__(self, jwk_set: PyJWKSet): + self.jwk_set = jwk_set + self.timestamp = time.monotonic() + + def get_jwk_set(self): + return self.jwk_set + + def get_timestamp(self): + return self.timestamp diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py new file mode 100644 index 00000000..e8c2a7e0 --- /dev/null +++ b/jwt/jwk_set_cache.py @@ -0,0 +1,32 @@ +import time +from typing import Optional + +from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp + + +class JWKSetCache: + def __init__(self, lifespan: int): + self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None + self.lifespan = lifespan + + def put(self, jwk_set: PyJWKSet): + if jwk_set is not None: + self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) + else: + # clear cache + self.jwk_set_with_timestamp = None + + def get(self) -> Optional[PyJWKSet]: + if self.jwk_set_with_timestamp is None or self.is_expired(): + return None + + return self.jwk_set_with_timestamp.get_jwk_set() + + def is_expired(self) -> bool: + + return ( + self.jwk_set_with_timestamp is not None + and self.lifespan > -1 + and time.monotonic() + > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan + ) diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 767b7179..b4e98007 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -1,31 +1,68 @@ import json import urllib.request from functools import lru_cache -from typing import Any, List +from typing import Any, List, Optional +from urllib.error import URLError from .api_jwk import PyJWK, PyJWKSet from .api_jwt import decode_complete as decode_token from .exceptions import PyJWKClientError +from .jwk_set_cache import JWKSetCache class PyJWKClient: - def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16): + def __init__( + self, + uri: str, + cache_keys: bool = False, + max_cached_keys: int = 16, + cache_jwk_set: bool = True, + lifespan: int = 300, + ): self.uri = uri + self.jwk_set_cache: Optional[JWKSetCache] = None + + if cache_jwk_set: + # Init jwt set cache with default or given lifespan. + # Default lifespan is 300 seconds (5 minutes). + if lifespan <= 0: + raise PyJWKClientError( + f'Lifespan must be greater than 0, the input is "{lifespan}"' + ) + self.jwk_set_cache = JWKSetCache(lifespan) + else: + self.jwk_set_cache = None + if cache_keys: # Cache signing keys # Ignore mypy (https://github.com/python/mypy/issues/2427) self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore def fetch_data(self) -> Any: - with urllib.request.urlopen(self.uri) as response: - return json.load(response) + jwk_set: Any = None + try: + with urllib.request.urlopen(self.uri) as response: + jwk_set = json.load(response) + except URLError as e: + raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') + else: + return jwk_set + finally: + if self.jwk_set_cache is not None: + self.jwk_set_cache.put(jwk_set) + + def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: + data = None + if self.jwk_set_cache is not None and not refresh: + data = self.jwk_set_cache.get() + + if data is None: + data = self.fetch_data() - def get_jwk_set(self) -> PyJWKSet: - data = self.fetch_data() return PyJWKSet.from_dict(data) - def get_signing_keys(self) -> List[PyJWK]: - jwk_set = self.get_jwk_set() + def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: + jwk_set = self.get_jwk_set(refresh) signing_keys = [ jwk_set_key for jwk_set_key in jwk_set.keys @@ -39,17 +76,17 @@ def get_signing_keys(self) -> List[PyJWK]: def get_signing_key(self, kid: str) -> PyJWK: signing_keys = self.get_signing_keys() - signing_key = None - - for key in signing_keys: - if key.key_id == kid: - signing_key = key - break + signing_key = self.match_kid(signing_keys, kid) if not signing_key: - raise PyJWKClientError( - f'Unable to find a signing key that matches: "{kid}"' - ) + # If no matching signing key from the jwk set, refresh the jwk set and try again. + signing_keys = self.get_signing_keys(refresh=True) + signing_key = self.match_kid(signing_keys, kid) + + if not signing_key: + raise PyJWKClientError( + f'Unable to find a signing key that matches: "{kid}"' + ) return signing_key @@ -57,3 +94,14 @@ def get_signing_key_from_jwt(self, token: str) -> PyJWK: unverified = decode_token(token, options={"verify_signature": False}) header = unverified["header"] return self.get_signing_key(header.get("kid")) + + @staticmethod + def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]: + signing_key = None + + for key in signing_keys: + if key.key_id == kid: + signing_key = key + break + + return signing_key diff --git a/tests/test_jwks_client.py b/tests/test_jwks_client.py index 3e42da17..c95dfcc0 100644 --- a/tests/test_jwks_client.py +++ b/tests/test_jwks_client.py @@ -1,6 +1,8 @@ import contextlib import json +import time from unittest import mock +from urllib.error import URLError import pytest @@ -11,7 +13,7 @@ from .utils import crypto_required -RESPONSE_DATA = { +RESPONSE_DATA_WITH_MATCHING_KID = { "keys": [ { "alg": "RS256", @@ -28,9 +30,22 @@ ] } +RESPONSE_DATA_NO_MATCHING_KID = { + "keys": [ + { + "alg": "RS256", + "kty": "RSA", + "use": "sig", + "n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ", + "e": "AQAB", + "kid": "MLYHNMMhwCNXw9roHIILFsK4nLs=", + } + ] +} + @contextlib.contextmanager -def mocked_response(data): +def mocked_success_response(data): with mock.patch("urllib.request.urlopen") as urlopen_mock: response = mock.Mock() response.__enter__ = mock.Mock(return_value=response) @@ -40,12 +55,35 @@ def mocked_response(data): yield urlopen_mock +@contextlib.contextmanager +def mocked_failed_response(): + with mock.patch("urllib.request.urlopen") as urlopen_mock: + urlopen_mock.side_effect = URLError("Fail to process the request.") + yield urlopen_mock + + +@contextlib.contextmanager +def mocked_first_call_wrong_kid_second_call_correct_kid( + response_data_one, response_data_two +): + with mock.patch("urllib.request.urlopen") as urlopen_mock: + response = mock.Mock() + response.__enter__ = mock.Mock(return_value=response) + response.__exit__ = mock.Mock() + response.read.side_effect = [ + json.dumps(response_data_one), + json.dumps(response_data_two), + ] + urlopen_mock.return_value = response + yield urlopen_mock + + @crypto_required class TestPyJWKClient: def test_get_jwk_set(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) jwk_set = jwks_client.get_jwk_set() @@ -54,7 +92,7 @@ def test_get_jwk_set(self): def test_get_signing_keys(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -64,11 +102,11 @@ def test_get_signing_keys(self): def test_get_signing_keys_if_no_use_provided(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA["keys"][0].copy() + mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy() del mocked_key["use"] response = {"keys": [mocked_key]} - with mocked_response(response): + with mocked_success_response(response): jwks_client = PyJWKClient(url) signing_keys = jwks_client.get_signing_keys() @@ -78,10 +116,10 @@ def test_get_signing_keys_if_no_use_provided(self): def test_get_signing_keys_raises_if_none_found(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - mocked_key = RESPONSE_DATA["keys"][0].copy() + mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy() mocked_key["use"] = "enc" response = {"keys": [mocked_key]} - with mocked_response(response): + with mocked_success_response(response): jwks_client = PyJWKClient(url) with pytest.raises(PyJWKClientError) as exc: @@ -93,7 +131,7 @@ def test_get_signing_key(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key(kid) @@ -106,14 +144,14 @@ def test_get_signing_key_caches_result(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - jwks_client = PyJWKClient(url) + jwks_client = PyJWKClient(url, cache_keys=True) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 0 @@ -122,14 +160,14 @@ def test_get_signing_key_does_not_cache_opt_out(self): url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" - jwks_client = PyJWKClient(url, cache_keys=False) + jwks_client = PyJWKClient(url, cache_jwk_set=False) - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client.get_signing_key(kid) # mocked_response does not allow urllib.request.urlopen to be called twice # so a second mock is needed - with mocked_response(RESPONSE_DATA) as repeated_call: + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: jwks_client.get_signing_key(kid) assert repeated_call.call_count == 1 @@ -138,7 +176,7 @@ def test_get_signing_key_from_jwt(self): token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA" url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" - with mocked_response(RESPONSE_DATA): + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): jwks_client = PyJWKClient(url) signing_key = jwks_client.get_signing_key_from_jwt(token) @@ -159,3 +197,102 @@ def test_get_signing_key_from_jwt(self): "azp": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC", "gty": "client-credentials", } + + def test_get_jwk_set_caches_result(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url) + assert jwks_client.jwk_set_cache is not None + + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 0 + + def test_get_jwt_set_cache_expired_result(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url, lifespan=1) + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + time.sleep(2) + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 1 + + def test_get_jwt_set_cache_disabled(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url, cache_jwk_set=False) + assert jwks_client.jwk_set_cache is None + + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + assert jwks_client.jwk_set_cache is None + + time.sleep(2) + + # mocked_response does not allow urllib.request.urlopen to be called twice + # so a second mock is needed + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call: + jwks_client.get_jwk_set() + + assert repeated_call.call_count == 1 + + def test_get_jwt_set_failed_request_should_clear_cache(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + jwks_client = PyJWKClient(url) + with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID): + jwks_client.get_jwk_set() + + with pytest.raises(PyJWKClientError): + with mocked_failed_response(): + jwks_client.get_jwk_set(refresh=True) + + assert jwks_client.jwk_set_cache is None + + def test_get_jwt_set_refresh_cache(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url) + + kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" + + # The first call will return response with no matching kid, + # the function should make another call to try to refresh the cache. + with mocked_first_call_wrong_kid_second_call_correct_kid( + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID + ) as call_data: + jwks_client.get_signing_key(kid) + + assert call_data.call_count == 2 + + def test_get_jwt_set_no_matching_kid_after_second_attempt(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + jwks_client = PyJWKClient(url) + + kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw" + + with pytest.raises(PyJWKClientError): + with mocked_first_call_wrong_kid_second_call_correct_kid( + RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_NO_MATCHING_KID + ): + jwks_client.get_signing_key(kid) + + def test_get_jwt_set_invalid_lifespan(self): + url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json" + + with pytest.raises(PyJWKClientError): + jwks_client = PyJWKClient(url, lifespan=-1) + assert jwks_client is None