From 8587d6c7ebd62823423f9dc0c79774639622c024 Mon Sep 17 00:00:00 2001 From: martynia Date: Wed, 11 Sep 2024 10:25:10 +0200 Subject: [PATCH 01/13] feat: Implement PilotAgents schema --- diracx-db/pyproject.toml | 1 + diracx-db/src/diracx/db/sql/__init__.py | 10 ++- .../diracx/db/sql/pilot_agents/__init__.py | 0 .../src/diracx/db/sql/pilot_agents/db.py | 8 +++ .../src/diracx/db/sql/pilot_agents/schema.py | 65 +++++++++++++++++++ .../src/diracx/routers/dependencies.py | 3 + 6 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 diracx-db/src/diracx/db/sql/pilot_agents/__init__.py create mode 100644 diracx-db/src/diracx/db/sql/pilot_agents/db.py create mode 100644 diracx-db/src/diracx/db/sql/pilot_agents/schema.py diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index 37d728be..e4e211ec 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -31,6 +31,7 @@ testing = [ AuthDB = "diracx.db.sql:AuthDB" JobDB = "diracx.db.sql:JobDB" JobLoggingDB = "diracx.db.sql:JobLoggingDB" +PilotAgentsDB = "diracx.db.sql:PilotAgentsDB" SandboxMetadataDB = "diracx.db.sql:SandboxMetadataDB" TaskQueueDB = "diracx.db.sql:TaskQueueDB" diff --git a/diracx-db/src/diracx/db/sql/__init__.py b/diracx-db/src/diracx/db/sql/__init__.py index f98785e2..3be3af8a 100644 --- a/diracx-db/src/diracx/db/sql/__init__.py +++ b/diracx-db/src/diracx/db/sql/__init__.py @@ -1,9 +1,17 @@ from __future__ import annotations -__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB") +__all__ = ( + "AuthDB", + "JobDB", + "JobLoggingDB", + "PilotAgentsDB", + "SandboxMetadataDB", + "TaskQueueDB", +) from .auth.db import AuthDB from .job.db import JobDB from .job_logging.db import JobLoggingDB +from .pilot_agents.db import PilotAgentsDB from .sandbox_metadata.db import SandboxMetadataDB from .task_queue.db import TaskQueueDB diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py b/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py new file mode 100644 index 00000000..a5c9b6bc --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from ..utils import BaseSQLDB +from .schema import PilotAgentsDBBase + + +class PilotAgentsDB(BaseSQLDB): + metadata = PilotAgentsDBBase.metadata diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py new file mode 100644 index 00000000..366e0774 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -0,0 +1,65 @@ +from sqlalchemy import ( + DateTime, + Double, + Index, + Integer, + String, + Text, +) +from sqlalchemy.orm import declarative_base + +from ..job.schema import EnumBackedBool +from ..utils import Column, NullColumn + +PilotAgentsDBBase = declarative_base() + + +class PilotAgents(PilotAgentsDBBase): + __tablename__ = "PilotAgents" + + PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True) + InitialJobID = Column("InitialJobID", Integer, default=0) + CurrentJobID = Column("CurrentJobID", Integer, default=0) + TaskQueueID = Column("TaskQueueID", Integer, default=0) + PilotJobReference = Column("PilotJobReference", String(255), default="Unknown") + PilotStamp = Column("PilotStamp", String(32), default="") + DestinationSite = Column("DestinationSite", String(128), default="NotAssigned") + Queue = Column("Queue", String(128), default="Unknown") + GridSite = Column("GridSite", String(128), default="Unknown") + Broker = Column("Broker", String(128), default="Unknown") + OwnerDN = Column("OwnerDN", String(255)) + OwnerGroup = Column("OwnerGroup", String(128)) + GridType = Column("GridType", String(32), default="LCG") + GridRequirements = Column("GridRequirements", Text) + BenchMark = Column("BenchMark", Double, default=0.0) + SubmissionTime = NullColumn("SubmissionTime", DateTime) + LastUpdateTime = NullColumn("LastUpdateTime", DateTime) + Status = Column("Status", String(32), default="Unknown") + StatusReason = Column("StatusReason", String(255), default="Unknown") + ParentID = Column("ParentID", Integer, default=0) + OutputReady = Column("OutputReady", EnumBackedBool(), default=False) + AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False) + + __table_args__ = ( + Index("PilotJobReference", "PilotJobReference"), + Index("Status", "Status"), + Index("Statuskey", "GridSite", "DestinationSite", "Status"), + ) + + +class JobToPilotMapping(PilotAgentsDBBase): + __tablename__ = "JobToPilotMapping" + + PilotID = Column("PilotID", Integer, primary_key=True) + JobID = Column("JobID", Integer, primary_key=True) + StartTime = Column("StartTime", DateTime) + + __table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID")) + + +class PilotOutput(PilotAgentsDBBase): + __tablename__ = "PilotOutput" + + PilotID = Column("PilotID", Integer, primary_key=True) + StdOutput = Column("StdOutput", Text) + StdError = Column("StdError", Text) diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index c342da83..ab40190b 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -7,6 +7,7 @@ "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB", + "PilotAgentsDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -23,6 +24,7 @@ from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB +from diracx.db.sql import PilotAgentsDB as _PilotAgentsDB from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB from diracx.db.sql import TaskQueueDB as _TaskQueueDB @@ -38,6 +40,7 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +PilotAgentsDB = Annotated[_PilotAgentsDB, Depends(_PilotAgentsDB.transaction)] SandboxMetadataDB = Annotated[ _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) ] From 92811584f6386a81ad5af4c28e45a7fbd616610a Mon Sep 17 00:00:00 2001 From: martynia Date: Thu, 19 Sep 2024 16:33:31 +0200 Subject: [PATCH 02/13] fix: add a default value for GridRequirements column --- diracx-db/src/diracx/db/sql/pilot_agents/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 366e0774..32e4c88b 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -30,7 +30,7 @@ class PilotAgents(PilotAgentsDBBase): OwnerDN = Column("OwnerDN", String(255)) OwnerGroup = Column("OwnerGroup", String(128)) GridType = Column("GridType", String(32), default="LCG") - GridRequirements = Column("GridRequirements", Text) + GridRequirements = Column("GridRequirements", Text, default="") BenchMark = Column("BenchMark", Double, default=0.0) SubmissionTime = NullColumn("SubmissionTime", DateTime) LastUpdateTime = NullColumn("LastUpdateTime", DateTime) From f55241786b8b70c50eb3d68960e6eb7ff082b412 Mon Sep 17 00:00:00 2001 From: martynia Date: Thu, 19 Sep 2024 16:37:28 +0200 Subject: [PATCH 03/13] feat: addPilotReferences --- .../src/diracx/db/sql/pilot_agents/db.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index a5c9b6bc..73e15c60 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -1,8 +1,42 @@ from __future__ import annotations +from datetime import datetime, timezone + +from sqlalchemy import insert + from ..utils import BaseSQLDB -from .schema import PilotAgentsDBBase +from .schema import PilotAgents, PilotAgentsDBBase class PilotAgentsDB(BaseSQLDB): metadata = PilotAgentsDBBase.metadata + + async def addPilotReferences( + self, + pilotRef: list[str], + ownerGroup: str, + gridType: str = "DIRAC", + pilotStampDict: dict = {}, + ) -> list[int]: + + row_ids = [] + for ref in pilotRef: + stamp = "" + if ref in pilotStampDict: + stamp = pilotStampDict[ref] + now = datetime.now(tz=timezone.utc) + stmt = insert(PilotAgents).values( + PilotJobReference=ref, + TaskQueueID=0, + OwnerDN="Unknown", + OwnerGroup=ownerGroup, + GridType=gridType, + SubmissionTime=now, + LastUpdateTime=now, + Status="submitted", + PilotStamp=stamp, + ) + result = await self.conn.execute(stmt) + row_ids.append(result.lastrowid) + + return row_ids From ce8dfc0f773fdd9ca61d7dedf5eeaaf15e9f14fc Mon Sep 17 00:00:00 2001 From: martynia Date: Thu, 19 Sep 2024 17:12:08 +0200 Subject: [PATCH 04/13] fix: modify schema to match Dirac 9.0 --- diracx-db/src/diracx/db/sql/pilot_agents/schema.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 32e4c88b..277513fb 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -20,24 +20,18 @@ class PilotAgents(PilotAgentsDBBase): PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True) InitialJobID = Column("InitialJobID", Integer, default=0) CurrentJobID = Column("CurrentJobID", Integer, default=0) - TaskQueueID = Column("TaskQueueID", Integer, default=0) PilotJobReference = Column("PilotJobReference", String(255), default="Unknown") PilotStamp = Column("PilotStamp", String(32), default="") DestinationSite = Column("DestinationSite", String(128), default="NotAssigned") Queue = Column("Queue", String(128), default="Unknown") GridSite = Column("GridSite", String(128), default="Unknown") - Broker = Column("Broker", String(128), default="Unknown") - OwnerDN = Column("OwnerDN", String(255)) - OwnerGroup = Column("OwnerGroup", String(128)) + VO = Column("VO", String(128)) GridType = Column("GridType", String(32), default="LCG") - GridRequirements = Column("GridRequirements", Text, default="") BenchMark = Column("BenchMark", Double, default=0.0) SubmissionTime = NullColumn("SubmissionTime", DateTime) LastUpdateTime = NullColumn("LastUpdateTime", DateTime) Status = Column("Status", String(32), default="Unknown") StatusReason = Column("StatusReason", String(255), default="Unknown") - ParentID = Column("ParentID", Integer, default=0) - OutputReady = Column("OutputReady", EnumBackedBool(), default=False) AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False) __table_args__ = ( From 86d0f4c5291e1100f9d8cca4ec9cd82daaeda961 Mon Sep 17 00:00:00 2001 From: martynia Date: Fri, 20 Sep 2024 09:45:04 +0100 Subject: [PATCH 05/13] fix: modify db.py to match Dirac 9.0 --- diracx-db/src/diracx/db/sql/pilot_agents/db.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index 73e15c60..0b9e56d5 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -14,7 +14,7 @@ class PilotAgentsDB(BaseSQLDB): async def addPilotReferences( self, pilotRef: list[str], - ownerGroup: str, + VO: str, gridType: str = "DIRAC", pilotStampDict: dict = {}, ) -> list[int]: @@ -27,13 +27,11 @@ async def addPilotReferences( now = datetime.now(tz=timezone.utc) stmt = insert(PilotAgents).values( PilotJobReference=ref, - TaskQueueID=0, - OwnerDN="Unknown", - OwnerGroup=ownerGroup, + VO=VO, GridType=gridType, SubmissionTime=now, LastUpdateTime=now, - Status="submitted", + Status="Submitted", PilotStamp=stamp, ) result = await self.conn.execute(stmt) From b2b0232e219b0ae7295f41ab05f405a85fc2c589 Mon Sep 17 00:00:00 2001 From: martynia Date: Thu, 3 Oct 2024 21:07:05 +0200 Subject: [PATCH 06/13] fix: avoid mutable default values as arguments + unit test --- .../src/diracx/db/sql/pilot_agents/db.py | 4 ++- diracx-db/tests/pilot_agents/__init__.py | 0 .../tests/pilot_agents/test_pilotAgentsDB.py | 29 +++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 diracx-db/tests/pilot_agents/__init__.py create mode 100644 diracx-db/tests/pilot_agents/test_pilotAgentsDB.py diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index 0b9e56d5..04b1aa13 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -16,9 +16,11 @@ async def addPilotReferences( pilotRef: list[str], VO: str, gridType: str = "DIRAC", - pilotStampDict: dict = {}, + pilotStampDict: dict | None = None, ) -> list[int]: + if pilotStampDict is None: + pilotStampDict = {} row_ids = [] for ref in pilotRef: stamp = "" diff --git a/diracx-db/tests/pilot_agents/__init__.py b/diracx-db/tests/pilot_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py new file mode 100644 index 00000000..e27bc3d9 --- /dev/null +++ b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest + +from diracx.db.sql.pilot_agents.db import PilotAgentsDB + + +@pytest.fixture +async def pilot_agents_db(tmp_path) -> PilotAgentsDB: + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): + + async with pilot_agents_db as pilot_agents_db: + # Add a pilot reference + refs = [f"ref_{i}" for i in range(10)] + stamps = [f"stamp_{i}" for i in range(10)] + stamp_dict = dict(zip(refs, stamps)) + + pilot_id = await pilot_agents_db.addPilotReferences( + refs, "test_vo", gridType="DIRAC", pilotStampDict=stamp_dict + ) + assert pilot_id + assert pilot_id == list(range(1, 11)) From 71395f4a7806f0286afaa393bf3037be903a2ddc Mon Sep 17 00:00:00 2001 From: martynia Date: Tue, 8 Oct 2024 14:23:09 +0200 Subject: [PATCH 07/13] fix: Apply suggestions from code review Co-authored-by: aldbr --- diracx-db/src/diracx/db/sql/pilot_agents/db.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index 04b1aa13..4bab8e60 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -13,23 +13,21 @@ class PilotAgentsDB(BaseSQLDB): async def addPilotReferences( self, - pilotRef: list[str], - VO: str, - gridType: str = "DIRAC", - pilotStampDict: dict | None = None, + pilot_ref: list[str], + vo: str, + grid_type: str = "DIRAC", + pilot_stamps: dict | None = None, ) -> list[int]: if pilotStampDict is None: pilotStampDict = {} row_ids = [] for ref in pilotRef: - stamp = "" - if ref in pilotStampDict: - stamp = pilotStampDict[ref] + stamp = pilotStampDict.get(ref, "") now = datetime.now(tz=timezone.utc) stmt = insert(PilotAgents).values( PilotJobReference=ref, - VO=VO, + VO=vo, GridType=gridType, SubmissionTime=now, LastUpdateTime=now, From 0adb6b3bc76c3693df72b5230a0d5a2b6313d9b7 Mon Sep 17 00:00:00 2001 From: martynia Date: Tue, 8 Oct 2024 16:58:11 +0200 Subject: [PATCH 08/13] fix: Apply suggestions from code review (2) --- diracx-db/src/diracx/db/sql/pilot_agents/db.py | 10 +++++----- diracx-db/tests/pilot_agents/test_pilotAgentsDB.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index 4bab8e60..92172637 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -19,16 +19,16 @@ async def addPilotReferences( pilot_stamps: dict | None = None, ) -> list[int]: - if pilotStampDict is None: - pilotStampDict = {} + if pilot_stamps is None: + pilot_stamps = {} row_ids = [] - for ref in pilotRef: - stamp = pilotStampDict.get(ref, "") + for ref in pilot_ref: + stamp = pilot_stamps.get(ref, "") now = datetime.now(tz=timezone.utc) stmt = insert(PilotAgents).values( PilotJobReference=ref, VO=vo, - GridType=gridType, + GridType=grid_type, SubmissionTime=now, LastUpdateTime=now, Status="Submitted", diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py index e27bc3d9..55d3628e 100644 --- a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py +++ b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py @@ -23,7 +23,7 @@ async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): stamp_dict = dict(zip(refs, stamps)) pilot_id = await pilot_agents_db.addPilotReferences( - refs, "test_vo", gridType="DIRAC", pilotStampDict=stamp_dict + refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict ) assert pilot_id assert pilot_id == list(range(1, 11)) From 34a33c6dd519c0e20c64a74141826bc7f00defc2 Mon Sep 17 00:00:00 2001 From: martynia Date: Tue, 8 Oct 2024 17:15:53 +0200 Subject: [PATCH 09/13] fix: move EnumBackedBool to utils/__init__ --- diracx-db/src/diracx/db/sql/job/schema.py | 29 +------------------ .../src/diracx/db/sql/pilot_agents/schema.py | 3 +- diracx-db/src/diracx/db/sql/utils/__init__.py | 27 +++++++++++++++++ 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index dded7be8..d17edf2d 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -1,4 +1,3 @@ -import sqlalchemy.types as types from sqlalchemy import ( DateTime, Enum, @@ -10,37 +9,11 @@ ) from sqlalchemy.orm import declarative_base -from ..utils import Column, NullColumn +from ..utils import Column, EnumBackedBool, NullColumn JobDBBase = declarative_base() -class EnumBackedBool(types.TypeDecorator): - """Maps a ``EnumBackedBool()`` column to True/False in Python.""" - - impl = types.Enum - cache_ok: bool = True - - def __init__(self) -> None: - super().__init__("True", "False") - - def process_bind_param(self, value, dialect) -> str: - if value is True: - return "True" - elif value is False: - return "False" - else: - raise NotImplementedError(value, dialect) - - def process_result_value(self, value, dialect) -> bool: - if value == "True": - return True - elif value == "False": - return False - else: - raise NotImplementedError(f"Unknown {value=}") - - class Jobs(JobDBBase): __tablename__ = "Jobs" diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 277513fb..7a2a0c5e 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -8,8 +8,7 @@ ) from sqlalchemy.orm import declarative_base -from ..job.schema import EnumBackedBool -from ..utils import Column, NullColumn +from ..utils import Column, EnumBackedBool, NullColumn PilotAgentsDBBase = declarative_base() diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index c514499b..3f3011a0 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -13,6 +13,7 @@ from functools import partial from typing import TYPE_CHECKING, Self, cast +import sqlalchemy.types as types from pydantic import TypeAdapter from sqlalchemy import Column as RawColumn from sqlalchemy import DateTime, Enum, MetaData, select @@ -128,6 +129,32 @@ def EnumColumn(enum_type, **kwargs): return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) +class EnumBackedBool(types.TypeDecorator): + """Maps a ``EnumBackedBool()`` column to True/False in Python.""" + + impl = types.Enum + cache_ok: bool = True + + def __init__(self) -> None: + super().__init__("True", "False") + + def process_bind_param(self, value, dialect) -> str: + if value is True: + return "True" + elif value is False: + return "False" + else: + raise NotImplementedError(value, dialect) + + def process_result_value(self, value, dialect) -> bool: + if value == "True": + return True + elif value == "False": + return False + else: + raise NotImplementedError(f"Unknown {value=}") + + class SQLDBError(Exception): pass From 69cf73a358af6d16e75c408067fa111f8e12507b Mon Sep 17 00:00:00 2001 From: martynia Date: Wed, 9 Oct 2024 14:43:26 +0200 Subject: [PATCH 10/13] test: extend tests for empty pilot stamps --- diracx-db/tests/pilot_agents/test_pilotAgentsDB.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py index 55d3628e..1f34bd26 100644 --- a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py +++ b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py @@ -27,3 +27,8 @@ async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): ) assert pilot_id assert pilot_id == list(range(1, 11)) + + pilot_id = await pilot_agents_db.addPilotReferences( + refs, "test_vo", grid_type="DIRAC", pilot_stamps=None + ) + assert pilot_id From 4d4fbe2a02821e9d61477e4fb6d4bdef771218e0 Mon Sep 17 00:00:00 2001 From: martynia Date: Wed, 9 Oct 2024 15:40:19 +0200 Subject: [PATCH 11/13] feat: use bulk insert --- .../src/diracx/db/sql/pilot_agents/db.py | 40 ++++++++++--------- .../tests/pilot_agents/test_pilotAgentsDB.py | 7 +--- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index 92172637..02d46a5f 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -17,24 +17,28 @@ async def addPilotReferences( vo: str, grid_type: str = "DIRAC", pilot_stamps: dict | None = None, - ) -> list[int]: + ) -> None: if pilot_stamps is None: pilot_stamps = {} - row_ids = [] - for ref in pilot_ref: - stamp = pilot_stamps.get(ref, "") - now = datetime.now(tz=timezone.utc) - stmt = insert(PilotAgents).values( - PilotJobReference=ref, - VO=vo, - GridType=grid_type, - SubmissionTime=now, - LastUpdateTime=now, - Status="Submitted", - PilotStamp=stamp, - ) - result = await self.conn.execute(stmt) - row_ids.append(result.lastrowid) - - return row_ids + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + values = [ + { + "PilotJobReference": ref, + "VO": vo, + "GridType": grid_type, + "SubmissionTime": now, + "LastUpdateTime": now, + "Status": "Submitted", + "PilotStamp": pilot_stamps.get(ref, ""), + } + for ref in pilot_ref + ] + + # Insert multiple rows in a single execute call + stmt = insert(PilotAgents) + await self.conn.execute(stmt, values) + return diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py index 1f34bd26..e6661835 100644 --- a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py +++ b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py @@ -22,13 +22,10 @@ async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): stamps = [f"stamp_{i}" for i in range(10)] stamp_dict = dict(zip(refs, stamps)) - pilot_id = await pilot_agents_db.addPilotReferences( + await pilot_agents_db.addPilotReferences( refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict ) - assert pilot_id - assert pilot_id == list(range(1, 11)) - pilot_id = await pilot_agents_db.addPilotReferences( + await pilot_agents_db.addPilotReferences( refs, "test_vo", grid_type="DIRAC", pilot_stamps=None ) - assert pilot_id From 9953ddbd2a8156f02f7f49d35f6c1d296274d4c7 Mon Sep 17 00:00:00 2001 From: martynia Date: Wed, 9 Oct 2024 17:28:45 +0200 Subject: [PATCH 12/13] feat: use bulk insert with a stmt compile step --- diracx-db/src/diracx/db/sql/pilot_agents/db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index 02d46a5f..e32d834b 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -39,6 +39,6 @@ async def addPilotReferences( ] # Insert multiple rows in a single execute call - stmt = insert(PilotAgents) - await self.conn.execute(stmt, values) + stmt = insert(PilotAgents).values(values) + await self.conn.execute(stmt) return From 09475e7757c09c4f1b60f7cfebe7e810e63f4412 Mon Sep 17 00:00:00 2001 From: martynia Date: Thu, 10 Oct 2024 14:41:57 +0200 Subject: [PATCH 13/13] fix: change method name --- diracx-db/src/diracx/db/sql/pilot_agents/db.py | 4 +++- diracx-db/tests/pilot_agents/test_pilotAgentsDB.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index e32d834b..b4f801b7 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -9,9 +9,11 @@ class PilotAgentsDB(BaseSQLDB): + """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + metadata = PilotAgentsDBBase.metadata - async def addPilotReferences( + async def add_pilot_references( self, pilot_ref: list[str], vo: str, diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py index e6661835..50829ecc 100644 --- a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py +++ b/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py @@ -22,10 +22,10 @@ async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): stamps = [f"stamp_{i}" for i in range(10)] stamp_dict = dict(zip(refs, stamps)) - await pilot_agents_db.addPilotReferences( + await pilot_agents_db.add_pilot_references( refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict ) - await pilot_agents_db.addPilotReferences( + await pilot_agents_db.add_pilot_references( refs, "test_vo", grid_type="DIRAC", pilot_stamps=None )