diff --git a/setup.cfg b/setup.cfg index ab177585786..7c05027ac99 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ install_requires = psutil pyasn1 pyasn1-modules - pydantic + pydantic >=2.4 pyparsing python-dateutil pytz diff --git a/src/DIRAC/Core/Utilities/test/Test_JDL.py b/src/DIRAC/Core/Utilities/test/Test_JDL.py index dcdfa4861f4..b918bf3bb2e 100644 --- a/src/DIRAC/Core/Utilities/test/Test_JDL.py +++ b/src/DIRAC/Core/Utilities/test/Test_JDL.py @@ -79,7 +79,7 @@ def test_jdlToBaseJobDescriptionModel_valid(jdl_monkey_business): res = jdlToBaseJobDescriptionModel(ClassAd(jdl)) assert res["OK"], res["Message"] - data = res["Value"].dict() + data = res["Value"].model_dump() assert JobDescriptionModel(owner="owner", ownerGroup="ownerGroup", vo="lhcb", **data) diff --git a/src/DIRAC/Resources/Computing/BatchSystems/SLURM.py b/src/DIRAC/Resources/Computing/BatchSystems/SLURM.py index 70fe9be45d5..41f38a492ee 100644 --- a/src/DIRAC/Resources/Computing/BatchSystems/SLURM.py +++ b/src/DIRAC/Resources/Computing/BatchSystems/SLURM.py @@ -125,7 +125,7 @@ def _generateSrunWrapper(self, executableFile): content = f.read() # Need to escape environment variables of the executable file - content = re.sub("\$", "\\$", content) + content = re.sub(r"\$", r"\\$", content) # Build the script to run the executable in parallel multiple times # - Embed the content of executableFile inside the parallel library wrapper script diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py b/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py index c9bcf3c7075..3ea951d6e9b 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py @@ -3,11 +3,9 @@ # pylint: disable=no-self-argument, no-self-use, invalid-name, missing-function-docstring from collections.abc import Iterable -from typing import Any, Annotated +from typing import Any, Annotated, TypeAlias, Self -import pydantic -from packaging.version import Version -from pydantic import BaseModel, root_validator, validator +from pydantic import BaseModel, BeforeValidator, model_validator, field_validator, ConfigDict from DIRAC import gLogger from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations @@ -16,7 +14,9 @@ # HACK: Convert appropriate iterables into sets def default_set_validator(value): - if not isinstance(value, Iterable): + if value is None: + return set() + elif not isinstance(value, Iterable): return value elif isinstance(value, (str, bytes, bytearray)): return value @@ -24,49 +24,47 @@ def default_set_validator(value): return set(value) -if Version(pydantic.__version__) > Version("2.0.0a0"): - CoercibleSetStr = Annotated[ - set[str] | None, pydantic.BeforeValidator(default_set_validator) # pylint: disable=no-member - ] -else: - CoercibleSetStr = set[str] +CoercibleSetStr: TypeAlias = Annotated[set[str], BeforeValidator(default_set_validator)] class BaseJobDescriptionModel(BaseModel): """Base model for the job description (not parametric)""" - class Config: - validate_assignment = True + model_config = ConfigDict(validate_assignment=True) - arguments: str = None - bannedSites: CoercibleSetStr = None + arguments: str = "" + bannedSites: CoercibleSetStr = set() + # TODO: This should use a field factory cpuTime: int = Operations().getValue("JobDescription/DefaultCPUTime", 86400) executable: str executionEnvironment: dict = None - gridCE: str = None - inputSandbox: CoercibleSetStr = None - inputData: CoercibleSetStr = None - inputDataPolicy: str = None - jobConfigArgs: str = None - jobGroup: str = None + gridCE: str = "" + inputSandbox: CoercibleSetStr = set() + inputData: CoercibleSetStr = set() + inputDataPolicy: str = "" + jobConfigArgs: str = "" + jobGroup: str = "" jobType: str = "User" jobName: str = "Name" + # TODO: This should be an StrEnum logLevel: str = "INFO" + # TODO: This can't be None with this type hint maxNumberOfProcessors: int = None minNumberOfProcessors: int = 1 - outputData: CoercibleSetStr = None - outputPath: str = None - outputSandbox: CoercibleSetStr = None - outputSE: str = None - platform: str = None + outputData: CoercibleSetStr = set() + outputPath: str = "" + outputSandbox: CoercibleSetStr = set() + outputSE: str = "" + platform: str = "" + # TODO: This should use a field factory priority: int = Operations().getValue("JobDescription/DefaultPriority", 1) - sites: CoercibleSetStr = None + sites: CoercibleSetStr = set() stderr: str = "std.err" stdout: str = "std.out" - tags: CoercibleSetStr = None - extraFields: dict[str, Any] = None + tags: CoercibleSetStr = set() + extraFields: dict[str, Any] = {} - @validator("cpuTime") + @field_validator("cpuTime") def checkCPUTimeBounds(cls, v): minCPUTime = Operations().getValue("JobDescription/MinCPUTime", 100) maxCPUTime = Operations().getValue("JobDescription/MaxCPUTime", 500000) @@ -74,13 +72,13 @@ def checkCPUTimeBounds(cls, v): raise ValueError(f"cpuTime out of bounds (must be between {minCPUTime} and {maxCPUTime})") return v - @validator("executable") + @field_validator("executable") def checkExecutableIsNotAnEmptyString(cls, v: str): if not v: raise ValueError("executable must not be an empty string") return v - @validator("jobType") + @field_validator("jobType") def checkJobTypeIsAllowed(cls, v: str): jobTypes = Operations().getValue("JobDescription/AllowedJobTypes", ["User", "Test", "Hospital"]) transformationTypes = Operations().getValue("Transformations/DataProcessing", []) @@ -89,7 +87,7 @@ def checkJobTypeIsAllowed(cls, v: str): raise ValueError(f"jobType '{v}' is not allowed for this kind of user (must be in {allowedTypes})") return v - @validator("inputData") + @field_validator("inputData") def checkInputDataDoesntContainDoubleSlashes(cls, v): if v: for lfn in v: @@ -97,7 +95,7 @@ def checkInputDataDoesntContainDoubleSlashes(cls, v): raise ValueError("Input data contains //") return v - @validator("inputData") + @field_validator("inputData") def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]): if v: v = {lfn.strip() for lfn in v if lfn.strip()} @@ -108,22 +106,22 @@ def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]): raise ValueError("Input data files must start with LFN:/") return v - @root_validator(skip_on_failure=True) - def checkNumberOfInputDataFiles(cls, values): - if "inputData" in values and values["inputData"]: + @model_validator(mode="after") + def checkNumberOfInputDataFiles(self) -> Self: + if self.inputData: maxInputDataFiles = Operations().getValue("JobDescription/MaxInputData", 500) - if values["jobType"] == "User" and len(values["inputData"]) >= maxInputDataFiles: + if self.jobType == "User" and len(self.inputData) >= maxInputDataFiles: raise ValueError(f"inputData contains too many files (must contain at most {maxInputDataFiles})") - return values + return self - @validator("inputSandbox") + @field_validator("inputSandbox") def checkLFNSandboxesAreWellFormated(cls, v: set[str]): for inputSandbox in v: if inputSandbox.startswith("LFN:") and not inputSandbox.startswith("LFN:/"): raise ValueError("LFN files must start by LFN:/") return v - @validator("logLevel") + @field_validator("logLevel") def checkLogLevelIsValid(cls, v: str): v = v.upper() possibleLogLevels = gLogger.getAllPossibleLevels() @@ -131,7 +129,7 @@ def checkLogLevelIsValid(cls, v: str): raise ValueError(f"Log level {v} not in {possibleLogLevels}") return v - @validator("minNumberOfProcessors") + @field_validator("minNumberOfProcessors") def checkMinNumberOfProcessorsBounds(cls, v): minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1) maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024) @@ -141,7 +139,7 @@ def checkMinNumberOfProcessorsBounds(cls, v): ) return v - @validator("maxNumberOfProcessors") + @field_validator("maxNumberOfProcessors") def checkMaxNumberOfProcessorsBounds(cls, v): minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1) maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024) @@ -151,27 +149,22 @@ def checkMaxNumberOfProcessorsBounds(cls, v): ) return v - @root_validator(skip_on_failure=True) - def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(cls, values): - if "maxNumberOfProcessors" in values and values["maxNumberOfProcessors"]: - if values["maxNumberOfProcessors"] < values["minNumberOfProcessors"]: + @model_validator(mode="after") + def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(self) -> Self: + if self.maxNumberOfProcessors: + if self.maxNumberOfProcessors < self.minNumberOfProcessors: raise ValueError("maxNumberOfProcessors must be greater than minNumberOfProcessors") - return values - - @root_validator(skip_on_failure=True) - def addTagsDependingOnNumberOfProcessors(cls, values): - if "maxNumberOfProcessors" in values and values["minNumberOfProcessors"] == values["maxNumberOfProcessors"]: - if values["tags"] is None: - values["tags"] = set() - values["tags"].add(f"{values['minNumberOfProcessors']}Processors") - if values["minNumberOfProcessors"] > 1: - if values["tags"] is None: - values["tags"] = set() - values["tags"].add("MultiProcessor") - - return values - - @validator("sites") + return self + + @model_validator(mode="after") + def addTagsDependingOnNumberOfProcessors(self) -> Self: + if self.minNumberOfProcessors == self.maxNumberOfProcessors: + self.tags.add(f"{self.minNumberOfProcessors}Processors") + if self.minNumberOfProcessors > 1: + self.tags.add("MultiProcessor") + return self + + @field_validator("sites") def checkSites(cls, v: set[str]): if v: res = getSites() @@ -182,16 +175,16 @@ def checkSites(cls, v: set[str]): raise ValueError(f"Invalid sites: {' '.join(invalidSites)}") return v - @root_validator(skip_on_failure=True) - def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(cls, values): - if "sites" in values and values["sites"] and "bannedSites" in values and values["bannedSites"]: - values["sites"] -= values["bannedSites"] - values["bannedSites"] = None - if not values["sites"]: + @model_validator(mode="after") + def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(self) -> Self: + if self.sites and self.bannedSites: + while self.bannedSites: + self.sites.discard(self.bannedSites.pop()) + if not self.sites: raise ValueError("sites and bannedSites are mutually exclusive") - return values + return self - @validator("platform") + @field_validator("platform") def checkPlatform(cls, v: str): if v: res = getDIRACPlatforms() @@ -201,7 +194,7 @@ def checkPlatform(cls, v: str): raise ValueError("Invalid platform") return v - @validator("priority") + @field_validator("priority") def checkPriorityBounds(cls, v): minPriority = Operations().getValue("JobDescription/MinPriority", 0) maxPriority = Operations().getValue("JobDescription/MaxPriority", 10) @@ -217,10 +210,10 @@ class JobDescriptionModel(BaseJobDescriptionModel): ownerGroup: str vo: str - @root_validator(skip_on_failure=True) - def checkLFNMatchesREGEX(cls, values): - if "inputData" in values and values["inputData"]: - for lfn in values["inputData"]: - if not lfn.startswith(f"LFN:/{values['vo']}/"): - raise ValueError(f"Input data not correctly specified (must start with LFN:/{values['vo']}/)") - return values + @model_validator(mode="after") + def checkLFNMatchesREGEX(self) -> Self: + if self.inputData: + for lfn in self.inputData: + if not lfn.startswith(f"LFN:/{self.vo}/"): + raise ValueError(f"Input data not correctly specified (must start with LFN:/{self.vo}/)") + return self diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py index b8f6c15e0d3..63ff1703942 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py @@ -231,9 +231,9 @@ def test_sitesValidator_invalid(validSites, selectedSites): @pytest.mark.parametrize( "sites, bannedSites, parsedSites, parsedBannedSites", [ - ({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None), - (None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}), - ({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, {"LCG.PIC.es", "LCG.CNAF.it"}, {"LCG.IN2P3.fr"}, None), + ({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, set()), + (None, {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, set(), {"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}), + ({"LCG.PIC.es", "LCG.CNAF.it", "LCG.IN2P3.fr"}, {"LCG.PIC.es", "LCG.CNAF.it"}, {"LCG.IN2P3.fr"}, set()), ], ) def test_checkThatSitesAndBannedSitesAreNotMutuallyExclusive_valid( diff --git a/src/DIRAC/__init__.py b/src/DIRAC/__init__.py index 04301701856..d24c7988da5 100755 --- a/src/DIRAC/__init__.py +++ b/src/DIRAC/__init__.py @@ -55,12 +55,12 @@ """ import os +import importlib.metadata import re import sys import warnings from pkgutil import extend_path from typing import Any, Optional, Union -from pkg_resources import get_distribution, DistributionNotFound __path__ = extend_path(__path__, __name__) @@ -81,9 +81,9 @@ # Define Version try: - __version__ = get_distribution(__name__).version + __version__ = importlib.metadata.version(__name__) version = __version__ -except DistributionNotFound: +except importlib.metadata.PackageNotFoundError: # package is not installed version = "Unknown"