Skip to content

Commit

Permalink
More changes
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Nov 8, 2024
1 parent 3d4a9a5 commit 7c9779b
Show file tree
Hide file tree
Showing 15 changed files with 320 additions and 345 deletions.
4 changes: 1 addition & 3 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ConfigLoader,
)
from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent
from blueapi.service.authentication import CliTokenManager, SessionManager
from blueapi.service.authentication import SessionManager
from blueapi.worker import ProgressEvent, Task, WorkerEvent

from .scratch import setup_scratch
Expand Down Expand Up @@ -362,7 +362,6 @@ def login(obj: dict) -> None:
auth: SessionManager = SessionManager(
server_config=config.oauth_server,
client_config=config.oauth_client,
token_manager=CliTokenManager(Path(config.oauth_client.token_file_path)),
)
auth.start_device_flow()
else:
Expand All @@ -377,7 +376,6 @@ def logout(obj: dict) -> None:
auth: SessionManager = SessionManager(
server_config=config.oauth_server,
client_config=config.oauth_client,
token_manager=CliTokenManager(Path(config.oauth_client.token_file_path)),
)
auth.logout()
print("Logged out")
Expand Down
4 changes: 3 additions & 1 deletion src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(
def from_config(cls, config: ApplicationConfig) -> "BlueapiClient":
rest: BlueapiRestClient = BlueapiRestClient(
config.api,
SessionManager.from_config(config.oauth_server, config.oauth_client),
SessionManager(config.oauth_server, config.oauth_client)
if config.oauth_server and config.oauth_client
else None,
)
if config.stomp is not None:
template = StompClient.for_broker(
Expand Down
6 changes: 4 additions & 2 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
session_manager: SessionManager | None = None,
) -> None:
self._config = config or RestConfig()
self._session_manager: SessionManager | None = session_manager
self._session_manager = session_manager

def get_plans(self) -> PlanResponse:
return self._request_and_deserialize("/plans", PlanResponse)
Expand Down Expand Up @@ -149,7 +149,9 @@ def _request_and_deserialize(
if self._session_manager and (token := self._session_manager.get_token()):
try:
# Check token is not expired
self._session_manager.authenticator.decode_jwt(token["access_token"])
self._session_manager.authenticator.decode_jwt(
token["access_token"], self._session_manager._client_audience
)
except jwt.ExpiredSignatureError:
token = self._session_manager.refresh_auth_token()
assert token # This must be present
Expand Down
9 changes: 5 additions & 4 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from enum import Enum
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -88,7 +88,6 @@ class OAuthServerConfig(BlueapiBaseModel):
oidc_config_url: str = Field(
description="URL to fetch OIDC config from the provider"
)
audience: str = Field(description="Valid audience")

@cached_property
def _config_from_oidc_url(self) -> dict[str, Any]:
Expand All @@ -103,7 +102,7 @@ def device_auth_url(self) -> str:
)

@cached_property
def pkce_auth_url(self) -> str:
def auth_url(self) -> str:
return cast(str, self._config_from_oidc_url.get("authorization_endpoint"))

@cached_property
Expand Down Expand Up @@ -132,7 +131,9 @@ def signing_algos(self) -> list[str]:

class OAuthClientConfig(BlueapiBaseModel):
client_id: str = Field(description="Client ID")
client_audience: str = Field(description="Client Audience")
client_audience: str | Iterable[str] | None = Field(
description="Client Audience(s)"
)


class CLIClientConfig(OAuthClientConfig):
Expand Down
130 changes: 63 additions & 67 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import time
from abc import ABC, abstractmethod
from enum import Enum
from collections.abc import Iterable
from http import HTTPStatus
from pathlib import Path
from typing import Any, cast
Expand All @@ -20,32 +20,24 @@
)


class AuthenticationType(Enum):
DEVICE = "device"
PKCE = "pkce"


class Authenticator:
def __init__(self, server_config: OAuthServerConfig):
self._server_config: OAuthServerConfig = server_config

def decode_jwt(self, token: str) -> dict[str, str]:
def decode_jwt(
self, token: str, audience: str | Iterable[str] | None = None
) -> dict[str, str]:
signing_key = jwt.PyJWKClient(
self._server_config.jwks_uri
).get_signing_key_from_jwt(token)
decode: dict[str, str] = jwt.decode(
return jwt.decode(
token,
signing_key.key,
algorithms=self._server_config.signing_algos,
verify=True,
audience=self._server_config.audience,
audience=audience,
issuer=self._server_config.issuer,
)
return decode

def print_user_info(self, token: str) -> None:
decode: dict[str, str] = self.decode_jwt(token)
print(f'Logged in as {decode.get("name")} with fed-id {decode.get("fedid")}')


class TokenManager(ABC):
Expand All @@ -66,8 +58,7 @@ def _file_path(self) -> str:

def save_token(self, token: dict[str, Any]) -> None:
token_json: str = json.dumps(token)
token_bytes: bytes = token_json.encode("utf-8")
token_base64: bytes = base64.b64encode(token_bytes)
token_base64: bytes = base64.b64encode(token_json.encode("utf-8"))
with open(self._file_path(), "wb") as token_file:
token_file.write(token_base64)

Expand All @@ -77,8 +68,7 @@ def load_token(self) -> dict[str, Any] | None:
return None
with open(file_path, "rb") as token_file:
token_base64: bytes = token_file.read()
token_bytes: bytes = base64.b64decode(token_base64)
token_json: str = token_bytes.decode("utf-8")
token_json: bytes = base64.b64decode(token_base64).decode("utf-8")
return json.loads(token_json)

def delete_token(self) -> None:
Expand All @@ -90,50 +80,46 @@ def __init__(
self,
server_config: OAuthServerConfig,
client_config: OAuthClientConfig,
token_manager: TokenManager,
) -> None:
self._server_config: OAuthServerConfig = server_config
self._client_config: OAuthClientConfig = client_config
self.authenticator: Authenticator = Authenticator(server_config, client_config)
self._token_manager = token_manager

@classmethod
def from_config(
cls,
server_config: OAuthServerConfig | None,
client_config: OAuthClientConfig | None,
) -> SessionManager | None:
if server_config and client_config:
if isinstance(client_config, CLIClientConfig):
return SessionManager(
server_config,
client_config,
CliTokenManager(Path(client_config.token_file_path)),
)
return None
self._server_config = server_config
self._client_id = client_config.client_id
self._client_audience = client_config.client_audience
self.authenticator: Authenticator = Authenticator(server_config)
self._token_manager: TokenManager | None = (
CliTokenManager(client_config.token_file_path)
if isinstance(client_config, CLIClientConfig)
else None
)

def get_token(self) -> dict[str, Any] | None:
return self._token_manager.load_token()
if self._token_manager:
return self._token_manager.load_token()
return None

def logout(self) -> None:
self._token_manager.delete_token()
if self._token_manager:
self._token_manager.delete_token()

def refresh_auth_token(self) -> dict[str, Any] | None:
if token := self._token_manager.load_token():
response = requests.post(
self._server_config.token_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": self._client_config.client_id,
"grant_type": "refresh_token",
"refresh_token": token["refresh_token"],
},
)
if response.status_code == HTTPStatus.OK:
token = response.json()
if token:
self._token_manager.save_token(token)
return token
if not self._token_manager:
print("Session not configured to persist, no token to refresh")
return None
token = self._token_manager.load_token()
if not token:
print("No current Session, no token to refresh")
return None
response = requests.post(
self._server_config.token_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": self._client_id,
"grant_type": "refresh_token",
"refresh_token": token["refresh_token"],
},
)
if response.status_code == HTTPStatus.OK and (token := response.json()):
self._token_manager.save_token(token)
return token
return None

def poll_for_token(
Expand All @@ -147,7 +133,7 @@ def poll_for_token(
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
"client_id": self._client_config.client_id,
"client_id": self._client_id,
},
)
if response.status_code == HTTPStatus.OK:
Expand All @@ -157,26 +143,32 @@ def poll_for_token(
raise TimeoutError("Polling timed out")

def start_device_flow(self) -> None:
if not self._token_manager:
print("Session not configured to persist, no token to refresh")
return None

if token := self._token_manager.load_token():
try:
access_token_info: dict[str, Any] = self.authenticator.decode_jwt(
token["access_token"]
self.authenticator.decode_jwt(
token["access_token"], self._client_audience
)
if access_token_info:
self.authenticator.print_user_info(token["access_token"])
return
print("Cached token still valid, skipping flow")
return
except jwt.ExpiredSignatureError:
if token := self.refresh_auth_token():
self.authenticator.print_user_info(token["access_token"])
print("Refreshed cached token, skipping flow")
return
except Exception:
print("Problem with cached token, starting new session")
self._token_manager.delete_token()

response: requests.Response = requests.post(
self._server_config.device_auth_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": self._client_config.client_id,
"client_id": self._client_id,
"scope": "openid profile offline_access",
"audience": self._client_config.client_audience,
"audience": self._client_audience,
},
)

Expand All @@ -193,10 +185,14 @@ def start_device_flow(self) -> None:
device_code, interval, expires_in
)
decoded_token: dict[str, Any] = self.authenticator.decode_jwt(
auth_token_json["access_token"]
auth_token_json["access_token"], self._client_audience
)
if decoded_token:
self._token_manager.save_token(auth_token_json)
self.authenticator.print_user_info(auth_token_json["access_token"])
if self._token_manager:
self._token_manager.save_token(auth_token_json)
print("Logged in and cached new token")
else:
print("Logged in but not configured to persist session")

else:
print("Failed to login")
10 changes: 5 additions & 5 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from functools import lru_cache
from typing import Any

from fastapi import (
Expand Down Expand Up @@ -27,7 +28,7 @@
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError

from blueapi.config import ApplicationConfig, OAuthClientConfig, OAuthServerConfig
from blueapi.config import ApplicationConfig, OAuthServerConfig
from blueapi.service import interface
from blueapi.service.authentication import Authenticator
from blueapi.service.runner import WorkerDispatcher
Expand All @@ -49,8 +50,6 @@
REST_API_VERSION = "0.0.5"

RUNNER: WorkerDispatcher | None = None
SWAGGER_CONFIG: dict[str, Any] | None = None

CONTEXT_HEADER = "traceparent"


Expand Down Expand Up @@ -90,7 +89,8 @@ async def inner(app: FastAPI):
router = APIRouter()


def get_app(config: ApplicationConfig | None):
@lru_cache(1)
def get_app(config: ApplicationConfig | None = None):
app = FastAPI(
docs_url="/docs",
title="BlueAPI Control",
Expand All @@ -108,7 +108,7 @@ def get_app(config: ApplicationConfig | None):

def verify_access_token(config: OAuthServerConfig):
oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=config.pkce_auth_url,
authorizationUrl=config.auth_url,
tokenUrl=config.token_url,
refreshUrl=config.token_url,
)
Expand Down
6 changes: 6 additions & 0 deletions src/script/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ services:
- 61613:61613
volumes:
- ./rabbitmq_setup/enabled_plugins:/etc/rabbitmq/enabled_plugins
mock-oauth2-server:
image: ghcr.io/navikt/mock-oauth2-server:2.1.10
ports:
- 8080:8080
environment:
JSON_CONFIG: '{ "interactiveLogin": false, "httpServer": "MockWebServerWrapper", "tokenCallbacks": [ { "issuerId": "auth", "tokenExpiry": 120, "requestMappings": [] } ] }'
11 changes: 11 additions & 0 deletions src/script/oidc_setup/oidc_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"interactiveLogin": false,
"httpServer": "MockWebServerWrapper",
"tokenCallbacks": [
{
"issuerId": "auth",
"tokenExpiry": 120,
"requestMappings": []
}
]
}
7 changes: 6 additions & 1 deletion src/script/stomp_config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
---
stomp:
host: "localhost"
port: 61613
auth:
username: "guest"
password: "guest"
oauth_server:
oidc_config_url: "http://localhost:8080/auth/.well-known/openid-configuration"
oauth_client:
client_id: "blueapi"
client_audience: "blueapi-cli"
token_file_path: "/tmp/auth_token"
Loading

0 comments on commit 7c9779b

Please sign in to comment.