Skip to content

Commit

Permalink
fix: prefer RS256 for JWT validation
Browse files Browse the repository at this point in the history
  • Loading branch information
olevski committed Jan 25, 2024
1 parent ae3618c commit a857078
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 41 deletions.
3 changes: 3 additions & 0 deletions renku/ui/service/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from renku.ui.service.logger import service_log
from renku.ui.service.serializers.headers import JWT_TOKEN_SECRET
from renku.ui.service.utils.json_encoder import SvcJSONProvider
from renku.ui.service.utils import jwk_client
from renku.ui.service.views import error_response
from renku.ui.service.views.apispec import apispec_blueprint
from renku.ui.service.views.cache import cache_blueprint
Expand Down Expand Up @@ -76,6 +77,8 @@ def create_app(custom_exceptions=True):

app.config["cache"] = cache

app.config["KEYCLOAK_JWK_CLIENT"] = jwk_client()

if not is_test_session_running():
GunicornPrometheusMetrics(app)

Expand Down
57 changes: 18 additions & 39 deletions renku/ui/service/serializers/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import base64
import binascii
import os
from typing import cast

import jwt
from marshmallow import Schema, ValidationError, fields, post_load, pre_load
from flask import app
from marshmallow import Schema, ValidationError, fields, post_load
from werkzeug.utils import secure_filename

JWT_TOKEN_SECRET = os.getenv("RENKU_JWT_TOKEN_SECRET", "bW9menZ3cnh6cWpkcHVuZ3F5aWJycmJn")
Expand Down Expand Up @@ -79,7 +81,7 @@ class RenkuHeaders:

@staticmethod
def decode_token(token):
"""Extract authorization token."""
"""Extract the Gitlab access token form a bearer authorization header value."""
components = token.split(" ")

rfc_compliant = token.lower().startswith("bearer")
Expand All @@ -92,45 +94,22 @@ def decode_token(token):

@staticmethod
def decode_user(data):
"""Extract renku user from a JWT."""
decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku")
"""Extract renku user from the Keycloak ID token which is a JWT."""
try:
jwk = cast(jwt.PyJWKClient, app.config["KEYCLOAK_JWK_CLIENT"])
key = jwk.get_signing_key_from_jwt(data)
decoded = jwt.decode(data, key=key, algorithms=["RS256"], audience="renku")
except jwt.PyJWTError:
# NOTE: older tokens used to be signed with HS256 so use this as a backup if the validation with RS256
# above fails. We used to need HS256 because a step that is now removed was generating an ID token and
# signing it from data passed in individual header fields.
decoded = jwt.decode(data, JWT_TOKEN_SECRET, algorithms=["HS256"], audience="renku")
return UserIdentityToken().load(decoded)

@staticmethod
def reset_old_headers(data):
"""Process old version of old headers."""
# TODO: This should be removed once support for them is phased out.
if "renku-user-id" in data:
data.pop("renku-user-id")

if "renku-user-fullname" in data and "renku-user-email" in data:
renku_user = {
"aud": ["renku"],
"name": decode_b64(data.pop("renku-user-fullname")),
"email": decode_b64(data.pop("renku-user-email")),
}
renku_user["sub"] = renku_user["email"]
data["renku-user"] = jwt.encode(renku_user, JWT_TOKEN_SECRET, algorithm="HS256")

return data


class IdentityHeaders(Schema):
"""User identity schema."""

@pre_load
def set_fields(self, data, **kwargs):
"""Set fields for serialization."""
# NOTE: We don't process headers which are not meant for determining identity.
# TODO: Remove old headers support once support for them is phased out.
old_keys = ["renku-user-id", "renku-user-fullname", "renku-user-email"]
expected_keys = old_keys + [field.data_key for field in self.fields.values()]

data = {key.lower(): value for key, value in data.items() if key.lower() in expected_keys}
data = RenkuHeaders.reset_old_headers(data)

return data

@post_load
def set_user(self, data, **kwargs):
"""Extract user object from a JWT."""
Expand All @@ -151,12 +130,12 @@ def set_user(self, data, **kwargs):
class RequiredIdentityHeaders(IdentityHeaders):
"""Identity schema for required headers."""

user_token = fields.String(required=True, data_key="renku-user")
auth_token = fields.String(required=True, data_key="authorization")
user_token = fields.String(required=True, data_key="renku-user") # Keycloak ID token
auth_token = fields.String(required=True, data_key="authorization") # Gitlab access token


class OptionalIdentityHeaders(IdentityHeaders):
"""Identity schema for optional headers."""

user_token = fields.String(data_key="renku-user")
auth_token = fields.String(data_key="authorization")
user_token = fields.String(data_key="renku-user") # Keycloak ID token
auth_token = fields.String(data_key="authorization") # Gitlab access token
44 changes: 42 additions & 2 deletions renku/ui/service/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Renku service utility functions."""
from typing import Optional, overload
from time import sleep
from typing import Any, Dict, Optional, overload

from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH
import requests
import urllib
from jwt import PyJWKClient

from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH, OIDC_URL
from renku.ui.service.errors import ProgramInternalError
from renku.ui.service.logger import service_log
from renku.core.util.requests import get


def make_project_path(user, project):
Expand Down Expand Up @@ -86,3 +94,35 @@ def normalize_git_url(git_url: Optional[str]) -> Optional[str]:
git_url = git_url[: -len(".git")]

return git_url


def oidc_discovery() -> Dict[str, Any]:
"""Query the OIDC discovery endpoint from Keycloak with retries, parse the result with JSON and it."""
retries = 0
max_retries = 30
sleep_seconds = 2
while True:
retries += 1
try:
res: requests.Response = get(OIDC_URL)
except (requests.exceptions.HTTPError, urllib.error.HTTPError) as e:
if not retries < max_retries:
service_log.error("Failed to get OIDC discovery data after all retries - the server cannot start.")
raise e
service_log.info(
f"Failed to get OIDC discovery data from {OIDC_URL}, sleeping for {sleep_seconds} seconds and retrying"
)
sleep(sleep_seconds)
else:
service_log.info(f"Successfully fetched OIDC discovery data from {OIDC_URL}")
return res.json()


def jwk_client() -> PyJWKClient:
"""Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation"""
oidc_data = oidc_discovery()
jwks_uri = oidc_data.get("jwks_uri")
if not jwks_uri:
raise ProgramInternalError(error_message="Could not find JWK URI in the OIDC discovery data")
jwk = PyJWKClient(jwks_uri)
return jwk

0 comments on commit a857078

Please sign in to comment.