diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 4431d2f4..0283767c 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -21,7 +21,7 @@ from cachetools import Cache, LRUCache, TTLCache, cachedmethod from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints -from ..exceptions import BadConfigurationVersion +from ..exceptions import BadConfigurationVersionError from ..extensions import select_from_extension from .schema import Config @@ -159,7 +159,9 @@ def latest_revision(self) -> tuple[str, datetime]: ).strip() modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc) except sh.ErrorReturnCode as e: - raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e + raise BadConfigurationVersionError( + f"Error parsing latest revision: {e}" + ) from e logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified) return rev, modified @@ -176,7 +178,9 @@ def read_raw(self, hexsha: str, modified: datetime) -> Config: ) raw_obj = yaml.safe_load(blob) except sh.ErrorReturnCode as e: - raise BadConfigurationVersion(f"Error reading configuration: {e}") from e + raise BadConfigurationVersionError( + f"Error reading configuration: {e}" + ) from e config_class: Config = select_from_extension(group="diracx", name="config")[ 0 diff --git a/diracx-core/src/diracx/core/config/schema.py b/diracx-core/src/diracx/core/config/schema.py index 92d623da..8da2837a 100644 --- a/diracx-core/src/diracx/core/config/schema.py +++ b/diracx-core/src/diracx/core/config/schema.py @@ -115,7 +115,6 @@ class DIRACConfig(BaseModel): class JobMonitoringConfig(BaseModel): GlobalJobsInfo: bool = True - useESForJobParametersFlag: bool = False class JobSchedulingConfig(BaseModel): diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 3338f3b1..79834b1c 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -1,7 +1,7 @@ from http import HTTPStatus -class DiracHttpResponse(RuntimeError): +class DiracHttpResponseError(RuntimeError): def __init__(self, status_code: int, data): self.status_code = status_code self.data = data @@ -30,7 +30,7 @@ class ConfigurationError(DiracError): """Used whenever we encounter a problem with the configuration.""" -class BadConfigurationVersion(ConfigurationError): +class BadConfigurationVersionError(ConfigurationError): """The requested version is not known.""" @@ -38,7 +38,7 @@ class InvalidQueryError(DiracError): """It was not possible to build a valid database query from the given input.""" -class JobNotFound(Exception): +class JobNotFoundError(Exception): def __init__(self, job_id: int, detail: str | None = None): self.job_id: int = job_id super().__init__(f"Job {job_id} not found" + (" ({detail})" if detail else "")) diff --git a/diracx-db/src/diracx/db/exceptions.py b/diracx-db/src/diracx/db/exceptions.py index ca0cf0ec..0a163f92 100644 --- a/diracx-db/src/diracx/db/exceptions.py +++ b/diracx-db/src/diracx/db/exceptions.py @@ -1,2 +1,2 @@ -class DBUnavailable(Exception): +class DBUnavailableError(Exception): pass diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index 8b611c00..431cceaa 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -16,7 +16,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class OpenSearchDBError(Exception): pass -class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError): +class OpenSearchDBUnavailableError(DBUnavailableError, OpenSearchDBError): pass @@ -152,7 +152,7 @@ async def ping(self): be ran at every query. """ if not await self.client.ping(): - raise OpenSearchDBUnavailable( + raise OpenSearchDBUnavailableError( f"Failed to connect to {self.__class__.__qualname__}" ) diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index cd8d1d03..b587f869 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -58,7 +58,7 @@ async def get_device_flow(self, device_code: str, max_validity: int): stmt = select( DeviceFlows, (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label( - "is_expired" + "IsExpired" ), ).with_for_update() stmt = stmt.where( @@ -66,10 +66,10 @@ async def get_device_flow(self, device_code: str, max_validity: int): ) res = dict((await self.conn.execute(stmt)).one()._mapping) - if res["is_expired"]: + if res["IsExpired"]: raise ExpiredFlowError() - if res["status"] == FlowStatus.READY: + if res["Status"] == FlowStatus.READY: # Update the status to Done before returning await self.conn.execute( update(DeviceFlows) @@ -81,10 +81,10 @@ async def get_device_flow(self, device_code: str, max_validity: int): ) return res - if res["status"] == FlowStatus.DONE: + if res["Status"] == FlowStatus.DONE: raise AuthorizationError("Code was already used") - if res["status"] == FlowStatus.PENDING: + if res["Status"] == FlowStatus.PENDING: raise PendingAuthorizationError() raise AuthorizationError("Bad state in device flow") @@ -190,7 +190,7 @@ async def authorization_flow_insert_id_token( stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri) stmt = stmt.where(AuthorizationFlows.uuid == uuid) row = (await self.conn.execute(stmt)).one() - return code, row.redirect_uri + return code, row.RedirectURI async def get_authorization_flow(self, code: str, max_validity: int): hashed_code = hashlib.sha256(code.encode()).hexdigest() @@ -205,7 +205,7 @@ async def get_authorization_flow(self, code: str, max_validity: int): res = dict((await self.conn.execute(stmt)).one()._mapping) - if res["status"] == FlowStatus.READY: + if res["Status"] == FlowStatus.READY: # Update the status to Done before returning await self.conn.execute( update(AuthorizationFlows) @@ -215,7 +215,7 @@ async def get_authorization_flow(self, code: str, max_validity: int): return res - if res["status"] == FlowStatus.DONE: + if res["Status"] == FlowStatus.DONE: raise AuthorizationError("Code was already used") raise AuthorizationError("Bad state in authorization flow") @@ -247,7 +247,7 @@ async def insert_refresh_token( row = (await self.conn.execute(stmt)).one() # Return the JWT ID and the creation time - return jti, row.creation_time + return jti, row.CreationTime async def get_refresh_token(self, jti: str) -> dict: """Get refresh token details bound to a given JWT ID.""" diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index b6efbede..8d7dddc7 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -39,27 +39,27 @@ class FlowStatus(Enum): class DeviceFlows(Base): __tablename__ = "DeviceFlows" - user_code = Column(String(USER_CODE_LENGTH), primary_key=True) - status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name) - creation_time = DateNowColumn() - client_id = Column(String(255)) - scope = Column(String(1024)) - device_code = Column(String(128), unique=True) # Should be a hash - id_token = NullColumn(JSON()) + user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True) + status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name) + creation_time = DateNowColumn("CreationTime") + client_id = Column("ClientID", String(255)) + scope = Column("Scope", String(1024)) + device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash + id_token = NullColumn("IDToken", JSON()) class AuthorizationFlows(Base): __tablename__ = "AuthorizationFlows" - uuid = Column(Uuid(as_uuid=False), primary_key=True) - status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name) - client_id = Column(String(255)) - creation_time = DateNowColumn() - scope = Column(String(1024)) - code_challenge = Column(String(255)) - code_challenge_method = Column(String(8)) - redirect_uri = Column(String(255)) - code = NullColumn(String(255)) # Should be a hash - id_token = NullColumn(JSON()) + uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True) + status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name) + client_id = Column("ClientID", String(255)) + creation_time = DateNowColumn("CretionTime") + scope = Column("Scope", String(1024)) + code_challenge = Column("CodeChallenge", String(255)) + code_challenge_method = Column("CodeChallengeMethod", String(8)) + redirect_uri = Column("RedirectURI", String(255)) + code = NullColumn("Code", String(255)) # Should be a hash + id_token = NullColumn("IDToken", JSON()) class RefreshTokenStatus(Enum): @@ -85,13 +85,13 @@ class RefreshTokens(Base): __tablename__ = "RefreshTokens" # Refresh token attributes - jti = Column(Uuid(as_uuid=False), primary_key=True) + jti = Column("JTI", Uuid(as_uuid=False), primary_key=True) status = EnumColumn( - RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name + "Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name ) - creation_time = DateNowColumn() - scope = Column(String(1024)) + creation_time = DateNowColumn("CreationTime") + scope = Column("Scope", String(1024)) # User attributes bound to the refresh token - sub = Column(String(1024)) - preferred_username = Column(String(255)) + sub = Column("Sub", String(1024)) + preferred_username = Column("PreferredUsername", String(255)) diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 9a033163..fa6bd8f1 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -25,7 +25,7 @@ class DummyDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] - stmt = select(*columns, func.count(Cars.licensePlate).label("count")) + stmt = select(*columns, func.count(Cars.license_plate).label("count")) stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -44,7 +44,7 @@ async def insert_owner(self, name: str) -> int: async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int: stmt = insert(Cars).values( - licensePlate=license_plate, model=model, ownerID=owner_id + license_plate=license_plate, model=model, owner_id=owner_id ) result = await self.conn.execute(stmt) diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index ebb37b8d..a0c11c09 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -10,13 +10,13 @@ class Owners(Base): __tablename__ = "Owners" - ownerID = Column(Integer, primary_key=True, autoincrement=True) - creation_time = DateNowColumn() - name = Column(String(255)) + owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) + creation_time = DateNowColumn("CreationTime") + name = Column("Name", String(255)) class Cars(Base): __tablename__ = "Cars" - licensePlate = Column(Uuid(), primary_key=True) - model = Column(String(255)) - ownerID = Column(Integer, ForeignKey(Owners.ownerID)) + license_plate = Column("LicensePlate", Uuid(), primary_key=True) + model = Column("Model", String(255)) + owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id)) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 7817bb39..443372a1 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -8,7 +8,8 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter -from diracx.core.exceptions import InvalidQueryError, JobNotFound + +from diracx.core.exceptions import InvalidQueryError, JobNotFoundError from diracx.core.models import ( LimitedJobStatusReturn, SearchSpec, @@ -42,12 +43,12 @@ class JobDB(BaseSQLDB): # TODO: this is copied from the DIRAC JobDB # but is overwriten in LHCbDIRAC, so we need # to find a way to make it dynamic - jdl2DBParameters = ["JobName", "JobType", "JobGroup"] + jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"] async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = _get_columns(Jobs.__table__, group_by) - stmt = select(*columns, func.count(Jobs.JobID).label("count")) + stmt = select(*columns, func.count(Jobs.job_id).label("count")) stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -110,11 +111,11 @@ async def insert_input_data(self, lfns: dict[int, list[str]]): ], ) - async def setJobAttributes(self, job_id, jobData): + async def set_job_attributes(self, job_id, job_data): """TODO: add myDate and force parameters.""" - if "Status" in jobData: - jobData = jobData | {"LastUpdateTime": datetime.now(tz=timezone.utc)} - stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData) + if "Status" in job_data: + job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)} + stmt = update(Jobs).where(Jobs.job_id == job_id).values(job_data) await self.conn.execute(stmt) async def create_job(self, original_jdl): @@ -159,9 +160,9 @@ async def update_job_jdls(self, jdls_to_update: dict[int, str]): ], ) - async def checkAndPrepareJob( + async def check_and_prepare_job( self, - jobID, + job_id, class_ad_job, class_ad_req, owner, @@ -178,8 +179,8 @@ async def checkAndPrepareJob( checkAndPrepareJob, ) - retVal = checkAndPrepareJob( - jobID, + ret_val = checkAndPrepareJob( + job_id, class_ad_job, class_ad_req, owner, @@ -188,21 +189,21 @@ async def checkAndPrepareJob( vo, ) - if not retVal["OK"]: - if cmpError(retVal, EWMSSUBM): - await self.setJobAttributes(jobID, job_attrs) + if not ret_val["OK"]: + if cmpError(ret_val, EWMSSUBM): + await self.set_job_attributes(job_id, job_attrs) - returnValueOrRaise(retVal) + returnValueOrRaise(ret_val) - async def setJobJDL(self, job_id, jdl): + async def set_job_jdl(self, job_id, jdl): from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL stmt = ( - update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl)) + update(JobJDLs).where(JobJDLs.job_id == job_id).values(JDL=compressJDL(jdl)) ) await self.conn.execute(stmt) - async def setJobJDLsBulk(self, jdls): + async def set_job_jdl_bulk(self, jdls): from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL await self.conn.execute( @@ -212,19 +213,19 @@ async def setJobJDLsBulk(self, jdls): [{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()], ) - async def setJobAttributesBulk(self, jobData): + async def set_job_attributes_bulk(self, job_data): """TODO: add myDate and force parameters.""" - for job_id in jobData.keys(): - if "Status" in jobData[job_id]: - jobData[job_id].update( + for job_id in job_data.keys(): + if "Status" in job_data[job_id]: + job_data[job_id].update( {"LastUpdateTime": datetime.now(tz=timezone.utc)} ) - columns = set(key for attrs in jobData.values() for key in attrs.keys()) + columns = set(key for attrs in job_data.values() for key in attrs.keys()) case_expressions = { column: case( *[ (Jobs.__table__.c.JobID == job_id, attrs[column]) - for job_id, attrs in jobData.items() + for job_id, attrs in job_data.items() if column in attrs ], else_=getattr(Jobs.__table__.c, column), # Retain original value @@ -235,33 +236,23 @@ async def setJobAttributesBulk(self, jobData): stmt = ( Jobs.__table__.update() .values(**case_expressions) - .where(Jobs.__table__.c.JobID.in_(jobData.keys())) + .where(Jobs.__table__.c.JobID.in_(job_data.keys())) ) await self.conn.execute(stmt) - async def getJobJDL(self, job_id: int, original: bool = False) -> str: - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL - - if original: - stmt = select(JobJDLs.OriginalJDL).where(JobJDLs.JobID == job_id) - else: - stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id) - - jdl = (await self.conn.execute(stmt)).scalar_one() - if jdl: - jdl = extractJDL(jdl) - - return jdl - - async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, str]: + async def get_job_jdls( + self, job_ids, original: bool = False + ) -> dict[int | str, str]: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL if original: - stmt = select(JobJDLs.JobID, JobJDLs.OriginalJDL).where( - JobJDLs.JobID.in_(job_ids) + stmt = select(JobJDLs.job_id, JobJDLs.original_jdl).where( + JobJDLs.job_id.in_(job_ids) ) else: - stmt = select(JobJDLs.JobID, JobJDLs.JDL).where(JobJDLs.JobID.in_(job_ids)) + stmt = select(JobJDLs.job_id, JobJDLs.jdl).where( + JobJDLs.job_id.in_(job_ids) + ) return { jobid: extractJDL(jdl) @@ -271,14 +262,14 @@ async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, s async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: try: - stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where( - Jobs.JobID == job_id - ) + stmt = select( + Jobs.status, Jobs.minor_status, Jobs.application_status + ).where(Jobs.job_id == job_id) return LimitedJobStatusReturn( **dict((await self.conn.execute(stmt)).one()._mapping) ) except NoResultFound as e: - raise JobNotFound(job_id) from e + raise JobNotFoundError(job_id) from e async def set_job_command(self, job_id: int, command: str, arguments: str = ""): """Store a command to be passed to the job together with the next heart beat.""" @@ -291,7 +282,7 @@ async def set_job_command(self, job_id: int, command: str, arguments: str = ""): ) await self.conn.execute(stmt) except IntegrityError as e: - raise JobNotFound(job_id) from e + raise JobNotFoundError(job_id) from e async def set_job_command_bulk(self, commands): """Store a command to be passed to the job together with the next heart beat.""" @@ -311,7 +302,7 @@ async def set_job_command_bulk(self, commands): async def delete_jobs(self, job_ids: list[int]): """Delete jobs from the database.""" - stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids)) + stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids)) await self.conn.execute(stmt) async def set_properties( @@ -344,7 +335,7 @@ async def set_properties( if update_timestamp: values["LastUpdateTime"] = datetime.now(tz=timezone.utc) - stmt = update(Jobs).where(Jobs.JobID == bindparam("job_id")).values(**values) + stmt = update(Jobs).where(Jobs.job_id == bindparam("job_id")).values(**values) rows = await self.conn.execute(stmt, update_parameters) return rows.rowcount diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index d17edf2d..eea1e3a1 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -17,34 +17,34 @@ class Jobs(JobDBBase): __tablename__ = "Jobs" - JobID = Column( + job_id = Column( "JobID", Integer, ForeignKey("JobJDLs.JobID", ondelete="CASCADE"), primary_key=True, default=0, ) - JobType = Column("JobType", String(32), default="user") - JobGroup = Column("JobGroup", String(32), default="00000000") - Site = Column("Site", String(100), default="ANY") - JobName = Column("JobName", String(128), default="Unknown") - Owner = Column("Owner", String(64), default="Unknown") - OwnerGroup = Column("OwnerGroup", String(128), default="Unknown") - VO = Column("VO", String(32)) - SubmissionTime = NullColumn("SubmissionTime", DateTime) - RescheduleTime = NullColumn("RescheduleTime", DateTime) - LastUpdateTime = NullColumn("LastUpdateTime", DateTime) - StartExecTime = NullColumn("StartExecTime", DateTime) - HeartBeatTime = NullColumn("HeartBeatTime", DateTime) - EndExecTime = NullColumn("EndExecTime", DateTime) - Status = Column("Status", String(32), default="Received") - MinorStatus = Column("MinorStatus", String(128), default="Unknown") - ApplicationStatus = Column("ApplicationStatus", String(255), default="Unknown") - UserPriority = Column("UserPriority", Integer, default=0) - RescheduleCounter = Column("RescheduleCounter", Integer, default=0) - VerifiedFlag = Column("VerifiedFlag", EnumBackedBool(), default=False) + job_type = Column("JobType", String(32), default="user") + job_group = Column("JobGroup", String(32), default="00000000") + site = Column("Site", String(100), default="ANY") + job_name = Column("JobName", String(128), default="Unknown") + owner = Column("Owner", String(64), default="Unknown") + owner_group = Column("OwnerGroup", String(128), default="Unknown") + vo = Column("VO", String(32)) + submission_time = NullColumn("SubmissionTime", DateTime) + reschedule_time = NullColumn("RescheduleTime", DateTime) + last_update_time = NullColumn("LastUpdateTime", DateTime) + start_exec_time = NullColumn("StartExecTime", DateTime) + heart_beat_time = NullColumn("HeartBeatTime", DateTime) + end_exec_time = NullColumn("EndExecTime", DateTime) + status = Column("Status", String(32), default="Received") + minor_status = Column("MinorStatus", String(128), default="Unknown") + application_status = Column("ApplicationStatus", String(255), default="Unknown") + user_priority = Column("UserPriority", Integer, default=0) + reschedule_counter = Column("RescheduleCounter", Integer, default=0) + verified_flag = Column("VerifiedFlag", EnumBackedBool(), default=False) # TODO: Should this be True/False/"Failed"? Or True/False/Null? - AccountedFlag = Column( + accounted_flag = Column( "AccountedFlag", Enum("True", "False", "Failed"), default="False" ) @@ -64,66 +64,66 @@ class Jobs(JobDBBase): class JobJDLs(JobDBBase): __tablename__ = "JobJDLs" - JobID = Column(Integer, autoincrement=True, primary_key=True) - JDL = Column(Text) - JobRequirements = Column(Text) - OriginalJDL = Column(Text) + job_id = Column("JobID", Integer, autoincrement=True, primary_key=True) + jdl = Column("JDL", Text) + job_requirements = Column("JobRequirements", Text) + original_jdl = Column("OriginalJDL", Text) class InputData(JobDBBase): __tablename__ = "InputData" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - LFN = Column(String(255), default="", primary_key=True) - Status = Column(String(32), default="AprioriGood") + lfn = Column("LFN", String(255), default="", primary_key=True) + status = Column("Status", String(32), default="AprioriGood") class JobParameters(JobDBBase): __tablename__ = "JobParameters" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) class OptimizerParameters(JobDBBase): __tablename__ = "OptimizerParameters" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) class AtticJobParameters(JobDBBase): __tablename__ = "AtticJobParameters" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) - RescheduleCycle = Column(Integer) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) + reschedule_cycle = Column("RescheduleCycle", Integer) class HeartBeatLoggingInfo(JobDBBase): __tablename__ = "HeartBeatLoggingInfo" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Name = Column(String(100), primary_key=True) - Value = Column(Text) - HeartBeatTime = Column(DateTime, primary_key=True) + name = Column("Name", String(100), primary_key=True) + value = Column("Value", Text) + heart_beat_time = Column("HeartBeatTime", DateTime, primary_key=True) class JobCommands(JobDBBase): __tablename__ = "JobCommands" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Command = Column(String(100)) - Arguments = Column(String(100)) - Status = Column(String(64), default="Received") - ReceptionTime = Column(DateTime, primary_key=True) - ExecutionTime = NullColumn(DateTime) + command = Column("Command", String(100)) + arguments = Column("Arguments", String(100)) + status = Column("Status", String(64), default="Received") + reception_time = Column("ReceptionTime", DateTime, primary_key=True) + execution_time = NullColumn("ExecutionTime", DateTime) diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index 9774c523..154671e0 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -12,7 +12,7 @@ from collections import defaultdict -from diracx.core.exceptions import JobNotFound +from diracx.core.exceptions import JobNotFoundError from diracx.core.models import ( JobStatus, JobStatusReturn, @@ -57,9 +57,9 @@ async def insert_record( as datetime.datetime object. If the time stamp is not provided the current UTC time is used. """ - # First, fetch the maximum SeqNum for the given job_id - seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)).where( - LoggingInfo.JobID == job_id + # First, fetch the maximum seq_num for the given job_id + seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)).where( + LoggingInfo.job_id == job_id ) seqnum = await self.conn.scalar(seqnum_stmt) @@ -70,14 +70,14 @@ async def insert_record( ) stmt = insert(LoggingInfo).values( - JobID=int(job_id), - SeqNum=seqnum, - Status=status, - MinorStatus=minor_status, - ApplicationStatus=application_status[:255], - StatusTime=date, - StatusTimeOrder=epoc, - Source=source[:32], + job_id=int(job_id), + seq_num=seqnum, + status=status, + minor_status=minor_status, + application_status=application_status[:255], + status_time=date, + status_time_order=epoc, + source=source[:32], ) await self.conn.execute(stmt) @@ -97,10 +97,10 @@ def get_epoc(date): # First, fetch the maximum SeqNums for the given job_ids seqnum_stmt = ( select( - LoggingInfo.JobID, func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1) + LoggingInfo.job_id, func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1) ) - .where(LoggingInfo.JobID.in_([record.job_id for record in records])) - .group_by(LoggingInfo.JobID) + .where(LoggingInfo.job_id.in_([record.job_id for record in records])) + .group_by(LoggingInfo.job_id) ) seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))} @@ -131,15 +131,15 @@ async def get_records(self, job_ids: list[int]) -> dict[int, JobStatusReturn]: # results later. stmt = ( select( - LoggingInfo.JobID, - LoggingInfo.Status, - LoggingInfo.MinorStatus, - LoggingInfo.ApplicationStatus, - LoggingInfo.StatusTime, - LoggingInfo.Source, + LoggingInfo.job_id, + LoggingInfo.status, + LoggingInfo.minor_status, + LoggingInfo.application_status, + LoggingInfo.status_time, + LoggingInfo.source, ) - .where(LoggingInfo.JobID.in_(job_ids)) - .order_by(LoggingInfo.StatusTimeOrder, LoggingInfo.StatusTime) + .where(LoggingInfo.job_id.in_(job_ids)) + .order_by(LoggingInfo.status_time_order, LoggingInfo.status_time) ) rows = await self.conn.execute(stmt) @@ -198,7 +198,7 @@ async def get_records(self, job_ids: list[int]) -> dict[int, JobStatusReturn]: async def delete_records(self, job_ids: list[int]): """Delete logging records for given jobs.""" - stmt = delete(LoggingInfo).where(LoggingInfo.JobID.in_(job_ids)) + stmt = delete(LoggingInfo).where(LoggingInfo.job_id.in_(job_ids)) await self.conn.execute(stmt) async def get_wms_time_stamps(self, job_id): @@ -207,12 +207,12 @@ async def get_wms_time_stamps(self, job_id): """ result = {} stmt = select( - LoggingInfo.Status, - LoggingInfo.StatusTimeOrder, - ).where(LoggingInfo.JobID == job_id) + LoggingInfo.status, + LoggingInfo.status_time_order, + ).where(LoggingInfo.job_id == job_id) rows = await self.conn.execute(stmt) if not rows.rowcount: - raise JobNotFound(job_id) from None + raise JobNotFoundError(job_id) from None for event, etime in rows: result[event] = str(etime + MAGIC_EPOC_NUMBER) @@ -225,10 +225,10 @@ async def get_wms_time_stamps_bulk(self, job_ids): """ result = defaultdict(dict) stmt = select( - LoggingInfo.JobID, - LoggingInfo.Status, - LoggingInfo.StatusTimeOrder, - ).where(LoggingInfo.JobID.in_(job_ids)) + LoggingInfo.job_id, + LoggingInfo.status, + LoggingInfo.status_time_order, + ).where(LoggingInfo.job_id.in_(job_ids)) rows = await self.conn.execute(stmt) if not rows.rowcount: return {} diff --git a/diracx-db/src/diracx/db/sql/job_logging/schema.py b/diracx-db/src/diracx/db/sql/job_logging/schema.py index 6f459c48..1c229bb7 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/schema.py +++ b/diracx-db/src/diracx/db/sql/job_logging/schema.py @@ -13,13 +13,15 @@ class LoggingInfo(JobLoggingDBBase): __tablename__ = "LoggingInfo" - JobID = Column(Integer) - SeqNum = Column(Integer) - Status = Column(String(32), default="") - MinorStatus = Column(String(128), default="") - ApplicationStatus = Column(String(255), default="") - StatusTime = DateNowColumn() + job_id = Column("JobID", Integer) + seq_num = Column("SeqNum", Integer) + status = Column("Status", String(32), default="") + minor_status = Column("MinorStatus", String(128), default="") + application_status = Column("ApplicationStatus", String(255), default="") + status_time = DateNowColumn("StatusTime") # TODO: Check that this corresponds to the DOUBLE(12,3) type in MySQL - StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0) - Source = Column(String(32), default="Unknown", name="StatusSource") + status_time_order = Column( + "StatusTimeOrder", Numeric(precision=12, scale=3), default=0 + ) + source = Column("StatusSource", String(32), default="Unknown") __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 7a2a0c5e..76cd5c89 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -16,22 +16,22 @@ class PilotAgents(PilotAgentsDBBase): __tablename__ = "PilotAgents" - PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True) - InitialJobID = Column("InitialJobID", Integer, default=0) - CurrentJobID = Column("CurrentJobID", Integer, default=0) - PilotJobReference = Column("PilotJobReference", String(255), default="Unknown") - PilotStamp = Column("PilotStamp", String(32), default="") - DestinationSite = Column("DestinationSite", String(128), default="NotAssigned") - Queue = Column("Queue", String(128), default="Unknown") - GridSite = Column("GridSite", String(128), default="Unknown") - VO = Column("VO", String(128)) - GridType = Column("GridType", String(32), default="LCG") - BenchMark = Column("BenchMark", Double, default=0.0) - SubmissionTime = NullColumn("SubmissionTime", DateTime) - LastUpdateTime = NullColumn("LastUpdateTime", DateTime) - Status = Column("Status", String(32), default="Unknown") - StatusReason = Column("StatusReason", String(255), default="Unknown") - AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False) + pilot_id = Column("PilotID", Integer, autoincrement=True, primary_key=True) + initial_job_id = Column("InitialJobID", Integer, default=0) + current_job_id = Column("CurrentJobID", Integer, default=0) + pilot_job_reference = Column("PilotJobReference", String(255), default="Unknown") + pilot_stamp = Column("PilotStamp", String(32), default="") + destination_site = Column("DestinationSite", String(128), default="NotAssigned") + queue = Column("Queue", String(128), default="Unknown") + grid_site = Column("GridSite", String(128), default="Unknown") + vo = Column("VO", String(128)) + grid_type = Column("GridType", String(32), default="LCG") + benchmark = Column("BenchMark", Double, default=0.0) + submission_time = NullColumn("SubmissionTime", DateTime) + last_update_time = NullColumn("LastUpdateTime", DateTime) + status = Column("Status", String(32), default="Unknown") + status_reason = Column("StatusReason", String(255), default="Unknown") + accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False) __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), @@ -43,9 +43,9 @@ class PilotAgents(PilotAgentsDBBase): class JobToPilotMapping(PilotAgentsDBBase): __tablename__ = "JobToPilotMapping" - PilotID = Column("PilotID", Integer, primary_key=True) - JobID = Column("JobID", Integer, primary_key=True) - StartTime = Column("StartTime", DateTime) + pilot_id = Column("PilotID", Integer, primary_key=True) + job_id = Column("JobID", Integer, primary_key=True) + start_time = Column("StartTime", DateTime) __table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID")) @@ -53,6 +53,6 @@ class JobToPilotMapping(PilotAgentsDBBase): class PilotOutput(PilotAgentsDBBase): __tablename__ = "PilotOutput" - PilotID = Column("PilotID", Integer, primary_key=True) - StdOutput = Column("StdOutput", Text) - StdError = Column("StdError", Text) + pilot_id = Column("PilotID", Integer, primary_key=True) + std_output = Column("StdOutput", Text) + std_error = Column("StdError", Text) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py index db72a7f9..28462778 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py @@ -5,10 +5,10 @@ import sqlalchemy from diracx.core.models import SandboxInfo, SandboxType, UserInfo -from diracx.db.sql.utils import BaseSQLDB, utcnow +from diracx.db.sql.utils import BaseSQLDB, UTCNow from .schema import Base as SandboxMetadataDBBase -from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes +from .schema import SandBoxes, SBEntityMapping, SBOwners class SandboxMetadataDB(BaseSQLDB): @@ -17,16 +17,16 @@ class SandboxMetadataDB(BaseSQLDB): async def upsert_owner(self, user: UserInfo) -> int: """Get the id of the owner from the database.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 - stmt = sqlalchemy.select(sb_Owners.OwnerID).where( - sb_Owners.Owner == user.preferred_username, - sb_Owners.OwnerGroup == user.dirac_group, - sb_Owners.VO == user.vo, + stmt = sqlalchemy.select(SBOwners.OwnerID).where( + SBOwners.Owner == user.preferred_username, + SBOwners.OwnerGroup == user.dirac_group, + SBOwners.VO == user.vo, ) result = await self.conn.execute(stmt) if owner_id := result.scalar_one_or_none(): return owner_id - stmt = sqlalchemy.insert(sb_Owners).values( + stmt = sqlalchemy.insert(SBOwners).values( Owner=user.preferred_username, OwnerGroup=user.dirac_group, VO=user.vo, @@ -53,13 +53,13 @@ async def insert_sandbox( """Add a new sandbox in SandboxMetadataDB.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 owner_id = await self.upsert_owner(user) - stmt = sqlalchemy.insert(sb_SandBoxes).values( + stmt = sqlalchemy.insert(SandBoxes).values( OwnerId=owner_id, SEName=se_name, SEPFN=pfn, Bytes=size, - RegistrationTime=utcnow(), - LastAccessTime=utcnow(), + RegistrationTime=UTCNow(), + LastAccessTime=UTCNow(), ) try: result = await self.conn.execute(stmt) @@ -70,17 +70,17 @@ async def insert_sandbox( async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None: stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn) - .values(LastAccessTime=utcnow()) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn) + .values(LastAccessTime=UTCNow()) ) result = await self.conn.execute(stmt) assert result.rowcount == 1 async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool: """Checks if a sandbox exists and has been assigned.""" - stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where( - sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn + stmt: sqlalchemy.Executable = sqlalchemy.select(SandBoxes.Assigned).where( + SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn ) result = await self.conn.execute(stmt) is_assigned = result.scalar_one() @@ -97,11 +97,11 @@ async def get_sandbox_assigned_to_job( """Get the sandbox assign to job.""" entity_id = self.jobid_to_entity_id(job_id) stmt = ( - sqlalchemy.select(sb_SandBoxes.SEPFN) - .where(sb_SandBoxes.SBId == sb_EntityMapping.SBId) + sqlalchemy.select(SandBoxes.SEPFN) + .where(SandBoxes.SBId == SBEntityMapping.SBId) .where( - sb_EntityMapping.EntityId == entity_id, - sb_EntityMapping.Type == sb_type, + SBEntityMapping.EntityId == entity_id, + SBEntityMapping.Type == sb_type, ) ) result = await self.conn.execute(stmt) @@ -119,21 +119,21 @@ async def assign_sandbox_to_jobs( # Define the entity id as 'Entity:entity_id' due to the DB definition: entity_id = self.jobid_to_entity_id(job_id) select_sb_id = sqlalchemy.select( - sb_SandBoxes.SBId, + SandBoxes.SBId, sqlalchemy.literal(entity_id).label("EntityId"), sqlalchemy.literal(sb_type).label("Type"), ).where( - sb_SandBoxes.SEName == se_name, - sb_SandBoxes.SEPFN == pfn, + SandBoxes.SEName == se_name, + SandBoxes.SEPFN == pfn, ) - stmt = sqlalchemy.insert(sb_EntityMapping).from_select( + stmt = sqlalchemy.insert(SBEntityMapping).from_select( ["SBId", "EntityId", "Type"], select_sb_id ) await self.conn.execute(stmt) stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SEPFN == pfn) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SEPFN == pfn) .values(Assigned=True) ) result = await self.conn.execute(stmt) @@ -143,29 +143,29 @@ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None: """Delete mapping between jobs and sandboxes.""" for job_id in jobs_ids: entity_id = self.jobid_to_entity_id(job_id) - sb_sel_stmt = sqlalchemy.select(sb_SandBoxes.SBId) + sb_sel_stmt = sqlalchemy.select(SandBoxes.SBId) sb_sel_stmt = sb_sel_stmt.join( - sb_EntityMapping, sb_EntityMapping.SBId == sb_SandBoxes.SBId + SBEntityMapping, SBEntityMapping.SBId == SandBoxes.SBId ) - sb_sel_stmt = sb_sel_stmt.where(sb_EntityMapping.EntityId == entity_id) + sb_sel_stmt = sb_sel_stmt.where(SBEntityMapping.EntityId == entity_id) result = await self.conn.execute(sb_sel_stmt) sb_ids = [row.SBId for row in result] - del_stmt = sqlalchemy.delete(sb_EntityMapping).where( - sb_EntityMapping.EntityId == entity_id + del_stmt = sqlalchemy.delete(SBEntityMapping).where( + SBEntityMapping.EntityId == entity_id ) await self.conn.execute(del_stmt) - sb_entity_sel_stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( - sb_EntityMapping.SBId.in_(sb_ids) + sb_entity_sel_stmt = sqlalchemy.select(SBEntityMapping.SBId).where( + SBEntityMapping.SBId.in_(sb_ids) ) result = await self.conn.execute(sb_entity_sel_stmt) remaining_sb_ids = [row.SBId for row in result] if not remaining_sb_ids: unassign_stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SBId.in_(sb_ids)) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SBId.in_(sb_ids)) .values(Assigned=False) ) await self.conn.execute(unassign_stmt) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py index 8c849c67..5864ea42 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py @@ -14,7 +14,7 @@ Base = declarative_base() -class sb_Owners(Base): +class SBOwners(Base): __tablename__ = "sb_Owners" OwnerID = Column(Integer, autoincrement=True) Owner = Column(String(32)) @@ -23,7 +23,7 @@ class sb_Owners(Base): __table_args__ = (PrimaryKeyConstraint("OwnerID"),) -class sb_SandBoxes(Base): +class SandBoxes(Base): __tablename__ = "sb_SandBoxes" SBId = Column(Integer, autoincrement=True) OwnerId = Column(Integer) @@ -40,7 +40,7 @@ class sb_SandBoxes(Base): ) -class sb_EntityMapping(Base): +class SBEntityMapping(Base): __tablename__ = "sb_EntityMapping" SBId = Column(Integer) EntityId = Column(String(128)) diff --git a/diracx-db/src/diracx/db/sql/task_queue/db.py b/diracx-db/src/diracx/db/sql/task_queue/db.py index 537f128e..ff701509 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/db.py +++ b/diracx-db/src/diracx/db/sql/task_queue/db.py @@ -121,12 +121,12 @@ async def recalculate_tq_shares_for_entity( # TODO: I guess the rows are already a list of tupes # maybe refactor data = [(r[0], r[1]) for r in rows if r] - numOwners = len(data) + num_owners = len(data) # If there are no owners do now - if numOwners == 0: + if num_owners == 0: return # Split the share amongst the number of owners - entities_shares = {row[0]: job_share / numOwners for row in data} + entities_shares = {row[0]: job_share / num_owners for row in data} # TODO: implement the following # If corrector is enabled let it work it's magic diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 390588e6..eafc4d3b 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -26,7 +26,7 @@ from diracx.core.extensions import select_from_extension from diracx.core.models import SortDirection from diracx.core.settings import SqlalchemyDsn -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError if TYPE_CHECKING: from sqlalchemy.types import TypeEngine @@ -34,32 +34,32 @@ logger = logging.getLogger(__name__) -class utcnow(expression.FunctionElement): +class UTCNow(expression.FunctionElement): type: TypeEngine = DateTime() inherit_cache: bool = True -@compiles(utcnow, "postgresql") +@compiles(UTCNow, "postgresql") def pg_utcnow(element, compiler, **kw) -> str: return "TIMEZONE('utc', CURRENT_TIMESTAMP)" -@compiles(utcnow, "mssql") +@compiles(UTCNow, "mssql") def ms_utcnow(element, compiler, **kw) -> str: return "GETUTCDATE()" -@compiles(utcnow, "mysql") +@compiles(UTCNow, "mysql") def mysql_utcnow(element, compiler, **kw) -> str: return "(UTC_TIMESTAMP)" -@compiles(utcnow, "sqlite") +@compiles(UTCNow, "sqlite") def sqlite_utcnow(element, compiler, **kw) -> str: return "DATETIME('now')" -class date_trunc(expression.FunctionElement): +class DateTrunc(expression.FunctionElement): """Sqlalchemy function to truncate a date to a given resolution. Primarily used to be able to query for a specific resolution of a date e.g. @@ -77,7 +77,7 @@ def __init__(self, *args, time_resolution, **kwargs) -> None: self._time_resolution = time_resolution -@compiles(date_trunc, "postgresql") +@compiles(DateTrunc, "postgresql") def pg_date_trunc(element, compiler, **kw): res = { "SECOND": "second", @@ -90,7 +90,7 @@ def pg_date_trunc(element, compiler, **kw): return f"date_trunc('{res}', {compiler.process(element.clauses)})" -@compiles(date_trunc, "mysql") +@compiles(DateTrunc, "mysql") def mysql_date_trunc(element, compiler, **kw): pattern = { "SECOND": "%Y-%m-%d %H:%i:%S", @@ -105,7 +105,7 @@ def mysql_date_trunc(element, compiler, **kw): return compiler.process(func.date_format(dt_col, pattern)) -@compiles(date_trunc, "sqlite") +@compiles(DateTrunc, "sqlite") def sqlite_date_trunc(element, compiler, **kw): pattern = { "SECOND": "%Y-%m-%d %H:%M:%S", @@ -130,11 +130,11 @@ def substract_date(**kwargs: float) -> datetime: Column: partial[RawColumn] = partial(RawColumn, nullable=False) NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True) -DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow()) +DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=UTCNow()) -def EnumColumn(enum_type, **kwargs): - return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) +def EnumColumn(name, enum_type, **kwargs): # noqa: N802 + return Column(name, Enum(enum_type, native_enum=False, length=16), **kwargs) class EnumBackedBool(types.TypeDecorator): @@ -167,7 +167,7 @@ class SQLDBError(Exception): pass -class SQLDBUnavailable(DBUnavailable, SQLDBError): +class SQLDBUnavailableError(DBUnavailableError, SQLDBError): """Used whenever we encounter a problem with the B connection.""" @@ -324,7 +324,7 @@ async def __aenter__(self) -> Self: try: self._conn.set(await self.engine.connect().__aenter__()) except Exception as e: - raise SQLDBUnavailable( + raise SQLDBUnavailableError( f"Cannot connect to {self.__class__.__name__}" ) from e @@ -350,7 +350,7 @@ async def ping(self): try: await self.conn.scalar(select(1)) except OperationalError as e: - raise SQLDBUnavailable("Cannot ping the DB") from e + raise SQLDBUnavailableError("Cannot ping the DB") from e def find_time_resolution(value): @@ -394,7 +394,7 @@ def apply_search_filters(column_mapping, stmt, search): if "value" in query and isinstance(query["value"], str): resolution, value = find_time_resolution(query["value"]) if resolution: - column = date_trunc(column, time_resolution=resolution) + column = DateTrunc(column, time_resolution=resolution) query["value"] = value if query.get("values"): @@ -406,7 +406,7 @@ def apply_search_filters(column_mapping, stmt, search): f"Cannot mix different time resolutions in {query=}" ) if resolution := resolutions[0]: - column = date_trunc(column, time_resolution=resolution) + column = DateTrunc(column, time_resolution=resolution) query["values"] = values if query["operator"] == "eq": diff --git a/diracx-db/src/diracx/db/sql/utils/job.py b/diracx-db/src/diracx/db/sql/utils/job.py index 16ed5ba7..87763d45 100644 --- a/diracx-db/src/diracx/db/sql/utils/job.py +++ b/diracx-db/src/diracx/db/sql/utils/job.py @@ -49,7 +49,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): async with asyncio.TaskGroup() as tg: for job in jobs: original_jdl = deepcopy(job.jdl) - jobManifest = returnValueOrRaise( + job_manifest = returnValueOrRaise( checkAndAddOwner(original_jdl, job.owner, job.owner_group) ) @@ -60,13 +60,13 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): original_jdls.append( ( original_jdl, - jobManifest, + job_manifest, tg.create_task(job_db.create_job(original_jdl)), ) ) async with asyncio.TaskGroup() as tg: - for job, (original_jdl, jobManifest_, job_id_task) in zip(jobs, original_jdls): + for job, (original_jdl, job_manifest_, job_id_task) in zip(jobs, original_jdls): job_id = job_id_task.result() job_attrs = { "JobID": job_id, @@ -77,16 +77,16 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): "VO": job.vo, } - jobManifest_.setOption("JobID", job_id) + job_manifest_.setOption("JobID", job_id) # 2.- Check JDL and Prepare DIRAC JDL - jobJDL = jobManifest_.dumpAsJDL() + job_jdl = job_manifest_.dumpAsJDL() # Replace the JobID placeholder if any - if jobJDL.find("%j") != -1: - jobJDL = jobJDL.replace("%j", str(job_id)) + if job_jdl.find("%j") != -1: + job_jdl = job_jdl.replace("%j", str(job_id)) - class_ad_job = ClassAd(jobJDL) + class_ad_job = ClassAd(job_jdl) class_ad_req = ClassAd("[]") if not class_ad_job.isOK(): @@ -99,7 +99,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): # TODO is this even needed? class_ad_job.insertAttributeInt("JobID", job_id) - await job_db.checkAndPrepareJob( + await job_db.check_and_prepare_job( job_id, class_ad_job, class_ad_req, @@ -108,10 +108,10 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): job_attrs, job.vo, ) - jobJDL = createJDLWithInitialStatus( + job_jdl = createJDLWithInitialStatus( class_ad_job, class_ad_req, - job_db.jdl2DBParameters, + job_db.jdl_2_db_parameters, job_attrs, job.initial_status, job.initial_minor_status, @@ -119,11 +119,11 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): ) jobs_to_insert[job_id] = job_attrs - jdls_to_update[job_id] = jobJDL + jdls_to_update[job_id] = job_jdl if class_ad_job.lookupAttribute("InputData"): - inputData = class_ad_job.getListFromExpression("InputData") - inputdata_to_insert[job_id] = [lfn for lfn in inputData if lfn] + input_data = class_ad_job.getListFromExpression("InputData") + inputdata_to_insert[job_id] = [lfn for lfn in input_data if lfn] tg.create_task(job_db.update_job_jdls(jdls_to_update)) tg.create_task(job_db.insert_job_attributes(jobs_to_insert)) @@ -243,7 +243,7 @@ def parse_jdl(job_id, job_jdl): job_jdls = { jobid: parse_jdl(jobid, jdl) for jobid, jdl in ( - (await job_db.getJobJDLs(surviving_job_ids, original=True)).items() + (await job_db.get_job_jdls(surviving_job_ids, original=True)).items() ) } @@ -251,7 +251,7 @@ def parse_jdl(job_id, job_jdl): class_ad_job = job_jdls[job_id] class_ad_req = ClassAd("[]") try: - await job_db.checkAndPrepareJob( + await job_db.check_and_prepare_job( job_id, class_ad_job, class_ad_req, @@ -277,11 +277,11 @@ def parse_jdl(job_id, job_jdl): else: site = site_list[0] - reqJDL = class_ad_req.asJDL() - class_ad_job.insertAttributeInt("JobRequirements", reqJDL) - jobJDL = class_ad_job.asJDL() + req_jdl = class_ad_req.asJDL() + class_ad_job.insertAttributeInt("JobRequirements", req_jdl) + job_jdl = class_ad_job.asJDL() # Replace the JobID placeholder if any - jobJDL = jobJDL.replace("%j", str(job_id)) + job_jdl = job_jdl.replace("%j", str(job_id)) additional_attrs = { "Site": site, @@ -291,7 +291,7 @@ def parse_jdl(job_id, job_jdl): } # set new JDL - jdl_changes[job_id] = jobJDL + jdl_changes[job_id] = job_jdl # set new status status_changes[job_id] = { @@ -319,7 +319,7 @@ def parse_jdl(job_id, job_jdl): # BULK JDL UPDATE # DATABASE OPERATION - await job_db.setJobJDLsBulk(jdl_changes) + await job_db.set_job_jdl_bulk(jdl_changes) return { "failed": failed, @@ -412,40 +412,40 @@ async def set_job_status_bulk( for res in results: job_id = int(res["JobID"]) - currentStatus = res["Status"] - startTime = res["StartExecTime"] - endTime = res["EndExecTime"] + current_status = res["Status"] + start_time = res["StartExecTime"] + end_time = res["EndExecTime"] # If the current status is Stalled and we get an update, it should probably be "Running" - if currentStatus == JobStatus.STALLED: - currentStatus = JobStatus.RUNNING + if current_status == JobStatus.STALLED: + current_status = JobStatus.RUNNING ##################################################################################################### - statusDict = status_dicts[job_id] - # This is more precise than "LastTime". timeStamps is a sorted list of tuples... - timeStamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items()) - lastTime = TimeUtilities.fromEpoch(timeStamps[-1][0]).replace( + status_dict = status_dicts[job_id] + # This is more precise than "LastTime". time_stamps is a sorted list of tuples... + time_stamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items()) + last_time = TimeUtilities.fromEpoch(time_stamps[-1][0]).replace( tzinfo=timezone.utc ) # Get chronological order of new updates - updateTimes = sorted(statusDict) + update_times = sorted(status_dict) - newStartTime, newEndTime = getStartAndEndTime( - startTime, endTime, updateTimes, timeStamps, statusDict + new_start_time, new_end_time = getStartAndEndTime( + start_time, end_time, update_times, time_stamps, status_dict ) job_data: dict[str, str] = {} new_status: str | None = None - if updateTimes[-1] >= lastTime: + if update_times[-1] >= last_time: new_status, new_minor, new_application = ( returnValueOrRaise( # TODO: Catch this getNewStatus( job_id, - updateTimes, - lastTime, - statusDict, - currentStatus, + update_times, + last_time, + status_dict, + current_status, force, MagicMock(), # FIXME ) @@ -467,15 +467,15 @@ async def set_job_status_bulk( # if not result["OK"]: # return result - for updTime in updateTimes: - if statusDict[updTime]["Source"].startswith("Job"): - job_data["HeartBeatTime"] = str(updTime) + for upd_time in update_times: + if status_dict[upd_time]["Source"].startswith("Job"): + job_data["HeartBeatTime"] = str(upd_time) - if not startTime and newStartTime: - job_data["StartExecTime"] = newStartTime + if not start_time and new_start_time: + job_data["StartExecTime"] = new_start_time - if not endTime and newEndTime: - job_data["EndExecTime"] = newEndTime + if not end_time and new_end_time: + job_data["EndExecTime"] = new_end_time ##################################################################################################### # delete or kill job, if we transition to DELETED or KILLED state @@ -486,20 +486,20 @@ async def set_job_status_bulk( if job_data: job_attribute_updates[job_id] = job_data - for updTime in updateTimes: - sDict = statusDict[updTime] + for upd_time in update_times: + s_dict = status_dict[upd_time] job_logging_updates.append( JobLoggingRecord( job_id=job_id, - status=sDict.get("Status", "idem"), - minor_status=sDict.get("MinorStatus", "idem"), - application_status=sDict.get("ApplicationStatus", "idem"), - date=updTime, - source=sDict.get("Source", "Unknown"), + status=s_dict.get("Status", "idem"), + minor_status=s_dict.get("MinorStatus", "idem"), + application_status=s_dict.get("ApplicationStatus", "idem"), + date=upd_time, + source=s_dict.get("Source", "Unknown"), ) ) - await job_db.setJobAttributesBulk(job_attribute_updates) + await job_db.set_job_attributes_bulk(job_attribute_updates) await remove_jobs_from_task_queue( list(deletable_killable_jobs), config, task_queue_db, background_task diff --git a/diracx-db/tests/auth/test_authorization_flow.py b/diracx-db/tests/auth/test_authorization_flow.py index 153896a9..240cd55e 100644 --- a/diracx-db/tests/auth/test_authorization_flow.py +++ b/diracx-db/tests/auth/test_authorization_flow.py @@ -49,7 +49,7 @@ async def test_insert_id_token(auth_db: AuthDB): with pytest.raises(NoResultFound): await auth_db.get_authorization_flow(code, EXPIRED) res = await auth_db.get_authorization_flow(code, MAX_VALIDITY) - assert res["id_token"] == id_token + assert res["IDToken"] == id_token # Cannot add a id_token after finishing the flow async with auth_db as auth_db: diff --git a/diracx-db/tests/auth/test_device_flow.py b/diracx-db/tests/auth/test_device_flow.py index e1cb0e6b..45093d2e 100644 --- a/diracx-db/tests/auth/test_device_flow.py +++ b/diracx-db/tests/auth/test_device_flow.py @@ -107,8 +107,8 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): await auth_db.get_device_flow(device_code1, EXPIRED) res = await auth_db.get_device_flow(device_code1, MAX_VALIDITY) - assert res["user_code"] == user_code1 - assert res["id_token"] == {"token": "mytoken"} + assert res["UserCode"] == user_code1 + assert res["IDToken"] == {"token": "mytoken"} # cannot get it a second time async with auth_db as auth_db: @@ -147,4 +147,4 @@ async def test_device_flow_insert_id_token(auth_db: AuthDB): async with auth_db as auth_db: res = await auth_db.get_device_flow(device_code, MAX_VALIDITY) - assert res["id_token"] == id_token + assert res["IDToken"] == id_token diff --git a/diracx-db/tests/auth/test_refresh_token.py b/diracx-db/tests/auth/test_refresh_token.py index 2b0cb4f0..2d72cef0 100644 --- a/diracx-db/tests/auth/test_refresh_token.py +++ b/diracx-db/tests/auth/test_refresh_token.py @@ -55,16 +55,21 @@ async def test_get(auth_db: AuthDB): ) # Enrich the dict with the generated refresh token attributes - refresh_token_details["jti"] = jti - refresh_token_details["status"] = RefreshTokenStatus.CREATED - refresh_token_details["creation_time"] = creation_time + expected_refresh_token = { + "Sub": refresh_token_details["sub"], + "PreferredUsername": refresh_token_details["preferred_username"], + "Scope": refresh_token_details["scope"], + "JTI": jti, + "Status": RefreshTokenStatus.CREATED, + "CreationTime": creation_time, + } # Get refresh token details async with auth_db as auth_db: result = await auth_db.get_refresh_token(jti) # Make sure they are identical - assert result == refresh_token_details + assert result == expected_refresh_token async def test_get_user_refresh_tokens(auth_db: AuthDB): @@ -96,11 +101,11 @@ async def test_get_user_refresh_tokens(auth_db: AuthDB): # And check that the subject value corresponds to the user's subject assert len(refresh_tokens_user1) == 2 for refresh_token in refresh_tokens_user1: - assert refresh_token["sub"] == sub1 + assert refresh_token["Sub"] == sub1 assert len(refresh_tokens_user2) == 1 for refresh_token in refresh_tokens_user2: - assert refresh_token["sub"] == sub2 + assert refresh_token["Sub"] == sub2 async def test_revoke(auth_db: AuthDB): @@ -121,7 +126,7 @@ async def test_revoke(auth_db: AuthDB): async with auth_db as auth_db: refresh_token_details = await auth_db.get_refresh_token(jti) - assert refresh_token_details["status"] == RefreshTokenStatus.REVOKED + assert refresh_token_details["Status"] == RefreshTokenStatus.REVOKED async def test_revoke_user_refresh_tokens(auth_db: AuthDB): @@ -194,7 +199,7 @@ async def test_revoke_and_get_user_refresh_tokens(auth_db: AuthDB): # And check that the subject value corresponds to the user's subject assert len(refresh_tokens_user) == nb_tokens for refresh_token in refresh_tokens_user: - assert refresh_token["sub"] == sub + assert refresh_token["Sub"] == sub # Revoke one of the tokens async with auth_db as auth_db: @@ -208,8 +213,8 @@ async def test_revoke_and_get_user_refresh_tokens(auth_db: AuthDB): # And check that the subject value corresponds to the user's subject assert len(refresh_tokens_user) == nb_tokens - 1 for refresh_token in refresh_tokens_user: - assert refresh_token["sub"] == sub - assert refresh_token["jti"] != jtis[0] + assert refresh_token["Sub"] == sub + assert refresh_token["JTI"] != jtis[0] async def test_get_refresh_tokens(auth_db: AuthDB): diff --git a/diracx-db/tests/jobs/test_jobDB.py b/diracx-db/tests/jobs/test_job_db.py similarity index 99% rename from diracx-db/tests/jobs/test_jobDB.py rename to diracx-db/tests/jobs/test_job_db.py index aa17035b..060bd7d8 100644 --- a/diracx-db/tests/jobs/test_jobDB.py +++ b/diracx-db/tests/jobs/test_job_db.py @@ -2,7 +2,7 @@ import pytest -from diracx.core.exceptions import InvalidQueryError, JobNotFound +from diracx.core.exceptions import InvalidQueryError, JobNotFoundError from diracx.core.models import ( ScalarSearchOperator, ScalarSearchSpec, @@ -333,5 +333,5 @@ async def test_search_pagination(job_db): async def test_set_job_command_invalid_job_id(job_db: JobDB): """Test that setting a command for a non-existent job raises JobNotFound.""" async with job_db as job_db: - with pytest.raises(JobNotFound): + with pytest.raises(JobNotFoundError): await job_db.set_job_command(123456, "test_command") diff --git a/diracx-db/tests/jobs/test_jobLoggingDB.py b/diracx-db/tests/jobs/test_job_logging_db.py similarity index 100% rename from diracx-db/tests/jobs/test_jobLoggingDB.py rename to diracx-db/tests/jobs/test_job_logging_db.py diff --git a/diracx-db/tests/jobs/test_sandbox_metadata.py b/diracx-db/tests/jobs/test_sandbox_metadata.py index 06149189..bcb1c2cc 100644 --- a/diracx-db/tests/jobs/test_sandbox_metadata.py +++ b/diracx-db/tests/jobs/test_sandbox_metadata.py @@ -9,7 +9,7 @@ from diracx.core.models import SandboxInfo, UserInfo from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB -from diracx.db.sql.sandbox_metadata.schema import sb_EntityMapping, sb_SandBoxes +from diracx.db.sql.sandbox_metadata.schema import SandBoxes, SBEntityMapping @pytest.fixture @@ -89,7 +89,7 @@ async def _dump_db( """ async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_SandBoxes.SEPFN, sb_SandBoxes.OwnerId, sb_SandBoxes.LastAccessTime + SandBoxes.SEPFN, SandBoxes.OwnerId, SandBoxes.LastAccessTime ) res = await sandbox_metadata_db.conn.execute(stmt) return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res} @@ -109,7 +109,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( await sandbox_metadata_db.insert_sandbox(sandbox_se, user_info, pfn, 100) async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + stmt = sqlalchemy.select(SandBoxes.SBId, SandBoxes.SEPFN) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SEPFN: row.SBId for row in res} sb_id_1 = db_contents[pfn] @@ -120,7 +120,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Check there is no mapping async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + SBEntityMapping.SBId, SBEntityMapping.EntityId, SBEntityMapping.Type ) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} @@ -134,7 +134,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Check if sandbox and job are mapped async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + SBEntityMapping.SBId, SBEntityMapping.EntityId, SBEntityMapping.Type ) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} @@ -144,7 +144,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( assert sb_type == "Output" async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + stmt = sqlalchemy.select(SandBoxes.SBId, SandBoxes.SEPFN) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SEPFN: row.SBId for row in res} sb_id_1 = db_contents[pfn] @@ -158,8 +158,8 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Entity should not exists anymore async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( - sb_EntityMapping.EntityId == entity_id_1 + stmt = sqlalchemy.select(SBEntityMapping.SBId).where( + SBEntityMapping.EntityId == entity_id_1 ) res = await sandbox_metadata_db.conn.execute(stmt) entity_sb_id = [row.SBId for row in res] @@ -170,7 +170,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( assert await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se) is False # Check the mapping has been deleted async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_EntityMapping.SBId) + stmt = sqlalchemy.select(SBEntityMapping.SBId) res = await sandbox_metadata_db.conn.execute(stmt) res_sb_id = [row.SBId for row in res] assert sb_id_1 not in res_sb_id diff --git a/diracx-db/tests/opensearch/test_connection.py b/diracx-db/tests/opensearch/test_connection.py index 4b2e3877..1e61760f 100644 --- a/diracx-db/tests/opensearch/test_connection.py +++ b/diracx-db/tests/opensearch/test_connection.py @@ -2,7 +2,7 @@ import pytest -from diracx.db.os.utils import OpenSearchDBUnavailable +from diracx.db.os.utils import OpenSearchDBUnavailableError from diracx.testing.osdb import OPENSEARCH_PORT, DummyOSDB, require_port_availability @@ -10,7 +10,7 @@ async def _ensure_db_unavailable(db: DummyOSDB): """Helper function which raises an exception if we manage to connect to the DB.""" async with db.client_context(): async with db: - with pytest.raises(OpenSearchDBUnavailable): + with pytest.raises(OpenSearchDBUnavailableError): await db.ping() diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py similarity index 100% rename from diracx-db/tests/pilot_agents/test_pilotAgentsDB.py rename to diracx-db/tests/pilot_agents/test_pilot_agents_db.py diff --git a/diracx-db/tests/test_dummyDB.py b/diracx-db/tests/test_dummy_db.py similarity index 90% rename from diracx-db/tests/test_dummyDB.py rename to diracx-db/tests/test_dummy_db.py index 90ed15d0..023899ce 100644 --- a/diracx-db/tests/test_dummyDB.py +++ b/diracx-db/tests/test_dummy_db.py @@ -7,7 +7,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.db.sql.dummy.db import DummyDB -from diracx.db.sql.utils import SQLDBUnavailable +from diracx.db.sql.utils import SQLDBUnavailableError # Each DB test class must defined a fixture looking like this one # It allows to get an instance of an in memory DB, @@ -27,7 +27,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): # So it is important to write test this way async with dummy_db as dummy_db: # First we check that the DB is empty - result = await dummy_db.summary(["model"], []) + result = await dummy_db.summary(["Model"], []) assert not result # Now we add some data in the DB @@ -44,14 +44,14 @@ async def test_insert_and_summary(dummy_db: DummyDB): # Check that there are now 10 cars assigned to a single driver async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # Test the selection async with dummy_db as dummy_db: result = await dummy_db.summary( - ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] + ["OwnerID"], [{"parameter": "Model", "operator": "eq", "value": "model_1"}] ) assert result[0]["count"] == 1 @@ -59,7 +59,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async with dummy_db as dummy_db: with pytest.raises(InvalidQueryError): result = await dummy_db.summary( - ["ownerID"], + ["OwnerID"], [ { "parameter": "model", @@ -73,7 +73,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async def test_bad_connection(): dummy_db = DummyDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") async with dummy_db.engine_context(): - with pytest.raises(SQLDBUnavailable): + with pytest.raises(SQLDBUnavailableError): async with dummy_db: dummy_db.ping() @@ -93,7 +93,7 @@ async def test_successful_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -104,7 +104,7 @@ async def test_successful_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -114,7 +114,7 @@ async def test_successful_transaction(dummy_db): # Start a new transaction # The previous data should still be there because the transaction was committed (successful) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 @@ -134,7 +134,7 @@ async def test_failed_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -159,7 +159,7 @@ async def test_failed_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result @@ -203,7 +203,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -217,7 +217,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # This will raise an exception but the transaction will be rolled back @@ -231,7 +231,7 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Start a new transaction, this time we commit it manually @@ -240,7 +240,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert not result # Add data @@ -254,7 +254,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # Manually commit the transaction, and then raise an exception @@ -271,5 +271,5 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should be there because the transaction was committed before the exception async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 diff --git a/diracx-routers/src/diracx/routers/__init__.py b/diracx-routers/src/diracx/routers/__init__.py index b3725b2f..d17fbd8f 100644 --- a/diracx-routers/src/diracx/routers/__init__.py +++ b/diracx-routers/src/diracx/routers/__init__.py @@ -34,11 +34,11 @@ from uvicorn.logging import AccessFormatter, DefaultFormatter from diracx.core.config import ConfigSource -from diracx.core.exceptions import DiracError, DiracHttpResponse +from diracx.core.exceptions import DiracError, DiracHttpResponseError from diracx.core.extensions import select_from_extension from diracx.core.settings import ServiceSettingsBase from diracx.core.utils import dotenv_files_from_environment -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError from diracx.db.os.utils import BaseOSDB from diracx.db.sql.utils import BaseSQLDB from diracx.routers.access_policies import BaseAccessPolicy, check_permissions @@ -291,10 +291,10 @@ def create_app_inner( handler_signature = Callable[[Request, Exception], Response | Awaitable[Response]] app.add_exception_handler(DiracError, cast(handler_signature, dirac_error_handler)) app.add_exception_handler( - DiracHttpResponse, cast(handler_signature, http_response_handler) + DiracHttpResponseError, cast(handler_signature, http_response_handler) ) app.add_exception_handler( - DBUnavailable, cast(handler_signature, route_unavailable_error_hander) + DBUnavailableError, cast(handler_signature, route_unavailable_error_hander) ) # TODO: remove the CORSMiddleware once we figure out how to launch @@ -393,11 +393,11 @@ def dirac_error_handler(request: Request, exc: DiracError) -> Response: ) -def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response: +def http_response_handler(request: Request, exc: DiracHttpResponseError) -> Response: return JSONResponse(status_code=exc.status_code, content=exc.data) -def route_unavailable_error_hander(request: Request, exc: DBUnavailable): +def route_unavailable_error_hander(request: Request, exc: DBUnavailableError): return JSONResponse( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, headers={"Retry-After": "10"}, @@ -435,7 +435,7 @@ async def is_db_unavailable(db: BaseSQLDB | BaseOSDB) -> str: await db.ping() _db_alive_cache[db] = "" - except DBUnavailable as e: + except DBUnavailableError as e: _db_alive_cache[db] = e.args[0] return _db_alive_cache[db] @@ -448,7 +448,7 @@ async def db_transaction(db: T2) -> AsyncGenerator[T2]: async with db: # Check whether the connection still works before executing the query if reason := await is_db_unavailable(db): - raise DBUnavailable(reason) + raise DBUnavailableError(reason) yield db diff --git a/diracx-routers/src/diracx/routers/auth/management.py b/diracx-routers/src/diracx/routers/auth/management.py index e8b59356..7bd7c1b9 100644 --- a/diracx-routers/src/diracx/routers/auth/management.py +++ b/diracx-routers/src/diracx/routers/auth/management.py @@ -66,7 +66,7 @@ async def revoke_refresh_token( detail="JTI provided does not exist", ) - if PROXY_MANAGEMENT not in user_info.properties and user_info.sub != res["sub"]: + if PROXY_MANAGEMENT not in user_info.properties and user_info.sub != res["Sub"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Cannot revoke a refresh token owned by someone else", diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index 14103add..d21416ea 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -12,7 +12,7 @@ from fastapi import Depends, Form, Header, HTTPException, status from diracx.core.exceptions import ( - DiracHttpResponse, + DiracHttpResponseError, ExpiredFlowError, PendingAuthorizationError, ) @@ -120,27 +120,27 @@ async def get_oidc_token_info_from_device_flow( device_code, settings.device_flow_expiration_seconds ) except PendingAuthorizationError as e: - raise DiracHttpResponse( + raise DiracHttpResponseError( status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"} ) from e except ExpiredFlowError as e: - raise DiracHttpResponse( + raise DiracHttpResponseError( status.HTTP_401_UNAUTHORIZED, {"error": "expired_token"} ) from e - # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) - # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) + # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) + # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) - if info["client_id"] != client_id: + if info["ClientID"] != client_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Bad client_id", ) - oidc_token_info = info["id_token"] - scope = info["scope"] + oidc_token_info = info["IDToken"] + scope = info["Scope"] # TODO: use HTTPException while still respecting the standard format # required by the RFC - if info["status"] != FlowStatus.READY: + if info["Status"] != FlowStatus.READY: # That should never ever happen raise NotImplementedError(f"Unexpected flow status {info['status']!r}") return (oidc_token_info, scope) @@ -159,12 +159,12 @@ async def get_oidc_token_info_from_authorization_flow( info = await auth_db.get_authorization_flow( code, settings.authorization_flow_expiration_seconds ) - if redirect_uri != info["redirect_uri"]: + if redirect_uri != info["RedirectURI"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect_uri", ) - if client_id != info["client_id"]: + if client_id != info["ClientID"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Bad client_id", @@ -184,18 +184,18 @@ async def get_oidc_token_info_from_authorization_flow( detail="Malformed code_verifier", ) from e - if code_challenge != info["code_challenge"]: + if code_challenge != info["CodeChallenge"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid code_challenge", ) - oidc_token_info = info["id_token"] - scope = info["scope"] + oidc_token_info = info["IDToken"] + scope = info["Scope"] # TODO: use HTTPException while still respecting the standard format # required by the RFC - if info["status"] != FlowStatus.READY: + if info["Status"] != FlowStatus.READY: # That should never ever happen raise NotImplementedError(f"Unexpected flow status {info['status']!r}") @@ -214,7 +214,7 @@ async def get_oidc_token_info_from_refresh_flow( # Get some useful user information from the refresh token entry in the DB refresh_token_attributes = await auth_db.get_refresh_token(jti) - sub = refresh_token_attributes["sub"] + sub = refresh_token_attributes["Sub"] # Check if the refresh token was obtained from the legacy_exchange endpoint # If it is the case, we bypass the refresh token rotation mechanism @@ -224,7 +224,7 @@ async def get_oidc_token_info_from_refresh_flow( # This might indicate that a potential attacker try to impersonate someone # In such case, all the refresh tokens bound to a given user (subject) should be revoked # Forcing the user to reauthenticate interactively through an authorization/device flow (recommended practice) - if refresh_token_attributes["status"] == RefreshTokenStatus.REVOKED: + if refresh_token_attributes["Status"] == RefreshTokenStatus.REVOKED: # Revoke all the user tokens from the subject await auth_db.revoke_user_refresh_tokens(sub) @@ -246,9 +246,9 @@ async def get_oidc_token_info_from_refresh_flow( # The sub attribute coming from the DB contains the VO name # We need to remove it as if it were coming from an ID token from an external IdP "sub": sub.split(":", 1)[1], - "preferred_username": refresh_token_attributes["preferred_username"], + "preferred_username": refresh_token_attributes["PreferredUsername"], } - scope = refresh_token_attributes["scope"] + scope = refresh_token_attributes["Scope"] return (oidc_token_info, scope, legacy_exchange) diff --git a/diracx-routers/src/diracx/routers/auth/utils.py b/diracx-routers/src/diracx/routers/auth/utils.py index 3b881361..7ca8b523 100644 --- a/diracx-routers/src/diracx/routers/auth/utils.py +++ b/diracx-routers/src/diracx/routers/auth/utils.py @@ -262,7 +262,7 @@ async def initiate_authorization_flow_with_iam( state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite ) - urlParams = [ + url_params = [ "response_type=code", f"code_challenge={code_challenge}", "code_challenge_method=S256", @@ -271,7 +271,7 @@ async def initiate_authorization_flow_with_iam( "scope=openid%20profile", f"state={encrypted_state}", ] - authorization_flow_url = f"{authorization_endpoint}?{'&'.join(urlParams)}" + authorization_flow_url = f"{authorization_endpoint}?{'&'.join(url_params)}" return authorization_flow_url diff --git a/diracx-routers/src/diracx/routers/jobs/submission.py b/diracx-routers/src/diracx/routers/jobs/submission.py index 5f953fa3..4cec9bed 100644 --- a/diracx-routers/src/diracx/routers/jobs/submission.py +++ b/diracx-routers/src/diracx/routers/jobs/submission.py @@ -94,25 +94,25 @@ async def submit_bulk_jdl_jobs( if len(job_definitions) == 1: # Check if the job is a parametric one - jobClassAd = ClassAd(job_definitions[0]) - result = getParameterVectorLength(jobClassAd) + job_class_ad = ClassAd(job_definitions[0]) + result = getParameterVectorLength(job_class_ad) if not result["OK"]: # FIXME dont do this print("Issue with getParameterVectorLength", result["Message"]) return result - nJobs = result["Value"] - parametricJob = False - if nJobs is not None and nJobs > 0: + n_jobs = result["Value"] + parametric_job = False + if n_jobs is not None and n_jobs > 0: # if we are here, then jobDesc was the description of a parametric job. So we start unpacking - parametricJob = True - result = generateParametricJobs(jobClassAd) + parametric_job = True + result = generateParametricJobs(job_class_ad) if not result["OK"]: # FIXME why? return result - jobDescList = result["Value"] + job_desc_list = result["Value"] else: # if we are here, then jobDesc was the description of a single job. - jobDescList = job_definitions + job_desc_list = job_definitions else: # if we are here, then jobDesc is a list of JDLs # we need to check that none of them is a parametric @@ -128,12 +128,12 @@ async def submit_bulk_jdl_jobs( detail="You cannot submit parametric jobs in a bulk fashion", ) - jobDescList = job_definitions - # parametricJob = True - parametricJob = False + job_desc_list = job_definitions + # parametric_job = True + parametric_job = False # TODO: make the max number of jobs configurable in the CS - if len(jobDescList) > MAX_PARAMETRIC_JOBS: + if len(job_desc_list) > MAX_PARAMETRIC_JOBS: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail=f"Normal user cannot submit more than {MAX_PARAMETRIC_JOBS} jobs at once", @@ -141,12 +141,12 @@ async def submit_bulk_jdl_jobs( result = [] - if parametricJob: - initialStatus = JobStatus.SUBMITTING - initialMinorStatus = "Bulk transaction confirmation" + if parametric_job: + initial_status = JobStatus.SUBMITTING + initial_minor_status = "Bulk transaction confirmation" else: - initialStatus = JobStatus.RECEIVED - initialMinorStatus = "Job accepted" + initial_status = JobStatus.RECEIVED + initial_minor_status = "Job accepted" submitted_job_ids = await submit_jobs_jdl( [ @@ -154,11 +154,11 @@ async def submit_bulk_jdl_jobs( jdl=jdl, owner=user_info.preferred_username, owner_group=user_info.dirac_group, - initial_status=initialStatus, - initial_minor_status=initialMinorStatus, + initial_status=initial_status, + initial_minor_status=initial_minor_status, vo=user_info.vo, ) - for jdl in jobDescList + for jdl in job_desc_list ], job_db=job_db, ) @@ -172,8 +172,8 @@ async def submit_bulk_jdl_jobs( [ JobLoggingRecord( job_id=int(job_id), - status=initialStatus, - minor_status=initialMinorStatus, + status=initial_status, + minor_status=initial_minor_status, application_status="Unknown", date=job_created_time, source="JobManager", @@ -182,14 +182,14 @@ async def submit_bulk_jdl_jobs( ] ) - # if not parametricJob: + # if not parametric_job: # self.__sendJobsToOptimizationMind(submitted_job_ids) return [ InsertedJob( JobID=job_id, - Status=initialStatus, - MinorStatus=initialMinorStatus, + Status=initial_status, + MinorStatus=initial_minor_status, TimeStamp=job_created_time, ) for job_id in submitted_job_ids diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index 59b1f6c0..4a81d7e9 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -482,8 +482,8 @@ async def test_get_job_status_history( assert r.json()[0]["MinorStatus"] == "Job accepted" assert r.json()[0]["ApplicationStatus"] == "Unknown" - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" before = datetime.now(timezone.utc) r = normal_user_client.patch( @@ -491,8 +491,8 @@ async def test_get_job_status_history( json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -501,8 +501,8 @@ async def test_get_job_status_history( after = datetime.now(timezone.utc) assert r.status_code == 200, r.json() - assert r.json()["success"][str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()["success"][str(valid_job_id)]["Status"] == new_status + assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == new_minor_status # Act r = normal_user_client.post( @@ -588,15 +588,15 @@ def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): assert j["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -604,8 +604,8 @@ def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): # Assert assert r.status_code == 200, r.json() - assert r.json()["success"][str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()["success"][str(valid_job_id)]["Status"] == new_status + assert r.json()["success"][str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.post( "/api/jobs/search", @@ -621,8 +621,8 @@ def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): ) assert r.status_code == 200, r.json() assert r.json()[0]["JobID"] == valid_job_id - assert r.json()[0]["Status"] == NEW_STATUS - assert r.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[0]["Status"] == new_status + assert r.json()[0]["MinorStatus"] == new_minor_status assert r.json()[0]["ApplicationStatus"] == "Unknown" @@ -700,15 +700,15 @@ def test_set_job_status_cannot_make_impossible_transitions( assert r.json()[0]["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.RUNNING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.RUNNING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -718,8 +718,8 @@ def test_set_job_status_cannot_make_impossible_transitions( assert r.status_code == 200, r.json() success = r.json()["success"] assert len(success) == 1, r.json() - assert success[str(valid_job_id)]["Status"] != NEW_STATUS - assert success[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert success[str(valid_job_id)]["Status"] != new_status + assert success[str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.post( "/api/jobs/search", @@ -734,8 +734,8 @@ def test_set_job_status_cannot_make_impossible_transitions( }, ) assert r.status_code == 200, r.json() - assert r.json()[0]["Status"] != NEW_STATUS - assert r.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[0]["Status"] != new_status + assert r.json()[0]["MinorStatus"] == new_minor_status assert r.json()[0]["ApplicationStatus"] == "Unknown" @@ -760,15 +760,15 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) assert r.json()[0]["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.RUNNING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.RUNNING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ valid_job_id: { datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } }, @@ -779,8 +779,8 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) # Assert assert r.status_code == 200, r.json() - assert success[str(valid_job_id)]["Status"] == NEW_STATUS - assert success[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert success[str(valid_job_id)]["Status"] == new_status + assert success[str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.post( "/api/jobs/search", @@ -796,8 +796,8 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) ) assert r.status_code == 200, r.json() assert r.json()[0]["JobID"] == valid_job_id - assert r.json()[0]["Status"] == NEW_STATUS - assert r.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[0]["Status"] == new_status + assert r.json()[0]["MinorStatus"] == new_minor_status assert r.json()[0]["ApplicationStatus"] == "Unknown" @@ -822,15 +822,15 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): assert r.json()[0]["MinorStatus"] == "Bulk transaction confirmation" # Act - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ job_id: { datetime.now(timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } for job_id in valid_job_ids @@ -842,8 +842,8 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): # Assert assert r.status_code == 200, r.json() for job_id in valid_job_ids: - assert success[str(job_id)]["Status"] == NEW_STATUS - assert success[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert success[str(job_id)]["Status"] == new_status + assert success[str(job_id)]["MinorStatus"] == new_minor_status r_get = normal_user_client.post( "/api/jobs/search", @@ -859,8 +859,8 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): ) assert r_get.status_code == 200, r_get.json() assert r_get.json()[0]["JobID"] == job_id - assert r_get.json()[0]["Status"] == NEW_STATUS - assert r_get.json()[0]["MinorStatus"] == NEW_MINOR_STATUS + assert r_get.json()[0]["Status"] == new_status + assert r_get.json()[0]["MinorStatus"] == new_minor_status assert r_get.json()[0]["ApplicationStatus"] == "Unknown" diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index b73696b1..59ebca1d 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -170,11 +170,16 @@ class AlwaysAllowAccessPolicy(BaseAccessPolicy): """Dummy access policy.""" async def policy( - policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs + policy_name: str, # noqa: N805 + user_info: AuthorizedUserInfo, + /, + **kwargs, ): pass - def enrich_tokens(access_payload: dict, refresh_payload: dict): + def enrich_tokens( + access_payload: dict, refresh_payload: dict # noqa: N805 + ): return {"PolicySpecific": "OpenAccessForTest"}, {} diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 282128ac..6e181a79 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -42,8 +42,8 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: from diracx.db.sql.utils import DateNowColumn # Dynamically create a subclass of BaseSQLDB so we get clearer errors - MockedDB = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {}) - self._sql_db = MockedDB(connection_kwargs["sqlalchemy_dsn"]) + mocked_db = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {}) + self._sql_db = mocked_db(connection_kwargs["sqlalchemy_dsn"]) # Dynamically create the table definition based on the fields columns = [ @@ -53,16 +53,16 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: for field, field_type in self.fields.items(): match field_type["type"]: case "date": - ColumnType = DateNowColumn + column_type = DateNowColumn case "long": - ColumnType = partial(Column, type_=Integer) + column_type = partial(Column, type_=Integer) case "keyword": - ColumnType = partial(Column, type_=String(255)) + column_type = partial(Column, type_=String(255)) case "text": - ColumnType = partial(Column, type_=String(64 * 1024)) + column_type = partial(Column, type_=String(64 * 1024)) case _: raise NotImplementedError(f"Unknown field type: {field_type=}") - columns.append(ColumnType(field, default=None)) + columns.append(column_type(field, default=None)) self._sql_db.metadata = MetaData() self._table = Table("dummy", self._sql_db.metadata, *columns) @@ -158,6 +158,6 @@ def fake_available_osdb_implementations(name, *, real_available_implementations) # Dynamically generate a class that inherits from the first implementation # but that also has the MockOSDBMixin - MockParameterDB = type(name, (MockOSDBMixin, implementations[0]), {}) + mock_parameter_db = type(name, (MockOSDBMixin, implementations[0]), {}) - return [MockParameterDB] + implementations + return [mock_parameter_db] + implementations diff --git a/docs/CODING_CONVENTION.md b/docs/CODING_CONVENTION.md index 9ea27510..50363485 100644 --- a/docs/CODING_CONVENTION.md +++ b/docs/CODING_CONVENTION.md @@ -47,6 +47,28 @@ ALWAYS DO from __future__ import annotations ``` +# SQL Alchemy + +DO + +```python +class Owners(Base): + __tablename__ = "Owners" + owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) + creation_time = DateNowColumn("CreationTime") + name = Column("Name", String(255)) +``` + +DONT + +```python +class Owners(Base): + __tablename__ = "Owners" + OwnerID = Column(Integer, primary_key=True, autoincrement=True) + CreationTime = DateNowColumn() + Name = Column(String(255)) +``` + # Structure diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py index 8f56ce4e..414bc23d 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py @@ -17,12 +17,12 @@ async def insert_gubbins_info(self, job_id: int, info: str): """ This is a new method that makes use of a new table. """ - stmt = insert(GubbinsInfo).values(JobID=job_id, Info=info) + stmt = insert(GubbinsInfo).values(job_id=job_id, info=info) await self.conn.execute(stmt) - async def getJobJDL( # type: ignore[override] - self, job_id: int, original: bool = False, with_info=False - ) -> str | dict[str, str]: + async def get_job_jdls( # type: ignore[override] + self, job_ids, original: bool = False, with_info=False + ) -> dict: """ This method modifes the one in the parent class: * adds an extra argument @@ -31,16 +31,23 @@ async def getJobJDL( # type: ignore[override] Note that this requires to disable mypy error with # type: ignore[override] """ - jdl = await super().getJobJDL(job_id, original=original) + jdl = await super().get_job_jdls(job_ids, original=original) if not with_info: return jdl - stmt = select(GubbinsInfo.Info).where(GubbinsInfo.JobID == job_id) + stmt = select(GubbinsInfo.job_id, GubbinsInfo.info).where( + GubbinsInfo.job_id.in_(job_ids) + ) - info = (await self.conn.execute(stmt)).scalar_one() - return {"JDL": jdl, "Info": info} + rows = await self.conn.execute(stmt) + info = {row[0]: row[1] for row in rows.fetchall()} - async def setJobAttributesBulk(self, jobData): + result = {} + for job_id, jdl_details in jdl.items(): + result[job_id] = {"JDL": jdl_details, "Info": info.get(job_id, "")} + return result + + async def set_job_attributes_bulk(self, job_data): """ This method modified the one in the parent class, without changing the argument nor the return type diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py index ac5cd039..eee922d4 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py @@ -13,7 +13,7 @@ class GubbinsInfo(JobDBBase): __tablename__ = "GubbinsJobs" - JobID = Column( - Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + job_id = Column( + "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True ) - Info = Column(String(255), default="", primary_key=True) + info = Column("Info", String(255), default="", primary_key=True) diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py index 5ce64edc..dc73d3b1 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py @@ -25,7 +25,7 @@ class LollygagDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] - stmt = select(*columns, func.count(Cars.licensePlate).label("count")) + stmt = select(*columns, func.count(Cars.license_plate).label("count")) stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -48,7 +48,7 @@ async def get_owner(self) -> list[str]: async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int: stmt = insert(Cars).values( - licensePlate=license_plate, model=model, ownerID=owner_id + license_plate=license_plate, model=model, owner_id=owner_id ) result = await self.conn.execute(stmt) diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py index 9e7b4eba..9b80e513 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py @@ -9,13 +9,13 @@ class Owners(Base): __tablename__ = "Owners" - ownerID = Column(Integer, primary_key=True, autoincrement=True) - creation_time = DateNowColumn() - name = Column(String(255)) + owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True) + creation_time = DateNowColumn("CreationTime") + name = Column("Name", String(255)) class Cars(Base): __tablename__ = "Cars" - licensePlate = Column(Uuid(), primary_key=True) - model = Column(String(255)) - ownerID = Column(Integer, ForeignKey(Owners.ownerID)) + license_plate = Column("LicensePlate", Uuid(), primary_key=True) + model = Column("Model", String(255)) + owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id)) diff --git a/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py b/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py similarity index 79% rename from extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py rename to extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py index f98e3bdf..a9d21362 100644 --- a/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py +++ b/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py @@ -41,14 +41,12 @@ async def test_gubbins_info(gubbins_db): ], gubbins_db, ) + await gubbins_db.insert_gubbins_info(job_ids[0], "info") - job_id = job_ids[0] + result = await gubbins_db.get_job_jdls(job_ids, original=True) + assert result == {1: "[JDL]"} - await gubbins_db.insert_gubbins_info(job_id, "info") - - result = await gubbins_db.getJobJDL(job_id, original=True) - assert result == "[JDL]" - - result = await gubbins_db.getJobJDL(job_id, with_info=True) - assert "JDL" in result - assert result["Info"] == "info" + result = await gubbins_db.get_job_jdls(job_ids, with_info=True) + assert len(result) == 1 + assert result[1].get("JDL") + assert result[1].get("Info") == "info" diff --git a/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py similarity index 86% rename from extensions/gubbins/gubbins-db/tests/test_lollygagDB.py rename to extensions/gubbins/gubbins-db/tests/test_lollygag_db.py index f963ded1..e7c72931 100644 --- a/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py +++ b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py @@ -6,7 +6,7 @@ import pytest from diracx.core.exceptions import InvalidQueryError -from diracx.db.sql.utils import SQLDBUnavailable +from diracx.db.sql.utils import SQLDBUnavailableError from gubbins.db.sql.lollygag.db import LollygagDB @@ -31,7 +31,7 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # So it is important to write test this way async with lollygag_db as lollygag_db: # First we check that the DB is empty - result = await lollygag_db.summary(["model"], []) + result = await lollygag_db.summary(["Model"], []) assert not result # Now we add some data in the DB @@ -51,14 +51,14 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # Check that there are now 10 cars assigned to a single driver async with lollygag_db as lollygag_db: - result = await lollygag_db.summary(["ownerID"], []) + result = await lollygag_db.summary(["OwnerID"], []) assert result[0]["count"] == 10 # Test the selection async with lollygag_db as lollygag_db: result = await lollygag_db.summary( - ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] + ["OwnerID"], [{"parameter": "Model", "operator": "eq", "value": "model_1"}] ) assert result[0]["count"] == 1 @@ -66,10 +66,10 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async with lollygag_db as lollygag_db: with pytest.raises(InvalidQueryError): result = await lollygag_db.summary( - ["ownerID"], + ["OwnerID"], [ { - "parameter": "model", + "parameter": "Model", "operator": "BADSELECTION", "value": "model_1", } @@ -80,6 +80,6 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async def test_bad_connection(): lollygag_db = LollygagDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") async with lollygag_db.engine_context(): - with pytest.raises(SQLDBUnavailable): + with pytest.raises(SQLDBUnavailableError): async with lollygag_db: lollygag_db.ping() diff --git a/extensions/gubbins/pyproject.toml b/extensions/gubbins/pyproject.toml index a10370f5..c61127cb 100644 --- a/extensions/gubbins/pyproject.toml +++ b/extensions/gubbins/pyproject.toml @@ -52,6 +52,7 @@ select = [ "FLY", # flynt "DTZ", # flake8-datetimez "S", # flake8-bandit + "N", # pep8-naming ] ignore = [ diff --git a/pyproject.toml b/pyproject.toml index 06b78e3b..77998d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ select = [ "FLY", # flynt "DTZ", # flake8-datetimez "S", # flake8-bandit + "N", # pep8-naming ] ignore = [ "B905", diff --git a/run_local.sh b/run_local.sh index b632d248..83bfbcc4 100755 --- a/run_local.sh +++ b/run_local.sh @@ -70,7 +70,7 @@ echo "" echo "1. Use the CLI:" echo "" echo " export DIRACX_URL=http://localhost:8000" -echo " env DIRACX_SERVICE_AUTH_STATE_KEY='${state_key}' tests/make-token-local.py ${signing_key}" +echo " env DIRACX_SERVICE_AUTH_STATE_KEY='${state_key}' tests/make_token_local.py ${signing_key}" echo "" echo "2. Using swagger: http://localhost:8000/api/docs" diff --git a/tests/make-token-local.py b/tests/make_token_local.py similarity index 100% rename from tests/make-token-local.py rename to tests/make_token_local.py