Skip to content

Commit

Permalink
Merge pull request #72 from natthan-pigoux/feature/npigoux-reschedule
Browse files Browse the repository at this point in the history
Feature rescheduleJob method
  • Loading branch information
chrisburr authored Oct 19, 2023
2 parents 162e37d + 1d21729 commit 2cb5a85
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/diracx/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class JobStatus(StrEnum):
RESCHEDULED = "Rescheduled"


class JobMinorStatus(str, Enum):
MAX_RESCHEDULING = "Maximum of reschedulings reached"
RESCHEDULED = "Job Rescheduled"


class JobStatusUpdate(BaseModel):
status: JobStatus | None = Field(
default=None,
Expand Down
166 changes: 165 additions & 1 deletion src/diracx/db/sql/jobs/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import time
from datetime import datetime, timezone
from typing import Any
Expand All @@ -8,7 +9,14 @@
from sqlalchemy.exc import NoResultFound

from diracx.core.exceptions import InvalidQueryError
from diracx.core.models import JobStatus, JobStatusReturn, LimitedJobStatusReturn
from diracx.core.models import (
JobMinorStatus,
JobStatus,
JobStatusReturn,
LimitedJobStatusReturn,
ScalarSearchOperator,
ScalarSearchSpec,
)

from ..utils import BaseSQLDB, apply_search_filters
from .schema import (
Expand Down Expand Up @@ -40,6 +48,11 @@ class JobDB(BaseSQLDB):
# to find a way to make it dynamic
jdl2DBParameters = ["JobName", "JobType", "JobGroup"]

# TODO: set maxRescheduling value from CS
# maxRescheduling = self.getCSOption("MaxRescheduling", 3)
# For now:
maxRescheduling = 3

async def summary(self, group_by, search) -> list[dict[str, str | int]]:
columns = _get_columns(Jobs.__table__, group_by)

Expand Down Expand Up @@ -155,6 +168,20 @@ async def setJobJDL(self, job_id, jdl):
)
await self.conn.execute(stmt)

async def getJobJDL(self, job_id: int, original: bool = False) -> str:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL

if original:
stmt = select(JobJDLs.OriginalJDL).where(JobJDLs.JobID == job_id)
else:
stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id)

jdl = (await self.conn.execute(stmt)).scalar_one()
if jdl:
jdl = extractJDL(jdl)

return jdl

async def insert(
self,
jdl,
Expand Down Expand Up @@ -258,6 +285,143 @@ async def insert(
"TimeStamp": datetime.now(tz=timezone.utc),
}

async def rescheduleJob(self, job_id) -> dict[str, Any]:
"""Reschedule given job"""
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.Core.Utilities.ReturnValues import SErrorException

result = await self.search(
parameters=[
"Status",
"MinorStatus",
"VerifiedFlag",
"RescheduleCounter",
"Owner",
"OwnerGroup",
],
search=[
ScalarSearchSpec(
parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=job_id
)
],
sorts=[],
)
if not result:
raise ValueError(f"Job {job_id} not found.")

jobAttrs = result[0]

if "VerifiedFlag" not in jobAttrs:
raise ValueError(f"Job {job_id} not found in the system")

if not jobAttrs["VerifiedFlag"]:
raise ValueError(
f"Job {job_id} not Verified: Status {jobAttrs['Status']}, Minor Status: {jobAttrs['MinorStatus']}"
)

reschedule_counter = int(jobAttrs["RescheduleCounter"]) + 1

# TODO: update maxRescheduling:
# self.maxRescheduling = self.getCSOption("MaxRescheduling", self.maxRescheduling)

if reschedule_counter > self.maxRescheduling:
logging.warn(f"Job {job_id}: Maximum number of reschedulings is reached.")
self.setJobAttributes(
job_id,
{
"Status": JobStatus.FAILED,
"MinorStatus": JobMinorStatus.MAX_RESCHEDULING,
},
)
raise ValueError(
f"Maximum number of reschedulings is reached: {self.maxRescheduling}"
)

new_job_attributes = {"RescheduleCounter": reschedule_counter}

# TODO: get the job parameters from JobMonitoringClient
# result = JobMonitoringClient().getJobParameters(jobID)
# if result["OK"]:
# parDict = result["Value"]
# for key, value in parDict.get(jobID, {}).items():
# result = self.setAtticJobParameter(jobID, key, value, rescheduleCounter - 1)
# if not result["OK"]:
# break

# TODO: IF we keep JobParameters and OptimizerParameters: Delete job in those tables.
# await self.delete_job_parameters(job_id)
# await self.delete_job_optimizer_parameters(job_id)

job_jdl = await self.getJobJDL(job_id, original=True)
if not job_jdl.strip().startswith("["):
job_jdl = f"[{job_jdl}]"

classAdJob = ClassAd(job_jdl)
classAdReq = ClassAd("[]")
retVal = {}
retVal["JobID"] = job_id

classAdJob.insertAttributeInt("JobID", job_id)

try:
result = self._checkAndPrepareJob(
job_id,
classAdJob,
classAdReq,
jobAttrs["Owner"],
jobAttrs["OwnerGroup"],
new_job_attributes,
classAdJob.getAttributeString("VirtualOrganization"),
)
except SErrorException as e:
raise ValueError(e) from e

priority = classAdJob.getAttributeInt("Priority")
if priority is None:
priority = 0
jobAttrs["UserPriority"] = priority

siteList = classAdJob.getListFromExpression("Site")
if not siteList:
site = "ANY"
elif len(siteList) > 1:
site = "Multiple"
else:
site = siteList[0]

jobAttrs["Site"] = site

jobAttrs["Status"] = JobStatus.RECEIVED

jobAttrs["MinorStatus"] = JobMinorStatus.RESCHEDULED

jobAttrs["ApplicationStatus"] = "Unknown"

jobAttrs["ApplicationNumStatus"] = 0

jobAttrs["LastUpdateTime"] = str(datetime.utcnow())

jobAttrs["RescheduleTime"] = str(datetime.utcnow())

reqJDL = classAdReq.asJDL()
classAdJob.insertAttributeInt("JobRequirements", reqJDL)

jobJDL = classAdJob.asJDL()

# Replace the JobID placeholder if any
jobJDL = jobJDL.replace("%j", str(job_id))

result = self.setJobJDL(job_id, jobJDL)

result = self.setJobAttributes(job_id, jobAttrs)

retVal["InputData"] = classAdJob.lookupAttribute("InputData")
retVal["RescheduleCounter"] = reschedule_counter
retVal["Status"] = JobStatus.RECEIVED
retVal["MinorStatus"] = JobMinorStatus.RESCHEDULED

return retVal

async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where(
Jobs.JobID == job_id
Expand Down
64 changes: 64 additions & 0 deletions src/diracx/routers/job_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,70 @@ async def get_job_status_history_bulk(
return {job_id: status for job_id, status in zip(job_ids, result)}


@router.post("/reschedule")
async def reschedule_bulk_jobs(
job_ids: Annotated[list[int], Query()],
job_db: JobDB,
job_logging_db: JobLoggingDB,
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)],
):
rescheduled_jobs = []
# TODO: Joblist Policy:
# validJobList, invalidJobList, nonauthJobList, ownerJobList = self.jobPolicy.evaluateJobRights(
# jobList, RIGHT_RESCHEDULE
# )
# For the moment all jobs are valid:
valid_job_list = job_ids
for job_id in valid_job_list:
# TODO: delete job in TaskQueueDB
# self.taskQueueDB.deleteJob(jobID)
result = job_db.rescheduleJob(job_id)
try:
res_status = await job_db.get_job_status(job_id)
except NoResultFound as e:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found"
) from e

initial_status = res_status.status
initial_minor_status = res_status.minor_status

await job_logging_db.insert_record(
int(job_id),
initial_status,
initial_minor_status,
"Unknown",
datetime.now(timezone.utc),
"JobManager",
)
if result:
rescheduled_jobs.append(job_id)
# To uncomment when jobPolicy is setup:
# if invalid_job_list or non_auth_job_list:
# logging.error("Some jobs failed to reschedule")
# if invalid_job_list:
# logging.info(f"Invalid jobs: {invalid_job_list}")
# if non_auth_job_list:
# logging.info(f"Non authorized jobs: {nonauthJobList}")

# TODO: send jobs to OtimizationMind
# self.__sendJobsToOptimizationMind(validJobList)
return rescheduled_jobs


@router.post("/{job_id}/reschedule")
async def reschedule_single_job(
job_id: int,
job_db: JobDB,
user_info: Annotated[UserInfo, Depends(verify_dirac_access_token)],
):
try:
result = await job_db.rescheduleJob(job_id)
except ValueError as e:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e
return result


EXAMPLE_SEARCHES = {
"Show all": {
"summary": "Show all",
Expand Down
16 changes: 16 additions & 0 deletions tests/routers/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,19 @@ def test_set_job_status_with_invalid_job_id(normal_user_client: TestClient):
# Assert
assert r.status_code == 404, r.json()
assert r.json() == {"detail": "Job 999999999 not found"}


def test_insert_and_reschedule(normal_user_client: TestClient):
job_definitions = [TEST_JDL]
r = normal_user_client.post("/jobs/", json=job_definitions)
assert r.status_code == 200, r.json()
assert len(r.json()) == len(job_definitions)

submitted_job_ids = sorted([job_dict["JobID"] for job_dict in r.json()])

# Test /jobs/reschedule
r = normal_user_client.post(
"/jobs/reschedule",
params={"job_ids": submitted_job_ids},
)
assert r.status_code == 200, r.json()

0 comments on commit 2cb5a85

Please sign in to comment.