From d0ff92f1cabae2c1ceb9e1101be2fee5109faba6 Mon Sep 17 00:00:00 2001 From: natthan-pigoux Date: Fri, 12 Jul 2024 17:02:59 +0200 Subject: [PATCH] feat: add min client version, pass client version in header and add middleware to check client version --- diracx-client/src/diracx/client/_patch.py | 15 +++++ diracx-client/src/diracx/client/aio/_patch.py | 15 +++++ diracx-routers/pyproject.toml | 2 + diracx-routers/src/diracx/routers/__init__.py | 62 ++++++++++++++++++- diracx-routers/tests/test_generic.py | 10 +++ 5 files changed, 103 insertions(+), 1 deletion(-) diff --git a/diracx-client/src/diracx/client/_patch.py b/diracx-client/src/diracx/client/_patch.py index d8772d2e7..4122dd885 100644 --- a/diracx-client/src/diracx/client/_patch.py +++ b/diracx-client/src/diracx/client/_patch.py @@ -9,6 +9,7 @@ from __future__ import annotations from datetime import datetime, timezone +from importlib.metadata import PackageNotFoundError, distribution import json import jwt import requests @@ -148,6 +149,20 @@ def __init__( # Get .well-known configuration openid_configuration = get_openid_configuration(self._endpoint, verify=verify) + try: + self.client_version = distribution("diracx").version + except PackageNotFoundError: + try: + self.client_version = distribution("diracx-client").version + except PackageNotFoundError: + print("Error while getting client version") + self.client_version = "Unknown" + + # Setting default headers + kwargs.setdefault("base_headers", {})[ + "DiracX-Client-Version" + ] = self.client_version + # Initialize Dirac with a Dirac-specific token credential policy super().__init__( endpoint=self._endpoint, diff --git a/diracx-client/src/diracx/client/aio/_patch.py b/diracx-client/src/diracx/client/aio/_patch.py index 1a70f1c49..3438055fd 100644 --- a/diracx-client/src/diracx/client/aio/_patch.py +++ b/diracx-client/src/diracx/client/aio/_patch.py @@ -6,6 +6,7 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ +from importlib.metadata import PackageNotFoundError, distribution import json from types import TracebackType from pathlib import Path @@ -163,6 +164,20 @@ def __init__( # Get .well-known configuration openid_configuration = get_openid_configuration(self._endpoint, verify=verify) + try: + self.client_version = distribution("diracx").version + except PackageNotFoundError: + try: + self.client_version = distribution("diracx-client").version + except PackageNotFoundError: + print("Error while getting client version") + self.client_version = "Unknown" + + # Setting default headers + kwargs.setdefault("base_headers", {})[ + "DiracX-Client-Version" + ] = self.client_version + # Initialize Dirac with a Dirac-specific token credential policy super().__init__( endpoint=self._endpoint, diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 97fdccb31..886218fbb 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -56,6 +56,8 @@ auth = "diracx.routers.auth:router" WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy" +[project.entry-points."diracx.min_client_version"] +diracx = "diracx.routers:DIRACX_MIN_CLIENT_VERSION" [tool.setuptools.packages.find] where = ["src"] diff --git a/diracx-routers/src/diracx/routers/__init__.py b/diracx-routers/src/diracx/routers/__init__.py index 2ff15cded..64b7943d3 100644 --- a/diracx-routers/src/diracx/routers/__init__.py +++ b/diracx-routers/src/diracx/routers/__init__.py @@ -12,19 +12,23 @@ import os from collections.abc import AsyncGenerator from functools import partial +from http import HTTPStatus +from importlib.metadata import EntryPoint, EntryPoints, entry_points from logging import Formatter, StreamHandler from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast import dotenv from cachetools import TTLCache -from fastapi import APIRouter, Depends, Request, status +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status from fastapi.dependencies.models import Dependant from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response from fastapi.routing import APIRoute +from packaging.version import parse from pydantic import TypeAdapter # from starlette.types import ASGIApp +from starlette.middleware.base import BaseHTTPMiddleware from uvicorn.logging import AccessFormatter, DefaultFormatter from diracx.core.config import ConfigSource @@ -290,6 +294,8 @@ def create_app_inner( "http://localhost:8000", ] + app.add_middleware(ClientMinVersionCheckMiddleware) + app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -437,3 +443,57 @@ async def db_transaction(db: T2) -> AsyncGenerator[T2, None]: if reason := await is_db_unavailable(db): raise DBUnavailable(reason) yield db + + +class ClientMinVersionCheckMiddleware(BaseHTTPMiddleware): + """Custom FastAPI middleware to verify that + the client has the required minimum version. + """ + + def __init__(self, app: FastAPI): + super().__init__(app) + self.min_client_version = get_min_client_version() + + async def dispatch(self, request: Request, call_next) -> Response: + client_version = request.headers.get("DiracX-Client-Version") + if not client_version: + logger.info("DiracX-Client-Version header is missing.") + # TODO: if the request comes from web or swagger (other?), + # the header will be missing > how to manage that? + # raise HTTPException( + # status_code=HTTPStatus.BAD_REQUEST, + # detail="Client version header is missing.", + # ) + + elif self.is_version_too_old(client_version): + raise HTTPException( + status_code=HTTPStatus.UPGRADE_REQUIRED, + detail=f"Client version ({client_version}) not recent enough (>= {self.min_client_version}). Upgrade.", + ) + + response = await call_next(request) + return response + + def is_version_too_old(self, client_version: str) -> bool: + """Verify that client version is ge than min.""" + return parse(client_version) < parse(self.min_client_version) + + +# I'm not sure if this has to be define here: +DIRACX_MIN_CLIENT_VERSION = "0.0.1" + + +def get_min_client_version(): + """Extracting min client version from entry_points and seraching for extension.""" + matched_entry_points: EntryPoints = entry_points(group="diracx.min_client_version") + # Searching for an extension: + entry_points_dict: dict[str, EntryPoint] = { + ep.name: ep for ep in matched_entry_points + } + for ep_name, ep in entry_points_dict.items(): + if ep_name != "diracx": + return ep.load() + + # Taking diracx if no extension: + if "diracx" in entry_points_dict: + return entry_points_dict["diracx"].load() diff --git a/diracx-routers/tests/test_generic.py b/diracx-routers/tests/test_generic.py index 9f31064d0..49e617100 100644 --- a/diracx-routers/tests/test_generic.py +++ b/diracx-routers/tests/test_generic.py @@ -1,4 +1,7 @@ +from http import HTTPStatus + import pytest +from fastapi import HTTPException pytestmark = pytest.mark.enabled_dependencies( ["ConfigSource", "AuthSettings", "OpenAccessPolicy"] @@ -41,3 +44,10 @@ def test_unavailable_db(monkeypatch, test_client): r = test_client.get("/api/job/123") assert r.status_code == 503 assert r.json() + + +def test_min_client_version(test_client): + with pytest.raises(HTTPException) as response: + test_client.get("/", headers={"DiracX-Client-Version": "0.1.0"}) + assert response.value.status_code == HTTPStatus.UPGRADE_REQUIRED + assert "not recent enough" in response.value.detail