Skip to content

Commit

Permalink
fix: Contents of the TheImpersonator cache key
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisburr committed Oct 12, 2023
1 parent 6d8f0cd commit 5f8ab98
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
7 changes: 6 additions & 1 deletion src/DIRAC/FrameworkSystem/Service/ProxyManagerHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,12 @@ def export_exchangeProxyForToken(self):
"""Exchange a proxy for an equivalent token to be used with diracx"""
from DIRAC.FrameworkSystem.Utilities.diracx import get_token

return get_token(self.getRemoteCredentials())
credDict = self.getRemoteCredentials()
return get_token(
credDict["username"],
credDict["group"],
set(credDict.get("groupProperties", []) + credDict.get("properties", [])),
)


class ProxyManagerHandler(ProxyManagerHandlerMixin, RequestHandler):
Expand Down
31 changes: 17 additions & 14 deletions src/DIRAC/FrameworkSystem/Utilities/diracx.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# pylint: disable=import-error

import requests

from cachetools import TTLCache, cached
from cachetools.keys import hashkey
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any
Expand All @@ -19,12 +19,10 @@

# How long tokens are kept
DEFAULT_TOKEN_CACHE_TTL = 5 * 60

# Add a cache not to query the token all the time
_token_cache = TTLCache(maxsize=100, ttl=DEFAULT_TOKEN_CACHE_TTL)
DEFAULT_TOKEN_CACHE_SIZE = 1024


def get_token(credDict, *, expires_minutes=None):
def get_token(username: str, group: str, dirac_properties: set[str], *, expires_minutes: int | None = None):
"""Do a legacy exchange to get a DiracX access_token+refresh_token"""
diracxUrl = gConfig.getValue("/DiracX/URL")
if not diracxUrl:
Expand All @@ -33,16 +31,13 @@ def get_token(credDict, *, expires_minutes=None):
if not apiKey:
raise ValueError("Missing mandatory /DiracX/LegacyExchangeApiKey configuration")

vo = Registry.getVOForGroup(credDict["group"])
dirac_properties = list(set(credDict.get("groupProperties", [])) | set(credDict.get("properties", [])))
group = credDict["group"]

vo = Registry.getVOForGroup(group)
scopes = [f"vo:{vo}", f"group:{group}"] + [f"property:{prop}" for prop in dirac_properties]

r = requests.get(
f"{diracxUrl}/api/auth/legacy-exchange",
params={
"preferred_username": credDict["username"],
"preferred_username": username,
"scope": " ".join(scopes),
"expires_minutes": expires_minutes,
},
Expand All @@ -55,10 +50,13 @@ def get_token(credDict, *, expires_minutes=None):
return r.json()


@cached(_token_cache, key=lambda x, y: repr(x))
def _get_token_file(credDict) -> Path:
@cached(
TTLCache(maxsize=DEFAULT_TOKEN_CACHE_SIZE, ttl=DEFAULT_TOKEN_CACHE_TTL),
key=lambda a, b, c: hashkey(a, b, *sorted(c)),
)
def _get_token_file(username: str, group: str, dirac_properties: set[str]) -> Path:
"""Write token to a temporary file and return the path to that file"""
data = get_token(credDict)
data = get_token(username, group, dirac_properties)
token_location = Path(NamedTemporaryFile().name)
write_credentials(TokenResponse(**data), location=token_location)
return token_location
Expand All @@ -76,7 +74,12 @@ def TheImpersonator(credDict: dict[str, Any]) -> DiracClient:
diracxUrl = gConfig.getValue("/DiracX/URL")
if not diracxUrl:
raise ValueError("Missing mandatory /DiracX/URL configuration")
token_location = _get_token_file(credDict)

token_location = _get_token_file(
credDict["username"],
credDict["group"],
set(credDict.get("groupProperties", []) + credDict.get("properties", [])),
)
pref = DiracxPreferences(url=diracxUrl, credentials_path=token_location)

return DiracClient(diracx_preferences=pref)

0 comments on commit 5f8ab98

Please sign in to comment.