From e724cd6db78347fcc1fd268c71de6016cdf69867 Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 27 Jun 2023 10:17:18 +0200 Subject: [PATCH] sweep: #7062 fix: interacting with CEs using tokens --- src/DIRAC/Core/Utilities/Grid.py | 18 +- .../Client/TokenManagerClient.py | 2 +- .../Service/TornadoTokenManagerHandler.py | 15 +- .../Utilities/TokenManagementUtilities.py | 6 +- .../test/Test_TokenManagementUtilities.py | 211 ++++++++++++++++++ .../Computing/AREXComputingElement.py | 2 + .../Resources/Computing/ComputingElement.py | 16 +- .../Computing/HTCondorCEComputingElement.py | 36 ++- .../test/Test_HTCondorCEComputingElement.py | 14 +- .../Resources/IdProvider/IdProviderFactory.py | 15 +- .../Resources/IdProvider/OAuth2IdProvider.py | 2 + .../Agent/SiteDirector.py | 22 +- .../Client/PilotScopes.py | 17 ++ .../Service/PilotManagerHandler.py | 1 + .../Service/WMSUtilities.py | 31 +-- .../Resources/IdProvider/Test_IdProvider.py | 34 ++- 16 files changed, 355 insertions(+), 87 deletions(-) create mode 100644 src/DIRAC/FrameworkSystem/Utilities/test/Test_TokenManagementUtilities.py create mode 100644 src/DIRAC/WorkloadManagementSystem/Client/PilotScopes.py diff --git a/src/DIRAC/Core/Utilities/Grid.py b/src/DIRAC/Core/Utilities/Grid.py index caa77c353fa..0eba42b7440 100644 --- a/src/DIRAC/Core/Utilities/Grid.py +++ b/src/DIRAC/Core/Utilities/Grid.py @@ -11,7 +11,7 @@ from DIRAC.Core.Utilities.Subprocess import systemCall, shellCall -def executeGridCommand(proxy, cmd, gridEnvScript=None, gridEnvDict=None): +def executeGridCommand(cmd, gridEnvScript=None, gridEnvDict=None): """ Execute cmd tuple after sourcing GridEnv """ @@ -37,22 +37,6 @@ def executeGridCommand(proxy, cmd, gridEnvScript=None, gridEnvDict=None): else: gridEnv = currentEnv - if not proxy: - res = getProxyInfo() - if not res["OK"]: - return res - gridEnv["X509_USER_PROXY"] = res["Value"]["path"] - elif isinstance(proxy, str): - if os.path.exists(proxy): - gridEnv["X509_USER_PROXY"] = proxy - else: - return S_ERROR("Can not treat proxy passed as a string") - else: - ret = gProxyManager.dumpProxyToFile(proxy) - if not ret["OK"]: - return ret - gridEnv["X509_USER_PROXY"] = ret["Value"] - if gridEnvDict: gridEnv.update(gridEnvDict) diff --git a/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py b/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py index 1081b822cb1..2157ef970ea 100644 --- a/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py +++ b/src/DIRAC/FrameworkSystem/Client/TokenManagerClient.py @@ -34,7 +34,7 @@ def getToken( self, username: str = None, userGroup: str = None, - scope: str = None, + scope: list[str] = None, audience: str = None, identityProvider: str = None, requiredTimeLeft: int = 0, diff --git a/src/DIRAC/FrameworkSystem/Service/TornadoTokenManagerHandler.py b/src/DIRAC/FrameworkSystem/Service/TornadoTokenManagerHandler.py index 8138869b0de..f15d2415136 100644 --- a/src/DIRAC/FrameworkSystem/Service/TornadoTokenManagerHandler.py +++ b/src/DIRAC/FrameworkSystem/Service/TornadoTokenManagerHandler.py @@ -61,6 +61,13 @@ def initializeHandler(cls, *args): :return: S_OK()/S_ERROR() """ + # Cache containing tokens from scope requested by the client + cls.__tokensCache = DictCache() + + # The service plays an important OAuth 2.0 role, namely it is an Identity Provider client. + # This allows you to manage tokens without the involvement of their owners. + cls.idps = IdProviderFactory() + # Let's try to connect to the database try: cls.__tokenDB = TokenDB(parentLogger=cls.log) @@ -68,12 +75,6 @@ def initializeHandler(cls, *args): cls.log.exception(e) return S_ERROR(f"Could not connect to the database {repr(e)}") - # Cache containing tokens from scope requested by the client - cls.__tokensCache = DictCache() - - # The service plays an important OAuth 2.0 role, namely it is an Identity Provider client. - # This allows you to manage tokens without the involvement of their owners. - cls.idps = IdProviderFactory() return S_OK() def export_getUserTokensInfo(self): @@ -185,7 +186,7 @@ def export_getToken( self, username: str = None, userGroup: str = None, - scope: str = None, + scope: list[str] = None, audience: str = None, identityProvider: str = None, requiredTimeLeft: int = 0, diff --git a/src/DIRAC/FrameworkSystem/Utilities/TokenManagementUtilities.py b/src/DIRAC/FrameworkSystem/Utilities/TokenManagementUtilities.py index 1e857274dde..d2572c1d3f7 100644 --- a/src/DIRAC/FrameworkSystem/Utilities/TokenManagementUtilities.py +++ b/src/DIRAC/FrameworkSystem/Utilities/TokenManagementUtilities.py @@ -30,7 +30,7 @@ def getCachedKey( idProviderClient, username: str = None, userGroup: str = None, - scope: str = None, + scope: list[str] = None, audience: str = None, ): """Build the key to potentially retrieve a cached token given the provided parameters. @@ -53,7 +53,9 @@ def getCachedKey( if userGroup and (result := idProviderClient.getGroupScopes(userGroup)): # What scope correspond to the requested group? scope = list(set((scope or []) + result)) - scope = " ".join(scope) + + if scope: + scope = " ".join(sorted(scope)) return (subject, scope, audience, idProviderClient.name, idProviderClient.issuer) diff --git a/src/DIRAC/FrameworkSystem/Utilities/test/Test_TokenManagementUtilities.py b/src/DIRAC/FrameworkSystem/Utilities/test/Test_TokenManagementUtilities.py new file mode 100644 index 00000000000..3e97f702968 --- /dev/null +++ b/src/DIRAC/FrameworkSystem/Utilities/test/Test_TokenManagementUtilities.py @@ -0,0 +1,211 @@ +""" Test IdProvider Factory""" +import pytest +import time + +from DIRAC import S_ERROR, S_OK +from DIRAC.Core.Utilities.DictCache import DictCache +from DIRAC.FrameworkSystem.private.authorization.utils.Tokens import OAuth2Token +from DIRAC.FrameworkSystem.Utilities.TokenManagementUtilities import getCachedKey, getCachedToken +from DIRAC.Resources.IdProvider.OAuth2IdProvider import OAuth2IdProvider + + +@pytest.mark.parametrize( + "idProviderType, idProviderName, issuer, username, group, scope, audience, expectedValue", + [ + # Only a client name: this is mandatory + (OAuth2IdProvider, "IdPTest", "Issuer1", None, None, None, None, ("IdPTest", None, None, "IdPTest", "Issuer1")), + ( + OAuth2IdProvider, + "IdPTest2", + "Issuer1", + None, + None, + None, + None, + ("IdPTest2", None, None, "IdPTest2", "Issuer1"), + ), + ( + OAuth2IdProvider, + "IdPTest2", + "Issuer2", + None, + None, + None, + None, + ("IdPTest2", None, None, "IdPTest2", "Issuer2"), + ), + # Client name and username + (OAuth2IdProvider, "IdPTest", "Issuer1", "user", None, None, None, ("user", None, None, "IdPTest", "Issuer1")), + # Client name and group (should not add any permission in scope) + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + None, + "group", + None, + None, + ("IdPTest", None, None, "IdPTest", "Issuer1"), + ), + # Client name and scope + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + None, + None, + ["permission:1", "permission:2"], + None, + ("IdPTest", "permission:1 permission:2", None, "IdPTest", "Issuer1"), + ), + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + None, + None, + ["permission:2", "permission:1"], + None, + ("IdPTest", "permission:1 permission:2", None, "IdPTest", "Issuer1"), + ), + # Client name and audience + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + None, + None, + None, + "CE1", + ("IdPTest", None, "CE1", "IdPTest", "Issuer1"), + ), + # Client name, username, group + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + "user", + "group1", + None, + None, + ("user", None, None, "IdPTest", "Issuer1"), + ), + # Client name, username, scope + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + "user", + None, + ["permission:1", "permission:2"], + None, + ("user", "permission:1 permission:2", None, "IdPTest", "Issuer1"), + ), + # Client name, username, audience + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + "user", + None, + None, + "CE1", + ("user", None, "CE1", "IdPTest", "Issuer1"), + ), + # Client name, username, group, scope + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + "user", + "group1", + ["permission:1", "permission:2"], + None, + ("user", "permission:1 permission:2", None, "IdPTest", "Issuer1"), + ), + # Client name, username, group, audience + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + "user", + "group1", + None, + "CE1", + ("user", None, "CE1", "IdPTest", "Issuer1"), + ), + # Client name, usergroup, scope, audience + ( + OAuth2IdProvider, + "IdPTest", + "Issuer1", + "user", + "group1", + ["permission:1", "permission:2"], + "CE1", + ("user", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"), + ), + ], +) +def test_getCachedKey(idProviderType, idProviderName, issuer, username, group, scope, audience, expectedValue): + """Test getCachedKey""" + # Prepare IdP + idProviderClient = idProviderType() + idProviderClient.name = idProviderName + idProviderClient.issuer = issuer + + result = getCachedKey(idProviderClient, username, group, scope, audience) + assert result == expectedValue + + +@pytest.mark.parametrize( + "cachedKey, requiredTimeLeft, expectedValue", + [ + # Normal case + (("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"), 0, S_OK()), + # Empty cachedKey + ((), 0, S_ERROR("The key does not exist")), + # Wrong cachedKey + (("IdPTest", "permission:1", "CE1", "IdPTest", "Issuer1"), 0, S_ERROR("The key does not exist")), + # Expired token (650 > 150) + ( + ("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"), + 650, + S_ERROR("Token found but expired"), + ), + # Expired cachedKey (1500 > 1200) + ( + ("IdPTest", "permission:1 permission:2", "CE1", "IdPTest", "Issuer1"), + 1500, + S_ERROR("The key does not exist"), + ), + ], +) +def test_getCachedToken(cachedKey, requiredTimeLeft, expectedValue): + """Test getCachedToken""" + # Prepare cachedToken dictionary + cachedTokens = DictCache() + currentTime = time.time() + token = { + "sub": "0001234", + "aud": "CE1", + "nbf": currentTime - 150, + "scope": "permission:1 permission:2", + "iss": "Issuer1", + "exp": currentTime + 150, + "iat": currentTime - 150, + "jti": "000001234", + "client_id": "0001234", + } + tokenKey = ("IdPTest", "permission:1 permission:2", token["aud"], "IdPTest", token["iss"]) + cachedTokens.add(tokenKey, 1200, OAuth2Token(token)) + + # Try to get the token from the cache + result = getCachedToken(cachedTokens, cachedKey, requiredTimeLeft) + assert result["OK"] == expectedValue["OK"] + if result["OK"]: + resultToken = result["Value"] + assert resultToken["sub"] == token["sub"] + assert resultToken["scope"] == token["scope"] + else: + assert result["Message"] == expectedValue["Message"] diff --git a/src/DIRAC/Resources/Computing/AREXComputingElement.py b/src/DIRAC/Resources/Computing/AREXComputingElement.py index 4a5262b67f8..d09f3216e22 100755 --- a/src/DIRAC/Resources/Computing/AREXComputingElement.py +++ b/src/DIRAC/Resources/Computing/AREXComputingElement.py @@ -65,6 +65,8 @@ def _reset(self): # Get options from the ceParameters dictionary self.port = self.ceParameters.get("Port", self.port) + self.audienceName = f"https://{self.ceName}:{self.port}" + self.restVersion = self.ceParameters.get("RESTVersion", self.restVersion) self.proxyTimeLeftBeforeRenewal = self.ceParameters.get( "ProxyTimeLeftBeforeRenewal", self.proxyTimeLeftBeforeRenewal diff --git a/src/DIRAC/Resources/Computing/ComputingElement.py b/src/DIRAC/Resources/Computing/ComputingElement.py index e29897607b4..bef3878c016 100755 --- a/src/DIRAC/Resources/Computing/ComputingElement.py +++ b/src/DIRAC/Resources/Computing/ComputingElement.py @@ -71,15 +71,21 @@ def __init__(self, ceName): self.log = gLogger.getSubLogger(ceName) self.ceName = ceName self.ceParameters = {} - self.proxy = "" - self.token = None - self.valid = None self.mandatoryParameters = [] - self.batchSystem = None - self.taskResults = {} + + # Token audience + # None by default, it needs to be redefined in subclasses + self.audienceName = None + self.token = None + + self.proxy = "" self.minProxyTime = gConfig.getValue("/Registry/MinProxyLifeTime", 10800) # secs self.defaultProxyTime = gConfig.getValue("/Registry/DefaultProxyLifeTime", 43200) # secs self.proxyCheckPeriod = gConfig.getValue("/Registry/ProxyCheckingPeriod", 3600) # secs + self.valid = None + + self.batchSystem = None + self.taskResults = {} clsName = self.__class__.__name__ if clsName.endswith("ComputingElement"): diff --git a/src/DIRAC/Resources/Computing/HTCondorCEComputingElement.py b/src/DIRAC/Resources/Computing/HTCondorCEComputingElement.py index a8f44843401..357ed85fecb 100644 --- a/src/DIRAC/Resources/Computing/HTCondorCEComputingElement.py +++ b/src/DIRAC/Resources/Computing/HTCondorCEComputingElement.py @@ -127,9 +127,11 @@ class HTCondorCEComputingElement(ComputingElement): def __init__(self, ceUniqueID): """Standard constructor.""" super().__init__(ceUniqueID) + self.mandatoryParameters = MANDATORY_PARAMETERS + self.port = "9619" + self.audienceName = f"{self.ceName}:{self.port}" self.submittedJobs = 0 - self.mandatoryParameters = MANDATORY_PARAMETERS self.pilotProxy = "" self.queue = "" self.outputURL = "gsiftp://localhost" @@ -171,12 +173,12 @@ def __writeSub(self, executable, nJobs, location, processors, pilotStamps, token executable = os.path.join(self.workingDirectory, executable) - useCredentials = "" + useCredentials = "use_x509userproxy = true" if tokenFile: - useCredentials = textwrap.dedent( + useCredentials += textwrap.dedent( f""" use_scitokens = true - scitokens_file = {tokenFile} + scitokens_file = {tokenFile.name} """ ) @@ -200,14 +202,13 @@ def __writeSub(self, executable, nJobs, location, processors, pilotStamps, token sub = """ executable = %(executable)s universe = %(targetUniverse)s -use_x509userproxy = true %(useCredentials)s output = $(Cluster).$(Process).out error = $(Cluster).$(Process).err log = $(Cluster).$(Process).log environment = "HTCONDOR_JOBID=$(Cluster).$(Process) DIRAC_PILOT_STAMP=$(stamp)" initialdir = %(initialDir)s -grid_resource = condor %(ceName)s %(ceName)s:9619 +grid_resource = condor %(ceName)s %(ceName)s:%(port)s transfer_output_files = "" request_cpus = %(processors)s %(localScheddOptions)s @@ -223,6 +224,7 @@ def __writeSub(self, executable, nJobs, location, processors, pilotStamps, token nJobs=nJobs, processors=processors, ceName=self.ceName, + port=self.port, extraString=self.extraSubmitString, initialDir=os.path.join(self.workingDirectory, location), localScheddOptions=localScheddOptions, @@ -246,7 +248,9 @@ def _reset(self): if self.useLocalSchedd == "False": self.useLocalSchedd = False - self.remoteScheddOptions = "" if self.useLocalSchedd else f"-pool {self.ceName}:9619 -name {self.ceName} " + self.remoteScheddOptions = ( + "" if self.useLocalSchedd else f"-pool {self.ceName}:{self.port} -name {self.ceName} " + ) self.log.debug("Using local schedd:", self.useLocalSchedd) self.log.debug("Remote scheduler option:", self.remoteScheddOptions) @@ -259,7 +263,18 @@ def _executeCondorCommand(self, cmd, keepTokenFile=False): :param bool keepTokenFile: flag to reuse or not the previously created token file :return: S_OK/S_ERROR - the result of the executeGridCommand() call """ + if not self.token and not self.proxy: + return S_ERROR(f"Cannot execute the command, token and proxy not found: {cmd}") + # Prepare proxy + result = self._prepareProxy() + if not result["OK"]: + return result + + htcEnv = { + "_CONDOR_SEC_CLIENT_AUTHENTICATION_METHODS": "GSI", + } + # If a token is present, then we use it (overriding htcEnv) if self.token: # Create a new token file if we do not keep it across several calls if not self.tokenFile or not keepTokenFile: @@ -274,11 +289,8 @@ def _executeCondorCommand(self, cmd, keepTokenFile=False): } if cas := getCAsLocation(): htcEnv["_CONDOR_AUTH_SSL_CLIENT_CADIR"] = cas - else: - htcEnv = {"_CONDOR_SEC_CLIENT_AUTHENTICATION_METHODS": "GSI"} result = executeGridCommand( - self.proxy, cmd, gridEnvScript=self.gridEnv, gridEnvDict=htcEnv, @@ -319,7 +331,7 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): cmd = ["condor_submit", "-terse", subName] # the options for submit to remote are different than the other remoteScheddOptions # -remote: submit to a remote condor_schedd and spool all the required inputs - scheddOptions = [] if self.useLocalSchedd else ["-pool", f"{self.ceName}:9619", "-remote", self.ceName] + scheddOptions = [] if self.useLocalSchedd else ["-pool", f"{self.ceName}:{self.port}", "-remote", self.ceName] for op in scheddOptions: cmd.insert(-1, op) @@ -544,7 +556,7 @@ def __getJobOutput(self, jobID, outTypes): return S_ERROR(e.errno, f"{errorMessage} ({iwd})") if not self.useLocalSchedd: - cmd = ["condor_transfer_data", "-pool", f"{self.ceName}:9619", "-name", self.ceName, condorID] + cmd = ["condor_transfer_data", "-pool", f"{self.ceName}:{self.port}", "-name", self.ceName, condorID] result = self._executeCondorCommand(cmd) self.log.verbose(result) diff --git a/src/DIRAC/Resources/Computing/test/Test_HTCondorCEComputingElement.py b/src/DIRAC/Resources/Computing/test/Test_HTCondorCEComputingElement.py index 70988dd109f..b5acb6211b6 100644 --- a/src/DIRAC/Resources/Computing/test/Test_HTCondorCEComputingElement.py +++ b/src/DIRAC/Resources/Computing/test/Test_HTCondorCEComputingElement.py @@ -61,7 +61,7 @@ def test_parseCondorStatus(): def test_getJobStatus(mocker): """Test HTCondorCE getJobStatus""" mocker.patch( - MODNAME + ".executeGridCommand", + MODNAME + ".HTCondorCEComputingElement._executeCondorCommand", side_effect=[ S_OK((0, "\n".join(STATUS_LINES), "")), S_OK((0, "\n".join(HISTORY_LINES), "")), @@ -170,7 +170,9 @@ def test_submitJob(setUp, mocker, localSchedd, expected): ceName = "condorce.cern.ch" htce.ceName = ceName - execMock = mocker.patch(MODNAME + ".executeGridCommand", return_value=S_OK((0, "123.0 - 123.0", ""))) + execMock = mocker.patch( + MODNAME + ".HTCondorCEComputingElement._executeCondorCommand", return_value=S_OK((0, "123.0 - 123.0", "")) + ) mocker.patch( MODNAME + ".HTCondorCEComputingElement._HTCondorCEComputingElement__writeSub", return_value="dirac_pilot" ) @@ -179,7 +181,7 @@ def test_submitJob(setUp, mocker, localSchedd, expected): result = htce.submitJob("pilot", "proxy", 1) assert result["OK"] is True - assert " ".join(execMock.call_args_list[0][0][1]) == expected + assert " ".join(execMock.call_args_list[0][0][0]) == expected @pytest.mark.parametrize( @@ -202,10 +204,12 @@ def test_killJob(setUp, mocker, jobIDList, jobID, ret, success, local): htce.ceParameters = ceParameters htce._reset() - execMock = mocker.patch(MODNAME + ".executeGridCommand", return_value=S_OK((ret, "", ""))) + execMock = mocker.patch( + MODNAME + ".HTCondorCEComputingElement._executeCondorCommand", return_value=S_OK((ret, "", "")) + ) ret = htce.killJob(jobIDList=jobIDList) assert ret["OK"] == success if jobID: expected = f"condor_rm {htce.remoteScheddOptions.strip()} {jobID}" - assert " ".join(execMock.call_args_list[0][0][1]) == expected + assert " ".join(execMock.call_args_list[0][0][0]) == expected diff --git a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py index cb7670ae471..6dd34c79c59 100644 --- a/src/DIRAC/Resources/IdProvider/IdProviderFactory.py +++ b/src/DIRAC/Resources/IdProvider/IdProviderFactory.py @@ -50,26 +50,25 @@ def getIdProvider(self, name, **kwargs): """ if not name: return S_ERROR("Identity Provider client name must be not None.") - # Get Authorization Server metadata - try: - asMetaDict = collectMetadata(kwargs.get("issuer"), ignoreErrors=True) - except Exception as e: - return S_ERROR(str(e)) self.log.debug("Search configuration for", name) + clients = getDIRACClients() if name in clients: # If it is a DIRAC default pre-registered client + # Get Authorization Server metadata + try: + asMetaDict = collectMetadata(kwargs.get("issuer"), ignoreErrors=True) + except Exception as e: + return S_ERROR(str(e)) pDict = asMetaDict pDict.update(clients[name]) else: - # if it is external identity provider client + # If it is external identity provider client result = gConfig.getOptionsDict(f"/Resources/IdProviders/{name}") if not result["OK"]: self.log.error("Failed to read configuration", f"{name}: {result['Message']}") return result pDict = result["Value"] - # Set default redirect_uri - pDict["redirect_uri"] = pDict.get("redirect_uri", asMetaDict["redirect_uri"]) pDict.update(kwargs) pDict["ProviderName"] = name diff --git a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py index 445c6630de3..792f2d5e226 100644 --- a/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py +++ b/src/DIRAC/Resources/IdProvider/OAuth2IdProvider.py @@ -429,6 +429,8 @@ def fetchToken(self, **kwargs): """ try: self.fetch_access_token(self.get_metadata("token_endpoint"), **kwargs) + except OAuthError as e: + return S_ERROR(f"Cannot fetch access token: {e}") except Exception as e: self.log.exception(e) return S_ERROR(repr(e)) diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py index 0d35c5e91eb..b8947afb4b1 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py @@ -32,6 +32,8 @@ from DIRAC.ResourceStatusSystem.Client.ResourceStatus import ResourceStatus from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus from DIRAC.WorkloadManagementSystem.Client import PilotStatus +from DIRAC.WorkloadManagementSystem.Client.PilotScopes import PILOT_SCOPES + from DIRAC.WorkloadManagementSystem.Client.MatcherClient import MatcherClient from DIRAC.WorkloadManagementSystem.Client.ServerUtils import getPilotAgentsDB from DIRAC.WorkloadManagementSystem.private.ConfigHelper import findGenericPilotCredentials @@ -445,7 +447,7 @@ def submitPilots(self): # Get valid token if needed if "Token" in ce.ceParameters.get("Tag", []): - result = self.__getPilotToken() + result = self.__getPilotToken(audience=ce.audienceName) if not result["OK"]: return result ce.setToken(result["Value"], 3500) @@ -465,13 +467,20 @@ def submitPilots(self): return S_OK() - def __getPilotToken(self): + def __getPilotToken(self, audience: str, scope: list[str] = None): """Get the token corresponding to the pilot user identity + :param audience: Token audience, targeting a single CE + :param scope: list of permissions needed to interact with a CE :return: S_OK/S_ERROR, Token object as Value """ - result = gTokenManager.getToken(userGroup=self.pilotGroup, requiredTimeLeft=600) - return result + if not audience: + return S_ERROR("Audience is not defined") + + if not scope: + scope = PILOT_SCOPES + + return gTokenManager.getToken(userGroup=self.pilotGroup, requiredTimeLeft=600, scope=scope, audience=audience) def _ifAndWhereToSubmit(self): """Return a tuple that says if and where to submit pilots: @@ -1230,9 +1239,10 @@ def _updatePilotStatusPerQueue(self, queue, proxy): # Get valid token if needed if "Token" in ce.ceParameters.get("Tag", []): - result = self.__getPilotToken() + result = self.__getPilotToken(audience=ce.audienceName) if not result["OK"]: - return result + self.log.error("Failed to get token", f"{ceName}: {result['Message']}") + return ce.setToken(result["Value"], 3500) result = ce.getJobStatus(stampedPilotRefs) diff --git a/src/DIRAC/WorkloadManagementSystem/Client/PilotScopes.py b/src/DIRAC/WorkloadManagementSystem/Client/PilotScopes.py new file mode 100644 index 00000000000..9e953bb7b91 --- /dev/null +++ b/src/DIRAC/WorkloadManagementSystem/Client/PilotScopes.py @@ -0,0 +1,17 @@ +""" +This module contains constants and lists for the possible scopes to interact with pilots on CEs. +""" + +# Based on: https://github.com/WLCG-AuthZ-WG/common-jwt-profile/blob/master/profile.md#capability-based-authorization-scope + +#: To submit pilots: +CREATE = "compute.create" +#: To cancel pilots: +CANCEL = "compute.cancel" +#: To modify attributes of submitted pilots: +MODIFY = "compute.modify" +#: To read information about submitted pilots: +READ = "compute.read" + +#: Possible pilot scopes: +PILOT_SCOPES = [CANCEL, CREATE, MODIFY, READ] diff --git a/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py index 3862b0ed613..f89f931e77d 100644 --- a/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/PilotManagerHandler.py @@ -123,6 +123,7 @@ def export_getPilotOutput(self, pilotReference): if not hasattr(ce, "getJobOutput"): return S_ERROR(f"Pilot output not available for {pilotDict['GridType']} CEs") + # Set proxy or token for the CE result = setPilotCredentials(ce, pilotDict) if not result["OK"]: return result diff --git a/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py b/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py index 2c9f5856cf8..dde68b3552d 100644 --- a/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py @@ -10,6 +10,7 @@ from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager from DIRAC.FrameworkSystem.Client.TokenManagerClient import gTokenManager from DIRAC.Resources.Computing.ComputingElementFactory import ComputingElementFactory +from DIRAC.WorkloadManagementSystem.Client.PilotScopes import PILOT_SCOPES # List of files to be inserted/retrieved into/from pilot Output Sandbox @@ -67,27 +68,6 @@ def getPilotProxy(pilotDict): return S_OK(proxy) -def getPilotToken(pilotDict): - """Get a token corresponding to the pilot - - :param dict pilotDict: pilot parameters - :return: S_OK/S_ERROR with token as Value - """ - ownerDN = pilotDict["OwnerDN"] - group = pilotDict["OwnerGroup"] - - result = getUsernameForDN(ownerDN) - if not result["OK"]: - return result - username = result["Value"] - result = gTokenManager.getToken( - username=username, - userGroup=group, - requiredTimeLeft=3600, - ) - return result - - def setPilotCredentials(ce, pilotDict): """Instrument the given CE with proxy or token @@ -96,10 +76,15 @@ def setPilotCredentials(ce, pilotDict): :return: S_OK/S_ERROR """ if "Token" in ce.ceParameters.get("Tag", []): - result = getPilotToken(pilotDict) + result = gTokenManager.getToken( + userGroup=pilotDict["OwnerGroup"], + scope=PILOT_SCOPES, + audience=ce.audienceName, + requiredTimeLeft=150, + ) if not result["OK"]: return result - ce.setToken(result["Value"], 3500) + ce.setToken(result["Value"]) else: result = getPilotProxy(pilotDict) if not result["OK"]: diff --git a/tests/Integration/Resources/IdProvider/Test_IdProvider.py b/tests/Integration/Resources/IdProvider/Test_IdProvider.py index 748357bc451..581c35d19a1 100644 --- a/tests/Integration/Resources/IdProvider/Test_IdProvider.py +++ b/tests/Integration/Resources/IdProvider/Test_IdProvider.py @@ -448,7 +448,7 @@ def test_refreshToken(iam_connection, token, update, expectedValue): assert result["OK"] == expectedValue["OK"] if result["OK"]: resultToken = result["Value"] - assert resultToken["scope"].split(" ").sort() == baseParams["scope"].split("+").sort() + assert sorted(resultToken["scope"].split(" ")) == sorted(baseParams["scope"].split("+")) # Update the valid access token for next tests if update: token["access_token"] = resultToken["access_token"] @@ -456,6 +456,38 @@ def test_refreshToken(iam_connection, token, update, expectedValue): assert expectedValue["Message"] in result["Message"] +@pytest.mark.parametrize( + "grantType, scope, audience, expectedValue", + [ + # Client credentials + # No scope, no audience + ("client_credentials", None, None, {"OK": True}), + # Scope, no audience + ("client_credentials", ["openid"], None, {"OK": True}), + ("client_credentials", ["openid", "profile"], None, {"OK": True}), + # Scope, audience + ("client_credentials", ["openid"], "ce1.test.ch", {"OK": True}), + # Invalid scope + ("client_credentials", ["compute.read"], None, {"OK": False, "Message": "Cannot fetch access token"}), + ], +) +def test_fetchToken(iam_connection, grantType, scope, audience, expectedValue): + """Test fetchToken""" + idProvider = IAMIdProvider(**baseParams) + + result = idProvider.fetchToken(grant_type=grantType, scope=scope, audience=audience) + assert result["OK"] == expectedValue["OK"] + if result["OK"]: + resultToken = result["Value"] + + # Default scope + if not scope: + scope = baseParams["scope"].split("+") + assert sorted(resultToken["scope"].split(" ")) == sorted(scope) + else: + assert expectedValue["Message"] in result["Message"] + + @pytest.mark.parametrize( "token, expectedValue", [