Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement PilotAgents schema #292

Merged
merged 13 commits into from
Oct 15, 2024
1 change: 1 addition & 0 deletions diracx-db/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ testing = [
AuthDB = "diracx.db.sql:AuthDB"
JobDB = "diracx.db.sql:JobDB"
JobLoggingDB = "diracx.db.sql:JobLoggingDB"
PilotAgentsDB = "diracx.db.sql:PilotAgentsDB"
SandboxMetadataDB = "diracx.db.sql:SandboxMetadataDB"
TaskQueueDB = "diracx.db.sql:TaskQueueDB"

Expand Down
10 changes: 9 additions & 1 deletion diracx-db/src/diracx/db/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB")
__all__ = (
"AuthDB",
"JobDB",
"JobLoggingDB",
"PilotAgentsDB",
"SandboxMetadataDB",
"TaskQueueDB",
)

from .auth.db import AuthDB
from .job.db import JobDB
from .job_logging.db import JobLoggingDB
from .pilot_agents.db import PilotAgentsDB
from .sandbox_metadata.db import SandboxMetadataDB
from .task_queue.db import TaskQueueDB
29 changes: 1 addition & 28 deletions diracx-db/src/diracx/db/sql/job/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sqlalchemy.types as types
from sqlalchemy import (
DateTime,
Enum,
Expand All @@ -10,37 +9,11 @@
)
from sqlalchemy.orm import declarative_base

from ..utils import Column, NullColumn
from ..utils import Column, EnumBackedBool, NullColumn

JobDBBase = declarative_base()


class EnumBackedBool(types.TypeDecorator):
"""Maps a ``EnumBackedBool()`` column to True/False in Python."""

impl = types.Enum
cache_ok: bool = True

def __init__(self) -> None:
super().__init__("True", "False")

def process_bind_param(self, value, dialect) -> str:
if value is True:
return "True"
elif value is False:
return "False"
else:
raise NotImplementedError(value, dialect)

def process_result_value(self, value, dialect) -> bool:
if value == "True":
return True
elif value == "False":
return False
else:
raise NotImplementedError(f"Unknown {value=}")


class Jobs(JobDBBase):
__tablename__ = "Jobs"

Expand Down
Empty file.
46 changes: 46 additions & 0 deletions diracx-db/src/diracx/db/sql/pilot_agents/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from datetime import datetime, timezone

from sqlalchemy import insert

from ..utils import BaseSQLDB
from .schema import PilotAgents, PilotAgentsDBBase


class PilotAgentsDB(BaseSQLDB):
"""PilotAgentsDB class is a front-end to the PilotAgents Database."""

metadata = PilotAgentsDBBase.metadata

async def add_pilot_references(
self,
pilot_ref: list[str],
vo: str,
grid_type: str = "DIRAC",
pilot_stamps: dict | None = None,
) -> None:

if pilot_stamps is None:
pilot_stamps = {}

now = datetime.now(tz=timezone.utc)

# Prepare the list of dictionaries for bulk insertion
values = [
{
"PilotJobReference": ref,
"VO": vo,
"GridType": grid_type,
"SubmissionTime": now,
"LastUpdateTime": now,
"Status": "Submitted",
"PilotStamp": pilot_stamps.get(ref, ""),
}
for ref in pilot_ref
]

# Insert multiple rows in a single execute call
stmt = insert(PilotAgents).values(values)
await self.conn.execute(stmt)
return
58 changes: 58 additions & 0 deletions diracx-db/src/diracx/db/sql/pilot_agents/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from sqlalchemy import (
DateTime,
Double,
Index,
Integer,
String,
Text,
)
from sqlalchemy.orm import declarative_base

from ..utils import Column, EnumBackedBool, NullColumn

PilotAgentsDBBase = declarative_base()


class PilotAgents(PilotAgentsDBBase):
__tablename__ = "PilotAgents"

PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True)
InitialJobID = Column("InitialJobID", Integer, default=0)
CurrentJobID = Column("CurrentJobID", Integer, default=0)
PilotJobReference = Column("PilotJobReference", String(255), default="Unknown")
PilotStamp = Column("PilotStamp", String(32), default="")
DestinationSite = Column("DestinationSite", String(128), default="NotAssigned")
Queue = Column("Queue", String(128), default="Unknown")
GridSite = Column("GridSite", String(128), default="Unknown")
VO = Column("VO", String(128))
GridType = Column("GridType", String(32), default="LCG")
BenchMark = Column("BenchMark", Double, default=0.0)
SubmissionTime = NullColumn("SubmissionTime", DateTime)
LastUpdateTime = NullColumn("LastUpdateTime", DateTime)
Status = Column("Status", String(32), default="Unknown")
StatusReason = Column("StatusReason", String(255), default="Unknown")
AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False)

__table_args__ = (
Index("PilotJobReference", "PilotJobReference"),
Index("Status", "Status"),
Index("Statuskey", "GridSite", "DestinationSite", "Status"),
)


class JobToPilotMapping(PilotAgentsDBBase):
__tablename__ = "JobToPilotMapping"

PilotID = Column("PilotID", Integer, primary_key=True)
JobID = Column("JobID", Integer, primary_key=True)
StartTime = Column("StartTime", DateTime)

__table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID"))


class PilotOutput(PilotAgentsDBBase):
__tablename__ = "PilotOutput"

PilotID = Column("PilotID", Integer, primary_key=True)
StdOutput = Column("StdOutput", Text)
StdError = Column("StdError", Text)
27 changes: 27 additions & 0 deletions diracx-db/src/diracx/db/sql/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from functools import partial
from typing import TYPE_CHECKING, Self, cast

import sqlalchemy.types as types
from pydantic import TypeAdapter
from sqlalchemy import Column as RawColumn
from sqlalchemy import DateTime, Enum, MetaData, select
Expand Down Expand Up @@ -128,6 +129,32 @@ def EnumColumn(enum_type, **kwargs):
return Column(Enum(enum_type, native_enum=False, length=16), **kwargs)


class EnumBackedBool(types.TypeDecorator):
"""Maps a ``EnumBackedBool()`` column to True/False in Python."""

impl = types.Enum
cache_ok: bool = True

def __init__(self) -> None:
super().__init__("True", "False")

def process_bind_param(self, value, dialect) -> str:
if value is True:
return "True"
elif value is False:
return "False"
else:
raise NotImplementedError(value, dialect)

def process_result_value(self, value, dialect) -> bool:
if value == "True":
return True
elif value == "False":
return False
else:
raise NotImplementedError(f"Unknown {value=}")


class SQLDBError(Exception):
pass

Expand Down
Empty file.
31 changes: 31 additions & 0 deletions diracx-db/tests/pilot_agents/test_pilotAgentsDB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import pytest

from diracx.db.sql.pilot_agents.db import PilotAgentsDB


@pytest.fixture
async def pilot_agents_db(tmp_path) -> PilotAgentsDB:
agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:")
async with agents_db.engine_context():
async with agents_db.engine.begin() as conn:
await conn.run_sync(agents_db.metadata.create_all)
yield agents_db


async def test_insert_and_select(pilot_agents_db: PilotAgentsDB):

async with pilot_agents_db as pilot_agents_db:
# Add a pilot reference
refs = [f"ref_{i}" for i in range(10)]
stamps = [f"stamp_{i}" for i in range(10)]
stamp_dict = dict(zip(refs, stamps))

await pilot_agents_db.add_pilot_references(
refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict
)

await pilot_agents_db.add_pilot_references(
refs, "test_vo", grid_type="DIRAC", pilot_stamps=None
)
3 changes: 3 additions & 0 deletions diracx-routers/src/diracx/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"JobLoggingDB",
"SandboxMetadataDB",
"TaskQueueDB",
"PilotAgentsDB",
"add_settings_annotation",
"AvailableSecurityProperties",
)
Expand All @@ -23,6 +24,7 @@
from diracx.db.sql import AuthDB as _AuthDB
from diracx.db.sql import JobDB as _JobDB
from diracx.db.sql import JobLoggingDB as _JobLoggingDB
from diracx.db.sql import PilotAgentsDB as _PilotAgentsDB
from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB
from diracx.db.sql import TaskQueueDB as _TaskQueueDB

Expand All @@ -38,6 +40,7 @@ def add_settings_annotation(cls: T) -> T:
AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)]
JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)]
JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)]
PilotAgentsDB = Annotated[_PilotAgentsDB, Depends(_PilotAgentsDB.transaction)]
SandboxMetadataDB = Annotated[
_SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction)
]
Expand Down