Skip to content

Commit

Permalink
Updated handlers to support async scheduler apis
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Sep 23, 2022
1 parent 92fcadd commit 8d2b0e9
Showing 1 changed file with 60 additions and 31 deletions.
91 changes: 60 additions & 31 deletions jupyter_scheduler/handlers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
import re
from dataclasses import asdict
Expand Down Expand Up @@ -60,16 +61,20 @@ def post(self):

class CreateJobWithDefinitionHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def post(self, job_definition_id: str):
async def post(self, job_definition_id: str):
job_definition = self.scheduler.get_job_definition(job_definition_id)
if job_definition is None:
raise tornado.web.HTTPError(
404, f"Job definition with id: {job_definition_id} not found"
)

payload = self.get_json_body()

job_id = self.scheduler.create_job(CreateJob(**job_definition.dict().merge(payload)))
if inspect.isawaitable(self.scheduler.create_job):
job_id = await self.scheduler.create_job(
CreateJob(**job_definition.dict().merge(payload))
)
else:
job_id = self.scheduler.create_job(CreateJob(**job_definition.dict().merge(payload)))
self.finish(json.dumps(dict(job_id=job_id)))


Expand All @@ -91,76 +96,97 @@ def compute_sort_model(query_argument):

class JobHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def get(self, job_id=None):
async def get(self, job_id=None):
if job_id:
job = self.scheduler.get_job(job_id)
self.finish(job.json())
else:
status = self.get_query_argument("status", None)
start_time = self.get_query_argument("start_time", None)
sort_by = compute_sort_model(self.get_query_arguments("sort_by"))
list_jobs_response = self.scheduler.list_jobs(
ListJobsQuery(
job_definition_id=self.get_query_argument("job_definition_id", None),
status=Status(status.upper()) if status else None,
name=self.get_query_argument("name", None),
tags=self.get_query_arguments("tags", None),
start_time=int(start_time) if start_time else None,
sort_by=sort_by if sort_by else [DEFAULT_SORT],
max_items=self.get_query_argument("max_items", DEFAULT_MAX_ITEMS),
next_token=self.get_query_argument("next_token", None),
)
list_jobs_query = ListJobsQuery(
job_definition_id=self.get_query_argument("job_definition_id", None),
status=Status(status.upper()) if status else None,
name=self.get_query_argument("name", None),
tags=self.get_query_arguments("tags", None),
start_time=int(start_time) if start_time else None,
sort_by=sort_by if sort_by else [DEFAULT_SORT],
max_items=self.get_query_argument("max_items", DEFAULT_MAX_ITEMS),
next_token=self.get_query_argument("next_token", None),
)
if inspect.isawaitable(self.scheduler.list_jobs):
list_jobs_response = await self.scheduler.list_jobs(list_jobs_query)
else:
list_jobs_response = self.scheduler.list_jobs(list_jobs_query)
self.finish(list_jobs_response.json(exclude_none=True))

@tornado.web.authenticated
def post(self):
async def post(self):
payload = self.get_json_body()

job_id = self.scheduler.create_job(CreateJob(**payload))
if inspect.isawaitable(self.scheduler.create_job):
job_id = await self.scheduler.create_job(CreateJob(**payload))
else:
job_id = self.scheduler.create_job(CreateJob(**payload))

self.finish(json.dumps(dict(job_id=job_id)))

@tornado.web.authenticated
def patch(self, job_id):
async def patch(self, job_id):
payload = self.get_json_body()

if "status" not in payload:
raise tornado.web.HTTPError(500, "Field 'status' missing in request body")

status = Status(payload.get("status"))
if status == Status.STOPPED:
self.scheduler.stop_job(job_id)
if inspect.isawaitable(self.scheduler.stop_job):
await self.scheduler.stop_job(job_id)
else:
self.scheduler.stop_job(job_id)
else:
self.scheduler.update_job(UpdateJob(job_id=job_id, status=str(status)))
if inspect.isawaitable(self.scheduler.update_job):
await self.scheduler.update_job(UpdateJob(job_id=job_id, status=str(status)))
else:
self.scheduler.update_job(UpdateJob(job_id=job_id, status=str(status)))
self.set_status(204)
self.finish()

@tornado.web.authenticated
def delete(self, job_id):
self.scheduler.delete_job(job_id)
async def delete(self, job_id):
if inspect.isawaitable(self.scheduler.delete_job):
await self.scheduler.delete_job(job_id)
else:
self.scheduler.delete_job(job_id)
self.set_status(204)
self.finish()


class BatchJobHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def delete(self):
async def delete(self):
job_ids = self.get_query_arguments("job_id")
for job_id in job_ids:
self.scheduler.delete_job(job_id)
if inspect.isawaitable(self.scheduler.delete_job):
for job_id in job_ids:
await self.scheduler.delete_job(job_id)
else:
for job_id in job_ids:
self.scheduler.delete_job(job_id)

self.set_status(204)
self.finish()


class JobsCountHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def get(self):
async def get(self):
status = self.get_query_argument("status", None)
count = self.scheduler.count_jobs(
CountJobsQuery(status=Status(status.upper()) if status else Status.IN_PROGRESS)
count_jobs_query = CountJobsQuery(
status=Status(status.upper()) if status else Status.IN_PROGRESS
)
if inspect.isawaitable(self.scheduler.count_jobs):
count = await self.scheduler.count_jobs(count_jobs_query)
else:
count = self.scheduler.count_jobs(count_jobs_query)
self.finish(json.dumps(dict(count=count)))


Expand All @@ -174,11 +200,14 @@ def environment_manager(self):
self._environment_manager = config.environments_manager_class()
return self._environment_manager

def get(self):
async def get(self):
"""Returns names of available runtime environments"""

try:
environments = self.environment_manager.list_environments()
if inspect.isawaitable(self.environment_manager.list_environments):
environments = await self.environment_manager.list_environments()
else:
environments = self.environment_manager.list_environments()
except EnvironmentRetrievalError as e:
raise tornado.web.HTTPError(500, str(e))

Expand Down

0 comments on commit 8d2b0e9

Please sign in to comment.