From a857078cc1e0a7ba9281a39867ff88fcaae24a93 Mon Sep 17 00:00:00 2001 From: Tasko Olevski Date: Thu, 25 Jan 2024 15:31:45 +0100 Subject: [PATCH] fix: prefer RS256 for JWT validation --- renku/ui/service/entrypoint.py | 3 ++ renku/ui/service/serializers/headers.py | 57 ++++++++----------------- renku/ui/service/utils/__init__.py | 44 ++++++++++++++++++- 3 files changed, 63 insertions(+), 41 deletions(-) diff --git a/renku/ui/service/entrypoint.py b/renku/ui/service/entrypoint.py index 3f511eab6e..7449516497 100644 --- a/renku/ui/service/entrypoint.py +++ b/renku/ui/service/entrypoint.py @@ -41,6 +41,7 @@ from renku.ui.service.logger import service_log from renku.ui.service.serializers.headers import JWT_TOKEN_SECRET from renku.ui.service.utils.json_encoder import SvcJSONProvider +from renku.ui.service.utils import jwk_client from renku.ui.service.views import error_response from renku.ui.service.views.apispec import apispec_blueprint from renku.ui.service.views.cache import cache_blueprint @@ -76,6 +77,8 @@ def create_app(custom_exceptions=True): app.config["cache"] = cache + app.config["KEYCLOAK_JWK_CLIENT"] = jwk_client() + if not is_test_session_running(): GunicornPrometheusMetrics(app) diff --git a/renku/ui/service/serializers/headers.py b/renku/ui/service/serializers/headers.py index 2eca365afd..51ef6c3851 100644 --- a/renku/ui/service/serializers/headers.py +++ b/renku/ui/service/serializers/headers.py @@ -17,9 +17,11 @@ import base64 import binascii import os +from typing import cast import jwt -from marshmallow import Schema, ValidationError, fields, post_load, pre_load +from flask import app +from marshmallow import Schema, ValidationError, fields, post_load from werkzeug.utils import secure_filename JWT_TOKEN_SECRET = os.getenv("RENKU_JWT_TOKEN_SECRET", "bW9menZ3cnh6cWpkcHVuZ3F5aWJycmJn") @@ -79,7 +81,7 @@ class RenkuHeaders: @staticmethod def decode_token(token): - """Extract authorization token.""" + """Extract the Gitlab access token form a bearer authorization header value.""" components = token.split(" ") rfc_compliant = token.lower().startswith("bearer") @@ -92,45 +94,22 @@ def decode_token(token): @staticmethod def decode_user(data): - """Extract renku user from a JWT.""" - decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku") + """Extract renku user from the Keycloak ID token which is a JWT.""" + try: + jwk = cast(jwt.PyJWKClient, app.config["KEYCLOAK_JWK_CLIENT"]) + key = jwk.get_signing_key_from_jwt(data) + decoded = jwt.decode(data, key=key, algorithms=["RS256"], audience="renku") + except jwt.PyJWTError: + # NOTE: older tokens used to be signed with HS256 so use this as a backup if the validation with RS256 + # above fails. We used to need HS256 because a step that is now removed was generating an ID token and + # signing it from data passed in individual header fields. + decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku") return UserIdentityToken().load(decoded) - @staticmethod - def reset_old_headers(data): - """Process old version of old headers.""" - # TODO: This should be removed once support for them is phased out. - if "renku-user-id" in data: - data.pop("renku-user-id") - - if "renku-user-fullname" in data and "renku-user-email" in data: - renku_user = { - "aud": ["renku"], - "name": decode_b64(data.pop("renku-user-fullname")), - "email": decode_b64(data.pop("renku-user-email")), - } - renku_user["sub"] = renku_user["email"] - data["renku-user"] = jwt.encode(renku_user, JWT_TOKEN_SECRET, algorithm="HS256") - - return data - class IdentityHeaders(Schema): """User identity schema.""" - @pre_load - def set_fields(self, data, **kwargs): - """Set fields for serialization.""" - # NOTE: We don't process headers which are not meant for determining identity. - # TODO: Remove old headers support once support for them is phased out. - old_keys = ["renku-user-id", "renku-user-fullname", "renku-user-email"] - expected_keys = old_keys + [field.data_key for field in self.fields.values()] - - data = {key.lower(): value for key, value in data.items() if key.lower() in expected_keys} - data = RenkuHeaders.reset_old_headers(data) - - return data - @post_load def set_user(self, data, **kwargs): """Extract user object from a JWT.""" @@ -151,12 +130,12 @@ def set_user(self, data, **kwargs): class RequiredIdentityHeaders(IdentityHeaders): """Identity schema for required headers.""" - user_token = fields.String(required=True, data_key="renku-user") - auth_token = fields.String(required=True, data_key="authorization") + user_token = fields.String(required=True, data_key="renku-user") # Keycloak ID token + auth_token = fields.String(required=True, data_key="authorization") # Gitlab access token class OptionalIdentityHeaders(IdentityHeaders): """Identity schema for optional headers.""" - user_token = fields.String(data_key="renku-user") - auth_token = fields.String(data_key="authorization") + user_token = fields.String(data_key="renku-user") # Keycloak ID token + auth_token = fields.String(data_key="authorization") # Gitlab access token diff --git a/renku/ui/service/utils/__init__.py b/renku/ui/service/utils/__init__.py index 390f55c490..b315701c54 100644 --- a/renku/ui/service/utils/__init__.py +++ b/renku/ui/service/utils/__init__.py @@ -14,9 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. """Renku service utility functions.""" -from typing import Optional, overload +from time import sleep +from typing import Any, Dict, Optional, overload -from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH +import requests +import urllib +from jwt import PyJWKClient + +from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH, OIDC_URL +from renku.ui.service.errors import ProgramInternalError +from renku.ui.service.logger import service_log +from renku.core.util.requests import get def make_project_path(user, project): @@ -86,3 +94,35 @@ def normalize_git_url(git_url: Optional[str]) -> Optional[str]: git_url = git_url[: -len(".git")] return git_url + + +def oidc_discovery() -> Dict[str, Any]: + """Query the OIDC discovery endpoint from Keycloak with retries, parse the result with JSON and it.""" + retries = 0 + max_retries = 30 + sleep_seconds = 2 + while True: + retries += 1 + try: + res: requests.Response = get(OIDC_URL) + except (requests.exceptions.HTTPError, urllib.error.HTTPError) as e: + if not retries < max_retries: + service_log.error("Failed to get OIDC discovery data after all retries - the server cannot start.") + raise e + service_log.info( + f"Failed to get OIDC discovery data from {OIDC_URL}, sleeping for {sleep_seconds} seconds and retrying" + ) + sleep(sleep_seconds) + else: + service_log.info(f"Successfully fetched OIDC discovery data from {OIDC_URL}") + return res.json() + + +def jwk_client() -> PyJWKClient: + """Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation""" + oidc_data = oidc_discovery() + jwks_uri = oidc_data.get("jwks_uri") + if not jwks_uri: + raise ProgramInternalError(error_message="Could not find JWK URI in the OIDC discovery data") + jwk = PyJWKClient(jwks_uri) + return jwk