From fb55adc81be7c25a93a261b06a80ad0976895dd9 Mon Sep 17 00:00:00 2001 From: Ryunosuke O'Neil Date: Thu, 19 Dec 2024 15:05:03 +0100 Subject: [PATCH 1/3] Fix job attribute update to account for mismatching columns between rows to be updated --- diracx-db/src/diracx/db/sql/job/db.py | 21 +++++++++++++++------ diracx-db/src/diracx/db/sql/utils/job.py | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 145b4eb6..fb677fe9 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any -from sqlalchemy import bindparam, delete, func, insert, select, update +from sqlalchemy import bindparam, delete, func, insert, select, update, case from sqlalchemy.exc import IntegrityError, NoResultFound if TYPE_CHECKING: @@ -219,13 +219,22 @@ async def setJobAttributesBulk(self, jobData): jobData[job_id].update( {"LastUpdateTime": datetime.now(tz=timezone.utc)} ) + columns = set(key for attrs in jobData.values() for key in attrs.keys()) + case_expressions = { + column: case( + *[ + (Jobs.__table__.c.JobID == job_id, attrs[column]) + for job_id, attrs in jobData.items() if column in attrs + ], + else_=getattr(Jobs.__table__.c, column) # Retain original value + ) + for column in columns + } - await self.conn.execute( - Jobs.__table__.update().where( - Jobs.__table__.c.JobID == bindparam("b_JobID") - ), - [{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()], + stmt = Jobs.__table__.update().values(**case_expressions).where( + Jobs.__table__.c.JobID.in_(jobData.keys()) ) + await self.conn.execute(stmt) async def getJobJDL(self, job_id: int, original: bool = False) -> str: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL diff --git a/diracx-db/src/diracx/db/sql/utils/job.py b/diracx-db/src/diracx/db/sql/utils/job.py index 0032ff1d..544e3773 100644 --- a/diracx-db/src/diracx/db/sql/utils/job.py +++ b/diracx-db/src/diracx/db/sql/utils/job.py @@ -325,11 +325,11 @@ def parse_jdl(job_id, job_jdl): "failed": failed, "success": { job_id: { - "InputData": job_jdls[job_id], + "InputData": job_jdls.get(job_id, None), **attribute_changes[job_id], **set_status_result.model_dump(), } - for job_id, set_status_result in set_job_status_result.success.items() + for job_id, set_status_result in set_job_status_result.success.items() if job_id not in failed }, } From 78220f1485a625780df255ad876a0b985ea9a791 Mon Sep 17 00:00:00 2001 From: Ryunosuke O'Neil Date: Thu, 19 Dec 2024 15:06:52 +0100 Subject: [PATCH 2/3] Pre-commit --- diracx-db/src/diracx/db/sql/job/db.py | 13 ++++++++----- diracx-db/src/diracx/db/sql/utils/job.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index fb677fe9..7817bb39 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any -from sqlalchemy import bindparam, delete, func, insert, select, update, case +from sqlalchemy import bindparam, case, delete, func, insert, select, update from sqlalchemy.exc import IntegrityError, NoResultFound if TYPE_CHECKING: @@ -224,15 +224,18 @@ async def setJobAttributesBulk(self, jobData): column: case( *[ (Jobs.__table__.c.JobID == job_id, attrs[column]) - for job_id, attrs in jobData.items() if column in attrs + for job_id, attrs in jobData.items() + if column in attrs ], - else_=getattr(Jobs.__table__.c, column) # Retain original value + else_=getattr(Jobs.__table__.c, column), # Retain original value ) for column in columns } - stmt = Jobs.__table__.update().values(**case_expressions).where( - Jobs.__table__.c.JobID.in_(jobData.keys()) + stmt = ( + Jobs.__table__.update() + .values(**case_expressions) + .where(Jobs.__table__.c.JobID.in_(jobData.keys())) ) await self.conn.execute(stmt) diff --git a/diracx-db/src/diracx/db/sql/utils/job.py b/diracx-db/src/diracx/db/sql/utils/job.py index 544e3773..16ed5ba7 100644 --- a/diracx-db/src/diracx/db/sql/utils/job.py +++ b/diracx-db/src/diracx/db/sql/utils/job.py @@ -329,7 +329,8 @@ def parse_jdl(job_id, job_jdl): **attribute_changes[job_id], **set_status_result.model_dump(), } - for job_id, set_status_result in set_job_status_result.success.items() if job_id not in failed + for job_id, set_status_result in set_job_status_result.success.items() + if job_id not in failed }, } From ed019f15fcba43639eecd9d7ddce822f84f281f0 Mon Sep 17 00:00:00 2001 From: Ryunosuke O'Neil Date: Thu, 19 Dec 2024 15:49:53 +0100 Subject: [PATCH 3/3] Added test for this case --- diracx-routers/tests/test_job_manager.py | 74 ++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index 5330ef63..59b1f6c0 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -934,6 +934,80 @@ def test_insert_and_reschedule(normal_user_client: TestClient): } +## test edge case for rescheduling + + +def test_reschedule_job_attr_update(normal_user_client: TestClient): + job_definitions = [TEST_JDL] * 15 + + r = normal_user_client.post("/api/jobs/jdl", json=job_definitions) + assert r.status_code == 200, r.json() + assert len(r.json()) == len(job_definitions) + + submitted_job_ids = sorted([job_dict["JobID"] for job_dict in r.json()]) + + # Test /jobs/reschedule and + # test max_reschedule + + max_resched = 3 + + fail_resched_ids = submitted_job_ids[0:5] + good_resched_ids = list(set(submitted_job_ids) - set(fail_resched_ids)) + + for i in range(max_resched): + r = normal_user_client.post( + "/api/jobs/reschedule", + params={"job_ids": fail_resched_ids}, + ) + assert r.status_code == 200, r.json() + result = r.json() + successful_results = result["success"] + for jid in fail_resched_ids: + assert str(jid) in successful_results, result + assert successful_results[str(jid)]["Status"] == JobStatus.RECEIVED + assert successful_results[str(jid)]["MinorStatus"] == "Job Rescheduled" + assert successful_results[str(jid)]["RescheduleCounter"] == i + 1 + + for i in range(max_resched): + r = normal_user_client.post( + "/api/jobs/reschedule", + params={"job_ids": submitted_job_ids}, + ) + assert r.status_code == 200, r.json() + result = r.json() + successful_results = result["success"] + failed_results = result["failed"] + for jid in good_resched_ids: + assert str(jid) in successful_results, result + assert successful_results[str(jid)]["Status"] == JobStatus.RECEIVED + assert successful_results[str(jid)]["MinorStatus"] == "Job Rescheduled" + assert successful_results[str(jid)]["RescheduleCounter"] == i + 1 + for jid in fail_resched_ids: + assert str(jid) in failed_results, result + # assert successful_results[jid]["Status"] == JobStatus.RECEIVED + # assert successful_results[jid]["MinorStatus"] == "Job Rescheduled" + # assert successful_results[jid]["RescheduleCounter"] == i + 1 + + r = normal_user_client.post( + "/api/jobs/reschedule", + params={"job_ids": submitted_job_ids}, + ) + assert ( + r.status_code != 200 + ), f"Rescheduling more than {max_resched} times should have failed by now {r.json()}" + assert r.json() == { + "detail": { + "success": [], + "failed": { + str(i): { + "detail": f"Maximum number of reschedules exceeded ({max_resched})" + } + for i in good_resched_ids + fail_resched_ids + }, + } + } + + # Test delete job