diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index 37d728be..e4e211ec 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -31,6 +31,7 @@ testing = [ AuthDB = "diracx.db.sql:AuthDB" JobDB = "diracx.db.sql:JobDB" JobLoggingDB = "diracx.db.sql:JobLoggingDB" +PilotAgentsDB = "diracx.db.sql:PilotAgentsDB" SandboxMetadataDB = "diracx.db.sql:SandboxMetadataDB" TaskQueueDB = "diracx.db.sql:TaskQueueDB" diff --git a/diracx-db/src/diracx/db/sql/__init__.py b/diracx-db/src/diracx/db/sql/__init__.py index f98785e2..3be3af8a 100644 --- a/diracx-db/src/diracx/db/sql/__init__.py +++ b/diracx-db/src/diracx/db/sql/__init__.py @@ -1,9 +1,17 @@ from __future__ import annotations -__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB") +__all__ = ( + "AuthDB", + "JobDB", + "JobLoggingDB", + "PilotAgentsDB", + "SandboxMetadataDB", + "TaskQueueDB", +) from .auth.db import AuthDB from .job.db import JobDB from .job_logging.db import JobLoggingDB +from .pilot_agents.db import PilotAgentsDB from .sandbox_metadata.db import SandboxMetadataDB from .task_queue.db import TaskQueueDB diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index dded7be8..d17edf2d 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -1,4 +1,3 @@ -import sqlalchemy.types as types from sqlalchemy import ( DateTime, Enum, @@ -10,37 +9,11 @@ ) from sqlalchemy.orm import declarative_base -from ..utils import Column, NullColumn +from ..utils import Column, EnumBackedBool, NullColumn JobDBBase = declarative_base() -class EnumBackedBool(types.TypeDecorator): - """Maps a ``EnumBackedBool()`` column to True/False in Python.""" - - impl = types.Enum - cache_ok: bool = True - - def __init__(self) -> None: - super().__init__("True", "False") - - def process_bind_param(self, value, dialect) -> str: - if value is True: - return "True" - elif value is False: - return "False" - else: - raise NotImplementedError(value, dialect) - - def process_result_value(self, value, dialect) -> bool: - if value == "True": - return True - elif value == "False": - return False - else: - raise NotImplementedError(f"Unknown {value=}") - - class Jobs(JobDBBase): __tablename__ = "Jobs" diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py b/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py new file mode 100644 index 00000000..b4f801b7 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +from sqlalchemy import insert + +from ..utils import BaseSQLDB +from .schema import PilotAgents, PilotAgentsDBBase + + +class PilotAgentsDB(BaseSQLDB): + """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + + metadata = PilotAgentsDBBase.metadata + + async def add_pilot_references( + self, + pilot_ref: list[str], + vo: str, + grid_type: str = "DIRAC", + pilot_stamps: dict | None = None, + ) -> None: + + if pilot_stamps is None: + pilot_stamps = {} + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + values = [ + { + "PilotJobReference": ref, + "VO": vo, + "GridType": grid_type, + "SubmissionTime": now, + "LastUpdateTime": now, + "Status": "Submitted", + "PilotStamp": pilot_stamps.get(ref, ""), + } + for ref in pilot_ref + ] + + # Insert multiple rows in a single execute call + stmt = insert(PilotAgents).values(values) + await self.conn.execute(stmt) + return diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py new file mode 100644 index 00000000..7a2a0c5e --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -0,0 +1,58 @@ +from sqlalchemy import ( + DateTime, + Double, + Index, + Integer, + String, + Text, +) +from sqlalchemy.orm import declarative_base + +from ..utils import Column, EnumBackedBool, NullColumn + +PilotAgentsDBBase = declarative_base() + + +class PilotAgents(PilotAgentsDBBase): + __tablename__ = "PilotAgents" + + PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True) + InitialJobID = Column("InitialJobID", Integer, default=0) + CurrentJobID = Column("CurrentJobID", Integer, default=0) + PilotJobReference = Column("PilotJobReference", String(255), default="Unknown") + PilotStamp = Column("PilotStamp", String(32), default="") + DestinationSite = Column("DestinationSite", String(128), default="NotAssigned") + Queue = Column("Queue", String(128), default="Unknown") + GridSite = Column("GridSite", String(128), default="Unknown") + VO = Column("VO", String(128)) + GridType = Column("GridType", String(32), default="LCG") + BenchMark = Column("BenchMark", Double, default=0.0) + SubmissionTime = NullColumn("SubmissionTime", DateTime) + LastUpdateTime = NullColumn("LastUpdateTime", DateTime) + Status = Column("Status", String(32), default="Unknown") + StatusReason = Column("StatusReason", String(255), default="Unknown") + AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False) + + __table_args__ = ( + Index("PilotJobReference", "PilotJobReference"), + Index("Status", "Status"), + Index("Statuskey", "GridSite", "DestinationSite", "Status"), + ) + + +class JobToPilotMapping(PilotAgentsDBBase): + __tablename__ = "JobToPilotMapping" + + PilotID = Column("PilotID", Integer, primary_key=True) + JobID = Column("JobID", Integer, primary_key=True) + StartTime = Column("StartTime", DateTime) + + __table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID")) + + +class PilotOutput(PilotAgentsDBBase): + __tablename__ = "PilotOutput" + + PilotID = Column("PilotID", Integer, primary_key=True) + StdOutput = Column("StdOutput", Text) + StdError = Column("StdError", Text) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index c514499b..3f3011a0 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -13,6 +13,7 @@ from functools import partial from typing import TYPE_CHECKING, Self, cast +import sqlalchemy.types as types from pydantic import TypeAdapter from sqlalchemy import Column as RawColumn from sqlalchemy import DateTime, Enum, MetaData, select @@ -128,6 +129,32 @@ def EnumColumn(enum_type, **kwargs): return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) +class EnumBackedBool(types.TypeDecorator): + """Maps a ``EnumBackedBool()`` column to True/False in Python.""" + + impl = types.Enum + cache_ok: bool = True + + def __init__(self) -> None: + super().__init__("True", "False") + + def process_bind_param(self, value, dialect) -> str: + if value is True: + return "True" + elif value is False: + return "False" + else: + raise NotImplementedError(value, dialect) + + def process_result_value(self, value, dialect) -> bool: + if value == "True": + return True + elif value == "False": + return False + else: + raise NotImplementedError(f"Unknown {value=}") + + class SQLDBError(Exception): pass diff --git a/diracx-db/tests/pilot_agents/__init__.py b/diracx-db/tests/pilot_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py new file mode 100644 index 00000000..50829ecc --- /dev/null +++ b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import pytest + +from diracx.db.sql.pilot_agents.db import PilotAgentsDB + + +@pytest.fixture +async def pilot_agents_db(tmp_path) -> PilotAgentsDB: + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): + + async with pilot_agents_db as pilot_agents_db: + # Add a pilot reference + refs = [f"ref_{i}" for i in range(10)] + stamps = [f"stamp_{i}" for i in range(10)] + stamp_dict = dict(zip(refs, stamps)) + + await pilot_agents_db.add_pilot_references( + refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict + ) + + await pilot_agents_db.add_pilot_references( + refs, "test_vo", grid_type="DIRAC", pilot_stamps=None + ) diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index c342da83..ab40190b 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -7,6 +7,7 @@ "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB", + "PilotAgentsDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -23,6 +24,7 @@ from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB +from diracx.db.sql import PilotAgentsDB as _PilotAgentsDB from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB from diracx.db.sql import TaskQueueDB as _TaskQueueDB @@ -38,6 +40,7 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +PilotAgentsDB = Annotated[_PilotAgentsDB, Depends(_PilotAgentsDB.transaction)] SandboxMetadataDB = Annotated[ _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) ]