Skip to content

Commit

Permalink
feat: add min client version, pass client version in header and add m…
Browse files Browse the repository at this point in the history
…iddleware to check client version
  • Loading branch information
natthan-pigoux committed Jul 12, 2024
1 parent 2e18084 commit d0ff92f
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 1 deletion.
15 changes: 15 additions & 0 deletions diracx-client/src/diracx/client/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

from datetime import datetime, timezone
from importlib.metadata import PackageNotFoundError, distribution
import json
import jwt
import requests
Expand Down Expand Up @@ -148,6 +149,20 @@ def __init__(
# Get .well-known configuration
openid_configuration = get_openid_configuration(self._endpoint, verify=verify)

try:
self.client_version = distribution("diracx").version
except PackageNotFoundError:
try:
self.client_version = distribution("diracx-client").version
except PackageNotFoundError:
print("Error while getting client version")
self.client_version = "Unknown"

# Setting default headers
kwargs.setdefault("base_headers", {})[
"DiracX-Client-Version"
] = self.client_version

# Initialize Dirac with a Dirac-specific token credential policy
super().__init__(
endpoint=self._endpoint,
Expand Down
15 changes: 15 additions & 0 deletions diracx-client/src/diracx/client/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
from importlib.metadata import PackageNotFoundError, distribution
import json
from types import TracebackType
from pathlib import Path
Expand Down Expand Up @@ -163,6 +164,20 @@ def __init__(
# Get .well-known configuration
openid_configuration = get_openid_configuration(self._endpoint, verify=verify)

try:
self.client_version = distribution("diracx").version
except PackageNotFoundError:
try:
self.client_version = distribution("diracx-client").version
except PackageNotFoundError:
print("Error while getting client version")
self.client_version = "Unknown"

# Setting default headers
kwargs.setdefault("base_headers", {})[
"DiracX-Client-Version"
] = self.client_version

# Initialize Dirac with a Dirac-specific token credential policy
super().__init__(
endpoint=self._endpoint,
Expand Down
2 changes: 2 additions & 0 deletions diracx-routers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ auth = "diracx.routers.auth:router"
WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy"
SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy"

[project.entry-points."diracx.min_client_version"]
diracx = "diracx.routers:DIRACX_MIN_CLIENT_VERSION"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
62 changes: 61 additions & 1 deletion diracx-routers/src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@
import os
from collections.abc import AsyncGenerator
from functools import partial
from http import HTTPStatus
from importlib.metadata import EntryPoint, EntryPoints, entry_points
from logging import Formatter, StreamHandler
from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast

import dotenv
from cachetools import TTLCache
from fastapi import APIRouter, Depends, Request, status
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
from fastapi.dependencies.models import Dependant
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
from fastapi.routing import APIRoute
from packaging.version import parse
from pydantic import TypeAdapter

# from starlette.types import ASGIApp
from starlette.middleware.base import BaseHTTPMiddleware
from uvicorn.logging import AccessFormatter, DefaultFormatter

from diracx.core.config import ConfigSource
Expand Down Expand Up @@ -290,6 +294,8 @@ def create_app_inner(
"http://localhost:8000",
]

app.add_middleware(ClientMinVersionCheckMiddleware)

app.add_middleware(
CORSMiddleware,
allow_origins=origins,
Expand Down Expand Up @@ -437,3 +443,57 @@ async def db_transaction(db: T2) -> AsyncGenerator[T2, None]:
if reason := await is_db_unavailable(db):
raise DBUnavailable(reason)
yield db


class ClientMinVersionCheckMiddleware(BaseHTTPMiddleware):
"""Custom FastAPI middleware to verify that
the client has the required minimum version.
"""

def __init__(self, app: FastAPI):
super().__init__(app)
self.min_client_version = get_min_client_version()

async def dispatch(self, request: Request, call_next) -> Response:
client_version = request.headers.get("DiracX-Client-Version")
if not client_version:
logger.info("DiracX-Client-Version header is missing.")
# TODO: if the request comes from web or swagger (other?),
# the header will be missing > how to manage that?
# raise HTTPException(
# status_code=HTTPStatus.BAD_REQUEST,
# detail="Client version header is missing.",
# )

elif self.is_version_too_old(client_version):
raise HTTPException(
status_code=HTTPStatus.UPGRADE_REQUIRED,
detail=f"Client version ({client_version}) not recent enough (>= {self.min_client_version}). Upgrade.",
)

response = await call_next(request)
return response

def is_version_too_old(self, client_version: str) -> bool:
"""Verify that client version is ge than min."""
return parse(client_version) < parse(self.min_client_version)


# I'm not sure if this has to be define here:
DIRACX_MIN_CLIENT_VERSION = "0.0.1"


def get_min_client_version():
"""Extracting min client version from entry_points and seraching for extension."""
matched_entry_points: EntryPoints = entry_points(group="diracx.min_client_version")
# Searching for an extension:
entry_points_dict: dict[str, EntryPoint] = {
ep.name: ep for ep in matched_entry_points
}
for ep_name, ep in entry_points_dict.items():
if ep_name != "diracx":
return ep.load()

# Taking diracx if no extension:
if "diracx" in entry_points_dict:
return entry_points_dict["diracx"].load()
10 changes: 10 additions & 0 deletions diracx-routers/tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from http import HTTPStatus

import pytest
from fastapi import HTTPException

pytestmark = pytest.mark.enabled_dependencies(
["ConfigSource", "AuthSettings", "OpenAccessPolicy"]
Expand Down Expand Up @@ -41,3 +44,10 @@ def test_unavailable_db(monkeypatch, test_client):
r = test_client.get("/api/job/123")
assert r.status_code == 503
assert r.json()


def test_min_client_version(test_client):
with pytest.raises(HTTPException) as response:
test_client.get("/", headers={"DiracX-Client-Version": "0.1.0"})
assert response.value.status_code == HTTPStatus.UPGRADE_REQUIRED
assert "not recent enough" in response.value.detail

0 comments on commit d0ff92f

Please sign in to comment.