Skip to content

Commit

Permalink
Merge pull request #112 from chrisburr/legacy-exchange
Browse files Browse the repository at this point in the history
Add /auth/legacy-exchange
  • Loading branch information
chrisburr authored Oct 2, 2023
2 parents 34c630e + 263e0c7 commit ae3faf7
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 11 deletions.
14 changes: 12 additions & 2 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,22 @@ jobs:
run: cd /tmp/DIRACRepo && ./integration_tests.py test-server || touch server-tests-failed
- name: Client tests
run: cd /tmp/DIRACRepo && ./integration_tests.py test-client || touch client-tests-failed
- name: diracx logs
run: docker logs diracx
- name: Check test status
run: |
has_error=0
# TODO: set has_error=1 when we are ready to really run the tests
if [ -f server-tests-failed ]; then has_error=0; echo "Server tests failed"; fi
if [ -f client-tests-failed ]; then has_error=0; echo "Client tests failed"; fi
if [ ${has_error} = 1 ]; then exit 1; fi
- name: diracx logs
if: ${{ failure() }}
run: |
mkdir -p /tmp/service-logs
docker logs diracx 2>&1 | tee /tmp/service-logs/diracx.log
cd /tmp/DIRACRepo
./integration_tests.py logs --no-follow --lines 1000 2>&1 | tee /tmp/service-logs/dirac.log
- uses: actions/upload-artifact@v3
if: ${{ failure() }}
with:
name: serivce-logs
path: /tmp/service-logs/
12 changes: 9 additions & 3 deletions src/diracx/client/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
from __future__ import annotations

from datetime import datetime
import json
import requests
Expand All @@ -19,7 +21,7 @@

from diracx.client.models import TokenResponse
from diracx.core.models import TokenResponse as CoreTokenResponse
from diracx.core.preferences import get_diracx_preferences
from diracx.core.preferences import DiracxPreferences, get_diracx_preferences

from ._client import Dirac as DiracGenerated

Expand Down Expand Up @@ -112,9 +114,13 @@ class DiracClient(DiracGenerated):
"""

def __init__(
self, endpoint: str | None = None, client_id: str | None = None, **kwargs: Any
self,
endpoint: str | None = None,
client_id: str | None = None,
diracx_preferences: DiracxPreferences | None = None,
**kwargs: Any,
) -> None:
diracx_preferences = get_diracx_preferences()
diracx_preferences = diracx_preferences or get_diracx_preferences()
self._endpoint = endpoint or diracx_preferences.url
self._client_id = client_id or "myDIRACClientID"

Expand Down
10 changes: 7 additions & 3 deletions src/diracx/client/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy

from diracx.core.preferences import get_diracx_preferences
from diracx.core.preferences import get_diracx_preferences, DiracxPreferences

from ._client import Dirac as DiracGenerated
from .._patch import get_openid_configuration, get_token, refresh_token
Expand Down Expand Up @@ -125,9 +125,13 @@ class DiracClient(DiracGenerated):
"""

def __init__(
self, endpoint: str | None = None, client_id: str | None = None, **kwargs: Any
self,
endpoint: str | None = None,
client_id: str | None = None,
diracx_preferences: DiracxPreferences | None = None,
**kwargs: Any,
) -> None:
diracx_preferences = get_diracx_preferences()
diracx_preferences = diracx_preferences or get_diracx_preferences()
self._endpoint = endpoint or diracx_preferences.url
self._client_id = client_id or "myDIRACClientID"

Expand Down
10 changes: 10 additions & 0 deletions src/diracx/core/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ class RegistryConfig(BaseModel):
Users: dict[str, UserConfig]
Groups: dict[str, GroupConfig]

def sub_from_preferred_username(self, preferred_username: str) -> str:
"""Get the user sub from the preferred username.
TODO: This could easily be cached or optimised
"""
for sub, user in self.Users.items():
if user.PreferedUsername == preferred_username:
return sub
raise KeyError(f"User {preferred_username} not found in registry")


class DIRACConfig(BaseModel):
pass
Expand Down
5 changes: 3 additions & 2 deletions src/diracx/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import re
from datetime import datetime, timedelta, timezone
from pathlib import Path

from diracx.core.models import TokenResponse

Expand All @@ -19,7 +20,7 @@ def dotenv_files_from_environment(prefix: str) -> list[str]:
return [v for _, v in sorted(env_files.items())]


def write_credentials(token_response: TokenResponse):
def write_credentials(token_response: TokenResponse, location: Path | None = None):
"""Write credentials received in dirax_preferences.credentials_path"""
from diracx.core.preferences import get_diracx_preferences

Expand All @@ -31,6 +32,6 @@ def write_credentials(token_response: TokenResponse):
"refresh_token": token_response.refresh_token,
"expires_on": int(datetime.timestamp(expires)),
}
credentials_path = get_diracx_preferences().credentials_path
credentials_path = location or get_diracx_preferences().credentials_path
credentials_path.parent.mkdir(parents=True, exist_ok=True)
credentials_path.write_text(json.dumps(credential_data))
61 changes: 61 additions & 0 deletions src/diracx/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import hashlib
import json
import os
import re
import secrets
from datetime import timedelta
Expand All @@ -18,6 +19,7 @@
from fastapi import (
Depends,
Form,
Header,
HTTPException,
Request,
Response,
Expand Down Expand Up @@ -1011,3 +1013,62 @@ async def userinfo(
"properties": user_info.properties,
"preferred_username": user_info.preferred_username,
}


BASE_64_URL_SAFE_PATTERN = (
r"(?:[A-Za-z0-9\-_]{4})*(?:[A-Za-z0-9\-_]{2}==|[A-Za-z0-9\-_]{3}=)?"
)
LEGACY_EXCHANGE_PATTERN = rf"Bearer diracx:legacy:({BASE_64_URL_SAFE_PATTERN})"


@router.get("/legacy-exchange", include_in_schema=False)
async def legacy_exchange(
preferred_username: str,
scope: str,
authorization: Annotated[str, Header()],
auth_db: AuthDB,
available_properties: AvailableSecurityProperties,
settings: AuthSettings,
config: Config,
):
"""Endpoint used by legacy DIRAC to mint tokens for proxy -> token exchange."""
if not (
expected_api_key := os.environ.get("DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY")
):
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Legacy exchange is not enabled",
)

if match := re.fullmatch(LEGACY_EXCHANGE_PATTERN, authorization):
raw_token = base64.urlsafe_b64decode(match.group(1))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid authorization header",
)

if hashlib.sha256(raw_token).hexdigest() != expected_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)

try:
parsed_scope = parse_and_validate_scope(scope, config, available_properties)
vo_users = config.Registry[parsed_scope["vo"]]
sub = vo_users.sub_from_preferred_username(preferred_username)
except (KeyError, ValueError) as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid scope or preferred_username",
) from e

return await exchange_token(
auth_db,
scope,
{"sub": sub, "preferred_username": preferred_username},
config,
settings,
available_properties,
)
Empty file added tests/routers/auth/__init__.py
Empty file.
104 changes: 104 additions & 0 deletions tests/routers/auth/test_legacy_exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import base64
import hashlib
import secrets

import pytest


@pytest.fixture
def legacy_credentials(monkeypatch):
secret = secrets.token_bytes()
valid_token = f"diracx:legacy:{base64.urlsafe_b64encode(secret).decode()}"
monkeypatch.setenv(
"DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY", hashlib.sha256(secret).hexdigest()
)
yield {"Authorization": f"Bearer {valid_token}"}


async def test_valid(test_client, legacy_credentials):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "chaen", "scope": "vo:lhcb group:lhcb_user"},
headers=legacy_credentials,
)
assert r.status_code == 200
access_token = r.json()["access_token"]

r = test_client.get(
"/auth/userinfo", headers={"Authorization": f"Bearer {access_token}"}
)
assert r.status_code == 200
user_info = r.json()
assert user_info["sub"] == "lhcb:b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"
assert user_info["vo"] == "lhcb"
assert user_info["dirac_group"] == "lhcb_user"
assert user_info["properties"] == ["NormalUser", "PrivateLimitedDelegation"]


async def test_disabled(test_client):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "chaen", "scope": "vo:lhcb group:lhcb_user"},
headers={"Authorization": "Bearer diracx:legacy:ChangeME"},
)
assert r.status_code == 503


async def test_no_credentials(test_client, legacy_credentials):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "chaen", "scope": "vo:lhcb group:lhcb_user"},
headers={"Authorization": "Bearer invalid"},
)
assert r.status_code == 400
assert r.json()["detail"] == "Invalid authorization header"


async def test_invalid_credentials(test_client, legacy_credentials):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "chaen", "scope": "vo:lhcb group:lhcb_user"},
headers={"Authorization": "Bearer invalid"},
)
assert r.status_code == 400
assert r.json()["detail"] == "Invalid authorization header"


async def test_wrong_credentials(test_client, legacy_credentials):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "chaen", "scope": "vo:lhcb group:lhcb_user"},
headers={"Authorization": "Bearer diracx:legacy:ChangeME"},
)
assert r.status_code == 403
assert r.json()["detail"] == "Invalid credentials"


async def test_unknown_vo(test_client, legacy_credentials):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "chaen", "scope": "vo:unknown group:lhcb_user"},
headers=legacy_credentials,
)
assert r.status_code == 400
assert r.json()["detail"] == "Invalid scope or preferred_username"


async def test_unknown_group(test_client, legacy_credentials):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "chaen", "scope": "vo:lhcb group:unknown"},
headers=legacy_credentials,
)
assert r.status_code == 400
assert r.json()["detail"] == "Invalid scope or preferred_username"


async def test_unknown_user(test_client, legacy_credentials):
r = test_client.get(
"/auth/legacy-exchange",
params={"preferred_username": "unknown", "scope": "vo:lhcb group:lhcb_user"},
headers=legacy_credentials,
)
assert r.status_code == 400
assert r.json()["detail"] == "Invalid scope or preferred_username"
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def non_mocked_hosts(test_client) -> list[str]:

@pytest.fixture
async def auth_httpx_mock(httpx_mock: HTTPXMock, monkeypatch):
data_dir = Path(__file__).parent.parent / "data"
data_dir = Path(__file__).parent.parent.parent / "data"
path = "lhcb-auth.web.cern.ch/.well-known/openid-configuration"
httpx_mock.add_response(url=f"https://{path}", text=(data_dir / path).read_text())

Expand Down

0 comments on commit ae3faf7

Please sign in to comment.