Skip to content

Commit

Permalink
Merge pull request #356 from ryuwd/roneil-jobattr-update-fix
Browse files Browse the repository at this point in the history
Fix job attribute update to account for mismatching columns between rows
  • Loading branch information
aldbr authored Dec 19, 2024
2 parents 97b3c58 + ed019f1 commit 15a1926
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 7 deletions.
24 changes: 18 additions & 6 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any

from sqlalchemy import bindparam, delete, func, insert, select, update
from sqlalchemy import bindparam, case, delete, func, insert, select, update
from sqlalchemy.exc import IntegrityError, NoResultFound

if TYPE_CHECKING:
Expand Down Expand Up @@ -219,13 +219,25 @@ async def setJobAttributesBulk(self, jobData):
jobData[job_id].update(
{"LastUpdateTime": datetime.now(tz=timezone.utc)}
)
columns = set(key for attrs in jobData.values() for key in attrs.keys())
case_expressions = {
column: case(
*[
(Jobs.__table__.c.JobID == job_id, attrs[column])
for job_id, attrs in jobData.items()
if column in attrs
],
else_=getattr(Jobs.__table__.c, column), # Retain original value
)
for column in columns
}

await self.conn.execute(
Jobs.__table__.update().where(
Jobs.__table__.c.JobID == bindparam("b_JobID")
),
[{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()],
stmt = (
Jobs.__table__.update()
.values(**case_expressions)
.where(Jobs.__table__.c.JobID.in_(jobData.keys()))
)
await self.conn.execute(stmt)

async def getJobJDL(self, job_id: int, original: bool = False) -> str:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
Expand Down
3 changes: 2 additions & 1 deletion diracx-db/src/diracx/db/sql/utils/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,12 @@ def parse_jdl(job_id, job_jdl):
"failed": failed,
"success": {
job_id: {
"InputData": job_jdls[job_id],
"InputData": job_jdls.get(job_id, None),
**attribute_changes[job_id],
**set_status_result.model_dump(),
}
for job_id, set_status_result in set_job_status_result.success.items()
if job_id not in failed
},
}

Expand Down
74 changes: 74 additions & 0 deletions diracx-routers/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,80 @@ def test_insert_and_reschedule(normal_user_client: TestClient):
}


## test edge case for rescheduling


def test_reschedule_job_attr_update(normal_user_client: TestClient):
job_definitions = [TEST_JDL] * 15

r = normal_user_client.post("/api/jobs/jdl", json=job_definitions)
assert r.status_code == 200, r.json()
assert len(r.json()) == len(job_definitions)

submitted_job_ids = sorted([job_dict["JobID"] for job_dict in r.json()])

# Test /jobs/reschedule and
# test max_reschedule

max_resched = 3

fail_resched_ids = submitted_job_ids[0:5]
good_resched_ids = list(set(submitted_job_ids) - set(fail_resched_ids))

for i in range(max_resched):
r = normal_user_client.post(
"/api/jobs/reschedule",
params={"job_ids": fail_resched_ids},
)
assert r.status_code == 200, r.json()
result = r.json()
successful_results = result["success"]
for jid in fail_resched_ids:
assert str(jid) in successful_results, result
assert successful_results[str(jid)]["Status"] == JobStatus.RECEIVED
assert successful_results[str(jid)]["MinorStatus"] == "Job Rescheduled"
assert successful_results[str(jid)]["RescheduleCounter"] == i + 1

for i in range(max_resched):
r = normal_user_client.post(
"/api/jobs/reschedule",
params={"job_ids": submitted_job_ids},
)
assert r.status_code == 200, r.json()
result = r.json()
successful_results = result["success"]
failed_results = result["failed"]
for jid in good_resched_ids:
assert str(jid) in successful_results, result
assert successful_results[str(jid)]["Status"] == JobStatus.RECEIVED
assert successful_results[str(jid)]["MinorStatus"] == "Job Rescheduled"
assert successful_results[str(jid)]["RescheduleCounter"] == i + 1
for jid in fail_resched_ids:
assert str(jid) in failed_results, result
# assert successful_results[jid]["Status"] == JobStatus.RECEIVED
# assert successful_results[jid]["MinorStatus"] == "Job Rescheduled"
# assert successful_results[jid]["RescheduleCounter"] == i + 1

r = normal_user_client.post(
"/api/jobs/reschedule",
params={"job_ids": submitted_job_ids},
)
assert (
r.status_code != 200
), f"Rescheduling more than {max_resched} times should have failed by now {r.json()}"
assert r.json() == {
"detail": {
"success": [],
"failed": {
str(i): {
"detail": f"Maximum number of reschedules exceeded ({max_resched})"
}
for i in good_resched_ids + fail_resched_ids
},
}
}


# Test delete job


Expand Down

0 comments on commit 15a1926

Please sign in to comment.