Skip to content

Commit

Permalink
[DOP-23122] Use async methods of Keycloak client
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Dec 23, 2024
1 parent 162b0a0 commit e3bca7e
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 51 deletions.
23 changes: 10 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ onetl = {extras = ["spark", "s3", "hdfs"], version = "^0.12.0"}
faker = ">=28.4.1,<34.0.0"
coverage = "^7.6.1"
gevent = "^24.2.1"
responses = "*"
respx = "*"

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.2"
Expand Down
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
...

@abstractmethod
async def get_current_user(self, access_token: Any, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
"""
This method should return currently logged in user.
Expand Down
2 changes: 1 addition & 1 deletion syncmaster/server/providers/auth/dummy_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setup(cls, app: FastAPI) -> FastAPI:
app.dependency_overrides[DummyAuthProviderSettings] = lambda: settings
return app

async def get_current_user(self, access_token: str, *args, **kwargs) -> User:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
if not access_token:
raise AuthorizationError("Missing auth credentials")

Expand Down
30 changes: 14 additions & 16 deletions syncmaster/server/providers/auth/keycloak_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import Depends, FastAPI, Request
from keycloak import KeycloakOpenID

from syncmaster.db.models import User
from syncmaster.exceptions import EntityNotFoundError
from syncmaster.exceptions.auth import AuthorizationError
from syncmaster.exceptions.redirect import RedirectException
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_token_authorization_code_grant(
) -> dict[str, Any]:
try:
redirect_uri = redirect_uri or self.settings.keycloak.redirect_uri
token = self.keycloak_openid.token(
token = await self.keycloak_openid.a_token(
grant_type="authorization_code",
code=code,
redirect_uri=redirect_uri,
Expand All @@ -72,10 +73,8 @@ async def get_token_authorization_code_grant(
except Exception as e:
raise AuthorizationError("Failed to get token") from e

async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
async def get_current_user(self, access_token: str | None, **kwargs) -> User:
request: Request = kwargs["request"]
refresh_token = request.session.get("refresh_token")

if not access_token:
log.debug("No access token found in session.")
self.redirect_to_auth(request.url.path)
Expand All @@ -86,8 +85,9 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
token_info = self.keycloak_openid.decode_token(token=access_token)
except Exception as e:
log.info("Access token is invalid or expired: %s", e)
token_info = None
token_info = {}

refresh_token = request.session.get("refresh_token")
if not token_info and refresh_token:
log.debug("Access token invalid. Attempting to refresh.")

Expand All @@ -99,9 +99,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
request.session["access_token"] = new_access_token
request.session["refresh_token"] = new_refresh_token

token_info = self.keycloak_openid.decode_token(
token=new_access_token,
)
token_info = self.keycloak_openid.decode_token(token=new_access_token)
log.debug("Access token refreshed and decoded successfully.")
except Exception as e:
log.debug("Failed to refresh access token: %s", e)
Expand All @@ -110,19 +108,19 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
# these names are hardcoded in keycloak:
# https://github.com/keycloak/keycloak/blob/3ca3a4ad349b4d457f6829eaf2ae05f1e01408be/core/src/main/java/org/keycloak/representations/IDToken.java
user_id = token_info.get("sub")
if not user_id:
raise AuthorizationError("Invalid token payload")

login = token_info.get("preferred_username")
email = token_info.get("email")
first_name = token_info.get("given_name")
middle_name = token_info.get("middle_name")
last_name = token_info.get("family_name")

if not user_id:
raise AuthorizationError("Invalid token payload")

async with self._uow:
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
try:
user = await self._uow.user.read_by_username(login)
except EntityNotFoundError:
async with self._uow:
user = await self._uow.user.create(
username=login,
email=email,
Expand All @@ -134,7 +132,7 @@ async def get_current_user(self, access_token: str, *args, **kwargs) -> Any:
return user

async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]:
new_tokens = self.keycloak_openid.refresh_token(refresh_token)
new_tokens = await self.keycloak_openid.a_refresh_token(refresh_token)
return new_tokens

def redirect_to_auth(self, path: str) -> None:
Expand Down
27 changes: 27 additions & 0 deletions syncmaster/server/settings/auth/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0

from pydantic import BaseModel, Field, ImportString


class AuthSettings(BaseModel):
"""Authorization-related settings.
Here you can set auth provider class.
Examples
--------
.. code-block:: bash
SYNCMASTER__AUTH__PROVIDER=syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider
"""

provider: ImportString = Field( # type: ignore[assignment]
default="syncmaster.server.providers.auth.dummy_provider.DummyAuthProvider",
description="Full name of auth provider class",
validate_default=True,
)

class Config:
extra = "allow"
20 changes: 7 additions & 13 deletions tests/test_unit/test_auth/auth_fixtures/keycloak_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from base64 import b64encode

import pytest
import responses
import respx
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
Expand Down Expand Up @@ -85,9 +85,8 @@ def mock_keycloak_well_known(settings):
realm_name = settings.auth.dict()["keycloak"]["client_id"]
well_known_url = f"{server_url}/realms/{realm_name}/.well-known/openid-configuration"

responses.add(
responses.GET,
well_known_url,
respx.get(well_known_url).respond(
status_code=200,
json={
"authorization_endpoint": f"{server_url}/realms/{realm_name}/protocol/openid-connect/auth",
"token_endpoint": f"{server_url}/realms/{realm_name}/protocol/openid-connect/token",
Expand All @@ -96,7 +95,6 @@ def mock_keycloak_well_known(settings):
"jwks_uri": f"{server_url}/realms/{realm_name}/protocol/openid-connect/certs",
"issuer": f"{server_url}/realms/{realm_name}",
},
status=200,
content_type="application/json",
)

Expand All @@ -108,16 +106,14 @@ def mock_keycloak_realm(settings, rsa_keys):
realm_url = f"{server_url}/realms/{realm_name}"
public_pem_str = get_public_key_pem(rsa_keys["public_key"])

responses.add(
responses.GET,
realm_url,
respx.get(realm_url).respond(
status_code=200,
json={
"realm": realm_name,
"public_key": public_pem_str,
"token-service": f"{server_url}/realms/{realm_name}/protocol/openid-connect/token",
"account-service": f"{server_url}/realms/{realm_name}/account",
},
status=200,
content_type="application/json",
)

Expand All @@ -144,15 +140,13 @@ def mock_keycloak_token_refresh(settings, rsa_keys):
new_access_token = jwt.encode(payload, private_pem, algorithm="RS256")
new_refresh_token = "mock_new_refresh_token"

responses.add(
responses.POST,
token_url,
respx.post(token_url).respond(
status_code=200,
json={
"access_token": new_access_token,
"refresh_token": new_refresh_token,
"token_type": "bearer",
"expires_in": expires_in,
},
status=200,
content_type="application/json",
)
12 changes: 6 additions & 6 deletions tests/test_unit/test_auth/test_auth_keycloak.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import pytest
import responses
import respx
from httpx import AsyncClient

from syncmaster.server.settings import ServerAppSettings as Settings
Expand All @@ -11,7 +11,7 @@
pytestmark = [pytest.mark.asyncio, pytest.mark.server]


@responses.activate
@respx.mock
@pytest.mark.parametrize(
"settings",
[
Expand All @@ -33,7 +33,7 @@ async def test_get_keycloak_user_unauthorized(client: AsyncClient, mock_keycloak
)


@responses.activate
@respx.mock
@pytest.mark.parametrize(
"settings",
[
Expand Down Expand Up @@ -71,7 +71,7 @@ async def test_get_keycloak_user_authorized(
}


@responses.activate
@respx.mock
@pytest.mark.parametrize(
"settings",
[
Expand Down Expand Up @@ -116,7 +116,7 @@ async def test_get_keycloak_user_expired_access_token(
}


@responses.activate
@respx.mock
@pytest.mark.parametrize(
"settings",
[
Expand Down Expand Up @@ -155,7 +155,7 @@ async def test_get_keycloak_deleted_user(
}


@responses.activate
@respx.mock
@pytest.mark.parametrize(
"settings",
[
Expand Down

0 comments on commit e3bca7e

Please sign in to comment.