Skip to content

Commit

Permalink
add runtime job cancel in serverless job.stop()
Browse files Browse the repository at this point in the history
  • Loading branch information
akihikokuroda committed May 7, 2024
1 parent dbde93e commit fea165f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 13 deletions.
22 changes: 16 additions & 6 deletions client/qiskit_serverless/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ray.dashboard.modules.job.sdk import JobSubmissionClient

from opentelemetry import trace
from qiskit_ibm_runtime import QiskitRuntimeService

from qiskit_serverless.core.constants import (
OT_PROGRAM_NAME,
Expand Down Expand Up @@ -130,7 +131,7 @@ def status(self, job_id: str):
"""Check status."""
raise NotImplementedError

def stop(self, job_id: str):
def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None):
"""Stops job/program."""
raise NotImplementedError

Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(self, client: JobSubmissionClient):
def status(self, job_id: str):
return self._job_client.get_job_status(job_id).value

def stop(self, job_id: str):
def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None):
return self._job_client.stop_job(job_id)

def logs(self, job_id: str):
Expand Down Expand Up @@ -239,7 +240,7 @@ def __init__(self):
def status(self, job_id: str):
return self._jobs[job_id]["status"]

def stop(self, job_id: str):
def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None):
"""Stops job/program."""
return f"job:{job_id} has already stopped"

Expand Down Expand Up @@ -528,14 +529,23 @@ def status(self, job_id: str):

return response_data.get("status", default_status)

def stop(self, job_id: str):
def stop(self, job_id: str, service: Optional[QiskitRuntimeService] = None):
tracer = trace.get_tracer("client.tracer")
with tracer.start_as_current_span("job.stop"):
if service:
data = {
"service": json.dumps(service, cls=QiskitObjectsEncoder),
}
else:
data = {
"service": None,
}
response_data = safe_json_request(
request=lambda: requests.post(
f"{self.host}/api/{self.version}/jobs/{job_id}/stop/",
headers={"Authorization": f"Bearer {self._token}"},
timeout=REQUESTS_TIMEOUT,
json=data,
)
)

Expand Down Expand Up @@ -668,9 +678,9 @@ def status(self):
"""Returns status of the job."""
return _map_status_to_serverless(self._job_client.status(self.job_id))

def stop(self):
def stop(self, service: Optional[QiskitRuntimeService] = None):
"""Stops the job from running."""
return self._job_client.stop(self.job_id)
return self._job_client.stop(self.job_id, service=service)

def logs(self) -> str:
"""Returns logs of the job."""
Expand Down
30 changes: 24 additions & 6 deletions gateway/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from rest_framework.decorators import action
from rest_framework.generics import get_object_or_404
from rest_framework.response import Response

from quantum_serverless.serializers.program_serializers import (
QiskitObjectsDecoder,
)
from qiskit_ibm_runtime import RuntimeInvalidStateError
from utils import sanitize_file_path

from .models import VIEW_PROGRAM_PERMISSION, Program, Job, RuntimeJob
Expand Down Expand Up @@ -415,9 +420,10 @@ def logs(self, request, pk=None): # pylint: disable=invalid-name,unused-argumen
logs = job.logs
return Response({"logs": logs})

def get_runtime_job(self, job):
def get_runtime_job(self, job):
"""get runtime job for job"""
return RuntimeJob.objects.filter(job=job)

@action(methods=["POST"], detail=True)
def stop(self, request, pk=None): # pylint: disable=invalid-name,unused-argument
"""Stops job"""
Expand All @@ -430,10 +436,22 @@ def stop(self, request, pk=None): # pylint: disable=invalid-name,unused-argumen
job.save(update_fields=["status"])
message = "Job has been stopped."
runtime_jobs = self.get_runtime_job(job)
for runtime_job_entry in runtime_jobs:
print(runtime_job_entry.runtime_job)


if runtime_jobs and len(runtime_jobs) != 0:
if request.data.get("service"):
service = json.loads(
request.data.get("service"), cls=QiskitObjectsDecoder
)
for runtime_job_entry in runtime_jobs:
jobinstance = service.job(runtime_job_entry.runtime_job)
if jobinstance:
try:
logger.info(
"canceling [%s]", runtime_job_entry.runtime_job
)
jobinstance.cancel()
except RuntimeInvalidStateError:
logger.warning("cancel failed")

if job.compute_resource:
if job.compute_resource.active:
job_handler = get_job_handler(job.compute_resource.host)
Expand Down
3 changes: 2 additions & 1 deletion gateway/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ drf-yasg>=1.21.7
cryptography>=41.0.1
# Django dependency, but we need a newer version (IBMQ#246)
sqlparse>=0.5.0

qiskit_ibm_runtime
quantum_serverless

0 comments on commit fea165f

Please sign in to comment.