From be8361aa32e01952bcf61420e52c486d132d6e2c Mon Sep 17 00:00:00 2001 From: "Akihiko (Aki) Kuroda" <16141898+akihikokuroda@users.noreply.github.com> Date: Fri, 10 May 2024 08:19:20 -0400 Subject: [PATCH] Cancel (#1318) * cancel backend runtime jobs in qiskit_serverless job.stop() Signed-off-by: Akihiko Kuroda --- client/qiskit_serverless/core/job.py | 22 ++++++++++++++++------ gateway/api/views.py | 25 +++++++++++++++++++++++++ gateway/requirements.txt | 1 + 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/client/qiskit_serverless/core/job.py b/client/qiskit_serverless/core/job.py index acb6253dd..c93d657b8 100644 --- a/client/qiskit_serverless/core/job.py +++ b/client/qiskit_serverless/core/job.py @@ -49,6 +49,7 @@ from ray.dashboard.modules.job.sdk import JobSubmissionClient from opentelemetry import trace +from qiskit_ibm_runtime import QiskitRuntimeService from qiskit_serverless.core.constants import ( OT_PROGRAM_NAME, @@ -122,7 +123,7 @@ def status(self, job_id: str): """Check status.""" raise NotImplementedError - def stop(self, job_id: str): + def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None): """Stops job/program.""" raise NotImplementedError @@ -158,7 +159,7 @@ def __init__(self, client: JobSubmissionClient): def status(self, job_id: str): return self._job_client.get_job_status(job_id).value - def stop(self, job_id: str): + def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None): return self._job_client.stop_job(job_id) def logs(self, job_id: str): @@ -230,7 +231,7 @@ def __init__(self): def status(self, job_id: str): return self._jobs[job_id]["status"] - def stop(self, job_id: str): + def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None): """Stops job/program.""" return f"job:{job_id} has already stopped" @@ -418,14 +419,23 @@ def status(self, job_id: str): return response_data.get("status", default_status) - def stop(self, job_id: str): + def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None): tracer = trace.get_tracer("client.tracer") with tracer.start_as_current_span("job.stop"): + if service: + data = { + "service": json.dumps(service, cls=QiskitObjectsEncoder), + } + else: + data = { + "service": None, + } response_data = safe_json_request( request=lambda: requests.post( f"{self.host}/api/{self.version}/jobs/{job_id}/stop/", headers={"Authorization": f"Bearer {self._token}"}, timeout=REQUESTS_TIMEOUT, + json=data, ) ) @@ -556,9 +566,9 @@ def status(self): """Returns status of the job.""" return _map_status_to_serverless(self._job_client.status(self.job_id)) - def stop(self): + def stop(self, service: Optional[QiskitRuntimeService] = None): """Stops the job from running.""" - return self._job_client.stop(self.job_id) + return self._job_client.stop(self.job_id, service=service) def logs(self) -> str: """Returns logs of the job.""" diff --git a/gateway/api/views.py b/gateway/api/views.py index b4cb49ced..097924295 100644 --- a/gateway/api/views.py +++ b/gateway/api/views.py @@ -29,6 +29,8 @@ from rest_framework.decorators import action from rest_framework.generics import get_object_or_404 from rest_framework.response import Response + +from qiskit_ibm_runtime import RuntimeInvalidStateError, QiskitRuntimeService from utils import sanitize_file_path from .models import VIEW_PROGRAM_PERMISSION, Program, Job, RuntimeJob @@ -318,6 +320,10 @@ def logs(self, request, pk=None): # pylint: disable=invalid-name,unused-argumen logs = job.logs return Response({"logs": logs}) + def get_runtime_job(self, job): + """get runtime job for job""" + return RuntimeJob.objects.filter(job=job) + @action(methods=["POST"], detail=True) def stop(self, request, pk=None): # pylint: disable=invalid-name,unused-argument """Stops job""" @@ -329,6 +335,25 @@ def stop(self, request, pk=None): # pylint: disable=invalid-name,unused-argumen job.status = Job.STOPPED job.save(update_fields=["status"]) message = "Job has been stopped." + runtime_jobs = self.get_runtime_job(job) + if runtime_jobs and len(runtime_jobs) != 0: + if request.data.get("service"): + service = QiskitRuntimeService( + **json.loads(request.data.get("service"), cls=json.JSONDecoder)[ + "__value__" + ] + ) + for runtime_job_entry in runtime_jobs: + jobinstance = service.job(runtime_job_entry.runtime_job) + if jobinstance: + try: + logger.info( + "canceling [%s]", runtime_job_entry.runtime_job + ) + jobinstance.cancel() + except RuntimeInvalidStateError: + logger.warning("cancel failed") + if job.compute_resource: if job.compute_resource.active: job_handler = get_job_handler(job.compute_resource.host) diff --git a/gateway/requirements.txt b/gateway/requirements.txt index da9310d09..b5f465f3a 100644 --- a/gateway/requirements.txt +++ b/gateway/requirements.txt @@ -19,4 +19,5 @@ drf-yasg>=1.21.7 cryptography>=41.0.1 # Django dependency, but we need a newer version (IBMQ#246) sqlparse>=0.5.0 +qiskit_ibm_runtime>=0.22.0