Skip to content

Commit

Permalink
providers/oauth2: fix amr claim not set due to login event not associ…
Browse files Browse the repository at this point in the history
…ated (#11780)

* providers/oauth2: fix amr claim not set due to login event not associated

Signed-off-by: Jens Langhammer <[email protected]>

* add sid claim

Signed-off-by: Jens Langhammer <[email protected]>

* import engine only once

Signed-off-by: Jens Langhammer <[email protected]>

* remove manual sid extraction from proxy, add test, make session key hashing more obvious

Signed-off-by: Jens Langhammer <[email protected]>

* unrelated string fix

Signed-off-by: Jens Langhammer <[email protected]>

* fix format

Signed-off-by: Jens Langhammer <[email protected]>

* fix tests

Signed-off-by: Jens Langhammer <[email protected]>

---------

Signed-off-by: Jens Langhammer <[email protected]>
  • Loading branch information
BeryJu authored Oct 23, 2024
1 parent da73d4f commit 3bdb287
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 33 deletions.
19 changes: 16 additions & 3 deletions authentik/events/signals.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""authentik events signal listener"""

from importlib import import_module
from typing import Any

from django.conf import settings
from django.contrib.auth.signals import user_logged_in, user_logged_out
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from rest_framework.request import Request

from authentik.core.models import User
from authentik.core.models import AuthenticatedSession, User
from authentik.core.signals import login_failed, password_changed
from authentik.events.apps import SYSTEM_TASK_STATUS
from authentik.events.models import Event, EventAction, SystemTask
Expand All @@ -23,6 +26,7 @@
from authentik.tenants.utils import get_current_tenant

SESSION_LOGIN_EVENT = "login_event"
_session_engine = import_module(settings.SESSION_ENGINE)


@receiver(user_logged_in)
Expand All @@ -43,11 +47,20 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_):
kwargs[PLAN_CONTEXT_OUTPOST] = flow_plan.context[PLAN_CONTEXT_OUTPOST]
event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user)
request.session[SESSION_LOGIN_EVENT] = event
request.session.save()


def get_login_event(request: HttpRequest) -> Event | None:
def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | None) -> Event | None:
"""Wrapper to get login event that can be mocked in tests"""
return request.session.get(SESSION_LOGIN_EVENT, None)
session = None
if not request_or_session:
return None
if isinstance(request_or_session, HttpRequest | Request):
session = request_or_session.session
if isinstance(request_or_session, AuthenticatedSession):
SessionStore = _session_engine.SessionStore
session = SessionStore(request_or_session.session_key)
return session.get(SESSION_LOGIN_EVENT, None)


@receiver(user_logged_out)
Expand Down
17 changes: 14 additions & 3 deletions authentik/providers/oauth2/id_token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""id_token utils"""

from dataclasses import asdict, dataclass, field
from hashlib import sha256
from typing import TYPE_CHECKING, Any

from django.db import models
Expand All @@ -23,8 +24,13 @@
from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider


def hash_session_key(session_key: str) -> str:
"""Hash the session key for inclusion in JWTs as `sid`"""
return sha256(session_key.encode("ascii")).hexdigest()


class SubModes(models.TextChoices):
"""Mode after which 'sub' attribute is generateed, for compatibility reasons"""
"""Mode after which 'sub' attribute is generated, for compatibility reasons"""

HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID")
USER_ID = "user_id", _("Based on user ID")
Expand All @@ -51,7 +57,8 @@ class IDToken:
and potentially other requested Claims. The ID Token is represented as a
JSON Web Token (JWT) [JWT].
https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
https://openid.net/specs/openid-connect-core-1_0.html#IDToken
https://www.iana.org/assignments/jwt/jwt.xhtml"""

# Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1
iss: str | None = None
Expand Down Expand Up @@ -79,6 +86,8 @@ class IDToken:
nonce: str | None = None
# Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html
at_hash: str | None = None
# Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents
sid: str | None = None

claims: dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -116,9 +125,11 @@ def new(
now = timezone.now()
id_token.iat = int(now.timestamp())
id_token.auth_time = int(token.auth_time.timestamp())
if token.session:
id_token.sid = hash_session_key(token.session.session_key)

# We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time
auth_event = get_login_event(request)
auth_event = get_login_event(token.session)
if auth_event:
# Also check which method was used for authentication
method = auth_event.context.get(PLAN_CONTEXT_METHOD, "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import django.db.models.deletion
from django.apps.registry import Apps
from django.db import migrations, models
from django.db.backends.base.schema import BaseDatabaseSchemaEditor

import authentik.lib.utils.time

Expand All @@ -14,7 +15,7 @@
}


def set_managed_flag(apps: Apps, schema_editor):
def set_managed_flag(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping")
db_alias = schema_editor.connection.alias
for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Generated by Django 5.0.9 on 2024-10-23 13:38

from hashlib import sha256
import django.db.models.deletion
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from authentik.lib.migrations import progress_bar


def migrate_session(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
AuthenticatedSession = apps.get_model("authentik_core", "authenticatedsession")
AuthorizationCode = apps.get_model("authentik_providers_oauth2", "authorizationcode")
AccessToken = apps.get_model("authentik_providers_oauth2", "accesstoken")
RefreshToken = apps.get_model("authentik_providers_oauth2", "refreshtoken")
db_alias = schema_editor.connection.alias

print(f"\nFetching session keys, this might take a couple of minutes...")
session_ids = {}
for session in progress_bar(AuthenticatedSession.objects.using(db_alias).all()):
session_ids[sha256(session.session_key.encode("ascii")).hexdigest()] = session.session_key
for model in [AuthorizationCode, AccessToken, RefreshToken]:
print(
f"\nAdding session to {model._meta.verbose_name}, this might take a couple of minutes..."
)
for code in progress_bar(model.objects.using(db_alias).all()):
if code.session_id_old not in session_ids:
continue
code.session = (
AuthenticatedSession.objects.using(db_alias)
.filter(session_key=session_ids[code.session_id_old])
.first()
)
code.save()


class Migration(migrations.Migration):

dependencies = [
("authentik_core", "0040_provider_invalidation_flow"),
("authentik_providers_oauth2", "0021_oauth2provider_encryption_key_and_more"),
]

operations = [
migrations.RenameField(
model_name="accesstoken",
old_name="session_id",
new_name="session_id_old",
),
migrations.RenameField(
model_name="authorizationcode",
old_name="session_id",
new_name="session_id_old",
),
migrations.RenameField(
model_name="refreshtoken",
old_name="session_id",
new_name="session_id_old",
),
migrations.AddField(
model_name="accesstoken",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.AddField(
model_name="authorizationcode",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.AddField(
model_name="devicetoken",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.AddField(
model_name="refreshtoken",
name="session",
field=models.ForeignKey(
default=None,
null=True,
on_delete=django.db.models.deletion.SET_DEFAULT,
to="authentik_core.authenticatedsession",
),
),
migrations.RunPython(migrate_session),
migrations.RemoveField(
model_name="accesstoken",
name="session_id_old",
),
migrations.RemoveField(
model_name="authorizationcode",
name="session_id_old",
),
migrations.RemoveField(
model_name="refreshtoken",
name="session_id_old",
),
]
15 changes: 13 additions & 2 deletions authentik/providers/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from structlog.stdlib import get_logger

from authentik.brands.models import WebfingerProvider
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.core.models import (
AuthenticatedSession,
ExpiringModel,
PropertyMapping,
Provider,
User,
)
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key
from authentik.lib.models import SerializerModel
Expand Down Expand Up @@ -353,7 +359,9 @@ class BaseGrantModel(models.Model):
revoked = models.BooleanField(default=False)
_scope = models.TextField(default="", verbose_name=_("Scopes"))
auth_time = models.DateTimeField(verbose_name="Authentication time")
session_id = models.CharField(default="", blank=True)
session = models.ForeignKey(
AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None
)

class Meta:
abstract = True
Expand Down Expand Up @@ -491,6 +499,9 @@ class DeviceToken(ExpiringModel):
device_code = models.TextField(default=generate_key)
user_code = models.TextField(default=generate_code_fixed_length)
_scope = models.TextField(default="", verbose_name=_("Scopes"))
session = models.ForeignKey(
AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None
)

@property
def scope(self) -> list[str]:
Expand Down
5 changes: 1 addition & 4 deletions authentik/providers/oauth2/signals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from hashlib import sha256

from django.contrib.auth.signals import user_logged_out
from django.dispatch import receiver
from django.http import HttpRequest
Expand All @@ -13,5 +11,4 @@ def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User,
"""Revoke access tokens upon user logout"""
if not request.session or not request.session.session_key:
return
hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest()
AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete()
AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete()
11 changes: 7 additions & 4 deletions authentik/providers/oauth2/views/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from dataclasses import InitVar, dataclass, field
from datetime import timedelta
from hashlib import sha256
from json import dumps
from re import error as RegexError
from re import fullmatch
Expand All @@ -16,7 +15,7 @@
from django.utils.translation import gettext as _
from structlog.stdlib import get_logger

from authentik.core.models import Application
from authentik.core.models import Application, AuthenticatedSession
from authentik.events.models import Event, EventAction
from authentik.events.signals import get_login_event
from authentik.flows.challenge import (
Expand Down Expand Up @@ -318,7 +317,9 @@ def create_code(self, request: HttpRequest) -> AuthorizationCode:
expires=now + timedelta_from_string(self.provider.access_code_validity),
scope=self.scope,
nonce=self.nonce,
session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
session=AuthenticatedSession.objects.filter(
session_key=request.session.session_key
).first(),
)

if self.code_challenge and self.code_challenge_method:
Expand Down Expand Up @@ -610,7 +611,9 @@ def create_implicit_response(self, code: AuthorizationCode | None) -> dict:
expires=access_token_expiry,
provider=self.provider,
auth_time=auth_event.created if auth_event else now,
session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
session=AuthenticatedSession.objects.filter(
session_key=self.request.session.session_key
).first(),
)

id_token = IDToken.new(self.provider, token, self.request)
Expand Down
11 changes: 6 additions & 5 deletions authentik/providers/oauth2/views/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def create_code_response(self) -> dict[str, Any]:
# Keep same scopes as previous token
scope=self.params.authorization_code.scope,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
session=self.params.authorization_code.session,
)
access_id_token = IDToken.new(
self.provider,
Expand Down Expand Up @@ -578,7 +578,7 @@ def create_code_response(self) -> dict[str, Any]:
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
session=self.params.authorization_code.session,
)
id_token = IDToken.new(
self.provider,
Expand Down Expand Up @@ -611,7 +611,7 @@ def create_refresh_response(self) -> dict[str, Any]:
# Keep same scopes as previous token
scope=self.params.refresh_token.scope,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
session=self.params.refresh_token.session,
)
access_token.id_token = IDToken.new(
self.provider,
Expand All @@ -627,7 +627,7 @@ def create_refresh_response(self) -> dict[str, Any]:
expires=refresh_token_expiry,
provider=self.provider,
auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
session=self.params.refresh_token.session,
)
id_token = IDToken.new(
self.provider,
Expand Down Expand Up @@ -685,13 +685,14 @@ def create_device_code_response(self) -> dict[str, Any]:
raise DeviceCodeError("authorization_pending")
now = timezone.now()
access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity)
auth_event = get_login_event(self.request)
auth_event = get_login_event(self.params.device_code.session)
access_token = AccessToken(
provider=self.provider,
user=self.params.device_code.user,
expires=access_token_expiry,
scope=self.params.device_code.scope,
auth_time=auth_event.created if auth_event else now,
session=self.params.device_code.session,
)
access_token.id_token = IDToken.new(
self.provider,
Expand Down
5 changes: 2 additions & 3 deletions authentik/providers/proxy/tasks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""proxy provider tasks"""

from hashlib import sha256

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError

from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.oauth2.id_token import hash_session_key
from authentik.providers.proxy.models import ProxyProvider
from authentik.root.celery import CELERY_APP

Expand All @@ -26,7 +25,7 @@ def proxy_set_defaults():
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
Expand Down
Loading

0 comments on commit 3bdb287

Please sign in to comment.