diff --git a/dirac.cfg b/dirac.cfg index 4734379bd71..d67f02c2959 100644 --- a/dirac.cfg +++ b/dirac.cfg @@ -134,7 +134,8 @@ DiracX URL = https://diracx.invalid:8000 # A key used to have priviledged interactions with diracx. see LegacyExchangeApiKey = diracx:legacy:InsecureChangeMe - + # List of VOs which should use DiracX via the legacy compatibility mechanism + EnabledVOs = gridpp,cta } ### Registry section: # Sections to register VOs, groups, users and hosts diff --git a/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py index c68f1f97839..a8cf1130a17 100644 --- a/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py @@ -414,7 +414,12 @@ def export_exchangeProxyForToken(self): """Exchange a proxy for an equivalent token to be used with diracx""" from DIRAC.FrameworkSystem.Utilities.diracx import get_token - return get_token(self.getRemoteCredentials()) + credDict = self.getRemoteCredentials() + return get_token( + credDict["username"], + credDict["group"], + set(credDict.get("groupProperties", []) + credDict.get("properties", [])), + ) class ProxyManagerHandler(ProxyManagerHandlerMixin, RequestHandler): diff --git a/src/DIRAC/FrameworkSystem/Utilities/diracx.py b/src/DIRAC/FrameworkSystem/Utilities/diracx.py index 826f0f3d589..946ee47e423 100644 --- a/src/DIRAC/FrameworkSystem/Utilities/diracx.py +++ b/src/DIRAC/FrameworkSystem/Utilities/diracx.py @@ -1,8 +1,8 @@ # pylint: disable=import-error - import requests from cachetools import TTLCache, cached +from cachetools.keys import hashkey from pathlib import Path from tempfile import NamedTemporaryFile from typing import Any @@ -19,12 +19,10 @@ # How long tokens are kept DEFAULT_TOKEN_CACHE_TTL = 5 * 60 - -# Add a cache not to query the token all the time -_token_cache = TTLCache(maxsize=100, ttl=DEFAULT_TOKEN_CACHE_TTL) +DEFAULT_TOKEN_CACHE_SIZE = 1024 -def get_token(credDict, *, expires_minutes=None): +def get_token(username: str, group: str, dirac_properties: set[str], *, expires_minutes: int | None = None): """Do a legacy exchange to get a DiracX access_token+refresh_token""" diracxUrl = gConfig.getValue("/DiracX/URL") if not diracxUrl: @@ -33,16 +31,13 @@ def get_token(credDict, *, expires_minutes=None): if not apiKey: raise ValueError("Missing mandatory /DiracX/LegacyExchangeApiKey configuration") - vo = Registry.getVOForGroup(credDict["group"]) - dirac_properties = list(set(credDict.get("groupProperties", [])) | set(credDict.get("properties", []))) - group = credDict["group"] - + vo = Registry.getVOForGroup(group) scopes = [f"vo:{vo}", f"group:{group}"] + [f"property:{prop}" for prop in dirac_properties] r = requests.get( f"{diracxUrl}/api/auth/legacy-exchange", params={ - "preferred_username": credDict["username"], + "preferred_username": username, "scope": " ".join(scopes), "expires_minutes": expires_minutes, }, @@ -55,10 +50,13 @@ def get_token(credDict, *, expires_minutes=None): return r.json() -@cached(_token_cache, key=lambda x, y: repr(x)) -def _get_token_file(credDict) -> Path: +@cached( + TTLCache(maxsize=DEFAULT_TOKEN_CACHE_SIZE, ttl=DEFAULT_TOKEN_CACHE_TTL), + key=lambda a, b, c: hashkey(a, b, *sorted(c)), +) +def _get_token_file(username: str, group: str, dirac_properties: set[str]) -> Path: """Write token to a temporary file and return the path to that file""" - data = get_token(credDict) + data = get_token(username, group, dirac_properties) token_location = Path(NamedTemporaryFile().name) write_credentials(TokenResponse(**data), location=token_location) return token_location @@ -76,7 +74,12 @@ def TheImpersonator(credDict: dict[str, Any]) -> DiracClient: diracxUrl = gConfig.getValue("/DiracX/URL") if not diracxUrl: raise ValueError("Missing mandatory /DiracX/URL configuration") - token_location = _get_token_file(credDict) + + token_location = _get_token_file( + credDict["username"], + credDict["group"], + set(credDict.get("groupProperties", []) + credDict.get("properties", [])), + ) pref = DiracxPreferences(url=diracxUrl, credentials_path=token_location) return DiracClient(diracx_preferences=pref) diff --git a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py index 576ade8f2d0..55ff8ba528a 100644 --- a/src/DIRAC/FrameworkSystem/scripts/dirac_login.py +++ b/src/DIRAC/FrameworkSystem/scripts/dirac_login.py @@ -319,18 +319,26 @@ def loginWithCertificate(self): if vo in enabledVOs: from diracx.core.utils import write_credentials # pylint: disable=import-error from diracx.core.models import TokenResponse # pylint: disable=import-error + from diracx.core.preferences import DiracxPreferences # pylint: disable=import-error res = Client(url="Framework/ProxyManager").exchangeProxyForToken() if not res["OK"]: return res token_content = res["Value"] + + diracxUrl = gConfig.getValue("/DiracX/URL") + if not diracxUrl: + return S_ERROR("Missing mandatory /DiracX/URL configuration") + + preferences = DiracxPreferences(url=diracxUrl) write_credentials( TokenResponse( access_token=token_content["access_token"], expires_in=token_content["expires_in"], token_type=token_content.get("token_type"), refresh_token=token_content.get("refresh_token"), - ) + ), + location=preferences.credentials_path, ) return S_OK()