Skip to content

Commit

Permalink
Fixing tests...
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuwd committed Dec 14, 2024
1 parent 97271ca commit 46b9a4d
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 84 deletions.
2 changes: 1 addition & 1 deletion diracx-cli/tests/test_jobs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from io import StringIO
import json
import os
import tempfile
from io import StringIO

import pytest
from pytest import raises
Expand Down
12 changes: 8 additions & 4 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def _get_columns(table, parameters):
columns = [c for c in columns if c.name in parameters]
return columns


async def get_inserted_job_ids(conn, table, rows):
# TODO: We are assuming contiguous inserts for MySQL. Is that the correct thing? Should we be stricter
# about enforcing that with an explicit transaction handling?
# Retrieve the first inserted ID

if conn.engine.name == "mysql":
# Bulk insert for MySQL
await conn.execute(table.insert(), rows)
Expand Down Expand Up @@ -216,7 +217,9 @@ async def setJobAttributesBulk(self, jobData):
)

await self.conn.execute(
Jobs.__table__.update().where(Jobs.__table__.c.JobID == bindparam("b_JobID")),
Jobs.__table__.update().where(
Jobs.__table__.c.JobID == bindparam("b_JobID")
),
[{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()],
)

Expand Down Expand Up @@ -309,7 +312,6 @@ async def insert_bulk(
print(jobManifest_.__dict__)
jobManifest_.setOption("JobID", job_id)


# 2.- Check JDL and Prepare DIRAC JDL
jobJDL = jobManifest_.dumpAsJDL()

Expand Down Expand Up @@ -366,7 +368,9 @@ async def insert_bulk(
{"JobID": job_id, "LFN": lfn} for lfn in inputData if lfn
]
await self.conn.execute(
JobJDLs.__table__.update().where(JobJDLs.__table__.c.JobID == bindparam("b_JobID")),
JobJDLs.__table__.update().where(
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
),
jdls_to_update,
)

Expand Down
12 changes: 8 additions & 4 deletions diracx-db/src/diracx/db/sql/job_logging/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,16 @@ def get_epoc(date):
+ date.microsecond / 1000000.0
- MAGIC_EPOC_NUMBER
)

# First, fetch the maximum SeqNums for the given job_ids
seqnum_stmt = select(LoggingInfo.JobID, func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)).where(
LoggingInfo.JobID.in_([record.job_id for record in records])
).group_by(LoggingInfo.JobID)
seqnum_stmt = (
select(
LoggingInfo.JobID, func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)
)
.where(LoggingInfo.JobID.in_([record.job_id for record in records]))
.group_by(LoggingInfo.JobID)
)

# seqnum_stmt = select(LoggingInfo.JobID, func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)).where(LoggingInfo.JobID.in_([record.job_id for record in records]))
seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))}
# IF a seqnum is not found, then assume it does not exist and the first sequence number is 1.

Expand Down
25 changes: 11 additions & 14 deletions diracx-db/src/diracx/db/sql/utils/job_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def set_job_status_bulk(
*,
force: bool = False,
additional_attributes: dict[int, dict[str, str]] = {},
) -> dict[str, dict[int, Union[SetJobStatusReturn, dict[str, str]] ]]:
) -> dict[str, dict[int, Union[SetJobStatusReturn, dict[str, str]]]]:
"""Set various status fields for job specified by its jobId.
Set only the last status in the JobDB, updating all the status
logging information in the JobLoggingDB. The status dict has datetime
Expand Down Expand Up @@ -261,15 +261,14 @@ async def set_job_status_bulk(
}
print(f"Status changes requested {status_dicts=}")


# search all jobs at once
_, results = await job_db.search(
parameters=["Status", "StartExecTime", "EndExecTime", "JobID"],
search=[
{
"parameter": "JobID",
"operator": VectorSearchOperator.IN,
"values": set(status_changes.keys()),
"values": list(set(status_changes.keys())),
}
],
sorts=[],
Expand All @@ -282,15 +281,16 @@ async def set_job_status_bulk(
}

found_jobs = set(int(res["JobID"]) for res in results)
failed.update({
int(nf_job_id): {"detail": "Not found"}
for nf_job_id in set(status_changes.keys()) - found_jobs
})
failed.update(
{
int(nf_job_id): {"detail": "Not found"}
for nf_job_id in set(status_changes.keys()) - found_jobs
}
)
# Get the latest time stamps of major status updates
wms_time_stamps = await job_logging_db.get_wms_time_stamps_bulk(found_jobs)
print("timestamps", wms_time_stamps)


for res in results:
job_id = int(res["JobID"])
currentStatus = res["Status"]
Expand Down Expand Up @@ -318,7 +318,7 @@ async def set_job_status_bulk(

print("updateTimes", updateTimes, "lastTime", lastTime)

job_data = {}
job_data: dict[str, str] = {}
if updateTimes[-1] >= lastTime:
new_status, new_minor, new_application = (
returnValueOrRaise( # TODO: Catch this
Expand All @@ -332,12 +332,12 @@ async def set_job_status_bulk(
MagicMock(), # FIXME
)
)
)
)
print(f"statusDict is {statusDict}")
print(f"Update state to {new_status}")

if new_status:
job_data.update(additional_attributes)
job_data.update(additional_attributes.get(job_id, {}))
job_data["Status"] = new_status
job_data["LastUpdateTime"] = str(datetime.now(timezone.utc))
if new_minor:
Expand Down Expand Up @@ -399,9 +399,6 @@ async def set_job_status_bulk(

await job_logging_db.bulk_insert_record(job_logging_updates)

print(f"{job_attribute_updates=}")
print(f"{job_logging_updates=}")

return {
"success": job_attribute_updates,
"failed": failed,
Expand Down
4 changes: 2 additions & 2 deletions diracx-routers/src/diracx/routers/jobs/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from datetime import datetime
from http import HTTPStatus
from typing import Annotated, Union
from typing import Annotated

from fastapi import BackgroundTasks, HTTPException, Query

Expand Down Expand Up @@ -82,7 +82,7 @@ async def set_job_statuses(
background_task: BackgroundTasks,
check_permissions: CheckWMSPolicyCallable,
force: bool = False,
) -> dict[str, dict[int, SetJobStatusReturn | dict[str, str] ]]:
) -> dict[str, dict[int, SetJobStatusReturn | dict[str, str]]]:
await check_permissions(
action=ActionType.MANAGE, job_db=job_db, job_ids=list(job_update)
)
Expand Down
2 changes: 1 addition & 1 deletion diracx-routers/src/diracx/routers/jobs/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True):
# FIXME why?
return result
jobDescList = result["Value"]
else:
else:
# if we are here, then jobDesc was the description of a single job.
jobDescList = job_definitions
else:
Expand Down
Loading

0 comments on commit 46b9a4d

Please sign in to comment.