Skip to content

Commit

Permalink
Merge pull request #7079 from DIRACGridBot/cherry-pick-2-4454fcd14-in…
Browse files Browse the repository at this point in the history
…tegration

[sweep:integration] fix: interacting with CEs using tokens
  • Loading branch information
fstagni authored Jun 27, 2023
2 parents 1228f15 + e724cd6 commit 292f478
Show file tree
Hide file tree
Showing 16 changed files with 355 additions and 87 deletions.
18 changes: 1 addition & 17 deletions src/DIRAC/Core/Utilities/Grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from DIRAC.Core.Utilities.Subprocess import systemCall, shellCall


def executeGridCommand(proxy, cmd, gridEnvScript=None, gridEnvDict=None):
def executeGridCommand(cmd, gridEnvScript=None, gridEnvDict=None):
"""
Execute cmd tuple after sourcing GridEnv
"""
Expand All @@ -37,22 +37,6 @@ def executeGridCommand(proxy, cmd, gridEnvScript=None, gridEnvDict=None):
else:
gridEnv = currentEnv

if not proxy:
res = getProxyInfo()
if not res["OK"]:
return res
gridEnv["X509_USER_PROXY"] = res["Value"]["path"]
elif isinstance(proxy, str):
if os.path.exists(proxy):
gridEnv["X509_USER_PROXY"] = proxy
else:
return S_ERROR("Can not treat proxy passed as a string")
else:
ret = gProxyManager.dumpProxyToFile(proxy)
if not ret["OK"]:
return ret
gridEnv["X509_USER_PROXY"] = ret["Value"]

if gridEnvDict:
gridEnv.update(gridEnvDict)

Expand Down
2 changes: 1 addition & 1 deletion src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def getToken(
self,
username: str = None,
userGroup: str = None,
scope: str = None,
scope: list[str] = None,
audience: str = None,
identityProvider: str = None,
requiredTimeLeft: int = 0,
Expand Down
15 changes: 8 additions & 7 deletions src/DIRAC/FrameworkSystem/Service/TornadoTokenManagerHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,20 @@ def initializeHandler(cls, *args):
:return: S_OK()/S_ERROR()
"""
# Cache containing tokens from scope requested by the client
cls.__tokensCache = DictCache()

# The service plays an important OAuth 2.0 role, namely it is an Identity Provider client.
# This allows you to manage tokens without the involvement of their owners.
cls.idps = IdProviderFactory()

# Let's try to connect to the database
try:
cls.__tokenDB = TokenDB(parentLogger=cls.log)
except Exception as e:
cls.log.exception(e)
return S_ERROR(f"Could not connect to the database {repr(e)}")

# Cache containing tokens from scope requested by the client
cls.__tokensCache = DictCache()

# The service plays an important OAuth 2.0 role, namely it is an Identity Provider client.
# This allows you to manage tokens without the involvement of their owners.
cls.idps = IdProviderFactory()
return S_OK()

def export_getUserTokensInfo(self):
Expand Down Expand Up @@ -185,7 +186,7 @@ def export_getToken(
self,
username: str = None,
userGroup: str = None,
scope: str = None,
scope: list[str] = None,
audience: str = None,
identityProvider: str = None,
requiredTimeLeft: int = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def getCachedKey(
idProviderClient,
username: str = None,
userGroup: str = None,
scope: str = None,
scope: list[str] = None,
audience: str = None,
):
"""Build the key to potentially retrieve a cached token given the provided parameters.
Expand All @@ -53,7 +53,9 @@ def getCachedKey(
if userGroup and (result := idProviderClient.getGroupScopes(userGroup)):
# What scope correspond to the requested group?
scope = list(set((scope or []) + result))
scope = " ".join(scope)

if scope:
scope = " ".join(sorted(scope))

return (subject, scope, audience, idProviderClient.name, idProviderClient.issuer)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
""" Test IdProvider Factory"""
import pytest
import time

from DIRAC import S_ERROR, S_OK
from DIRAC.Core.Utilities.DictCache import DictCache
from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token
from DIRAC.FrameworkSystem.Utilities.TokenManagementUtilities import getCachedKey, getCachedToken
from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider


@pytest.mark.parametrize(
"idProviderType, idProviderName, issuer, username, group, scope, audience, expectedValue",
[
# Only a client name: this is mandatory
(OAuth2IdProvider, "IdPTest", "Issuer1", None, None, None, None, ("IdPTest", None, None, "IdPTest", "Issuer1")),
(
OAuth2IdProvider,
"IdPTest2",
"Issuer1",
None,
None,
None,
None,
("IdPTest2", None, None, "IdPTest2", "Issuer1"),
),
(
OAuth2IdProvider,
"IdPTest2",
"Issuer2",
None,
None,
None,
None,
("IdPTest2", None, None, "IdPTest2", "Issuer2"),
),
# Client name and username
(OAuth2IdProvider, "IdPTest", "Issuer1", "user", None, None, None, ("user", None, None, "IdPTest", "Issuer1")),
# Client name and group (should not add any permission in scope)
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
None,
"group",
None,
None,
("IdPTest", None, None, "IdPTest", "Issuer1"),
),
# Client name and scope
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
None,
None,
["permission:1", "permission:2"],
None,
("IdPTest", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
),
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
None,
None,
["permission:2", "permission:1"],
None,
("IdPTest", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
),
# Client name and audience
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
None,
None,
None,
"CE1",
("IdPTest", None, "CE1", "IdPTest", "Issuer1"),
),
# Client name, username, group
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
"user",
"group1",
None,
None,
("user", None, None, "IdPTest", "Issuer1"),
),
# Client name, username, scope
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
"user",
None,
["permission:1", "permission:2"],
None,
("user", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
),
# Client name, username, audience
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
"user",
None,
None,
"CE1",
("user", None, "CE1", "IdPTest", "Issuer1"),
),
# Client name, username, group, scope
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
"user",
"group1",
["permission:1", "permission:2"],
None,
("user", "permission:1 permission:2", None, "IdPTest", "Issuer1"),
),
# Client name, username, group, audience
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
"user",
"group1",
None,
"CE1",
("user", None, "CE1", "IdPTest", "Issuer1"),
),
# Client name, usergroup, scope, audience
(
OAuth2IdProvider,
"IdPTest",
"Issuer1",
"user",
"group1",
["permission:1", "permission:2"],
"CE1",
("user", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"),
),
],
)
def test_getCachedKey(idProviderType, idProviderName, issuer, username, group, scope, audience, expectedValue):
"""Test getCachedKey"""
# Prepare IdP
idProviderClient = idProviderType()
idProviderClient.name = idProviderName
idProviderClient.issuer = issuer

result = getCachedKey(idProviderClient, username, group, scope, audience)
assert result == expectedValue


@pytest.mark.parametrize(
"cachedKey, requiredTimeLeft, expectedValue",
[
# Normal case
(("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"), 0, S_OK()),
# Empty cachedKey
((), 0, S_ERROR("The key does not exist")),
# Wrong cachedKey
(("IdPTest", "permission:1", "CE1", "IdPTest", "Issuer1"), 0, S_ERROR("The key does not exist")),
# Expired token (650 > 150)
(
("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"),
650,
S_ERROR("Token found but expired"),
),
# Expired cachedKey (1500 > 1200)
(
("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"),
1500,
S_ERROR("The key does not exist"),
),
],
)
def test_getCachedToken(cachedKey, requiredTimeLeft, expectedValue):
"""Test getCachedToken"""
# Prepare cachedToken dictionary
cachedTokens = DictCache()
currentTime = time.time()
token = {
"sub": "0001234",
"aud": "CE1",
"nbf": currentTime - 150,
"scope": "permission:1 permission:2",
"iss": "Issuer1",
"exp": currentTime + 150,
"iat": currentTime - 150,
"jti": "000001234",
"client_id": "0001234",
}
tokenKey = ("IdPTest", "permission:1 permission:2", token["aud"], "IdPTest", token["iss"])
cachedTokens.add(tokenKey, 1200, OAuth2Token(token))

# Try to get the token from the cache
result = getCachedToken(cachedTokens, cachedKey, requiredTimeLeft)
assert result["OK"] == expectedValue["OK"]
if result["OK"]:
resultToken = result["Value"]
assert resultToken["sub"] == token["sub"]
assert resultToken["scope"] == token["scope"]
else:
assert result["Message"] == expectedValue["Message"]
2 changes: 2 additions & 0 deletions src/DIRAC/Resources/Computing/AREXComputingElement.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def _reset(self):

# Get options from the ceParameters dictionary
self.port = self.ceParameters.get("Port", self.port)
self.audienceName = f"https://{self.ceName}:{self.port}"

self.restVersion = self.ceParameters.get("RESTVersion", self.restVersion)
self.proxyTimeLeftBeforeRenewal = self.ceParameters.get(
"ProxyTimeLeftBeforeRenewal", self.proxyTimeLeftBeforeRenewal
Expand Down
16 changes: 11 additions & 5 deletions src/DIRAC/Resources/Computing/ComputingElement.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,21 @@ def __init__(self, ceName):
self.log = gLogger.getSubLogger(ceName)
self.ceName = ceName
self.ceParameters = {}
self.proxy = ""
self.token = None
self.valid = None
self.mandatoryParameters = []
self.batchSystem = None
self.taskResults = {}

# Token audience
# None by default, it needs to be redefined in subclasses
self.audienceName = None
self.token = None

self.proxy = ""
self.minProxyTime = gConfig.getValue("/Registry/MinProxyLifeTime", 10800) # secs
self.defaultProxyTime = gConfig.getValue("/Registry/DefaultProxyLifeTime", 43200) # secs
self.proxyCheckPeriod = gConfig.getValue("/Registry/ProxyCheckingPeriod", 3600) # secs
self.valid = None

self.batchSystem = None
self.taskResults = {}

clsName = self.__class__.__name__
if clsName.endswith("ComputingElement"):
Expand Down
Loading

0 comments on commit 292f478

Please sign in to comment.