Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated handlers to support async scheduler apis #41

Merged
merged 1 commit into from
Sep 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 60 additions & 31 deletions jupyter_scheduler/handlers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
import re
from dataclasses import asdict
Expand Down Expand Up @@ -60,16 +61,20 @@ def post(self):

class CreateJobWithDefinitionHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def post(self, job_definition_id: str):
async def post(self, job_definition_id: str):
job_definition = self.scheduler.get_job_definition(job_definition_id)
if job_definition is None:
raise tornado.web.HTTPError(
404, f"Job definition with id: {job_definition_id} not found"
)

payload = self.get_json_body()

job_id = self.scheduler.create_job(CreateJob(**job_definition.dict().merge(payload)))
if inspect.isawaitable(self.scheduler.create_job):
job_id = await self.scheduler.create_job(
CreateJob(**job_definition.dict().merge(payload))
)
else:
job_id = self.scheduler.create_job(CreateJob(**job_definition.dict().merge(payload)))
self.finish(json.dumps(dict(job_id=job_id)))


Expand All @@ -91,76 +96,97 @@ def compute_sort_model(query_argument):

class JobHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def get(self, job_id=None):
async def get(self, job_id=None):
if job_id:
job = self.scheduler.get_job(job_id)
self.finish(job.json())
else:
status = self.get_query_argument("status", None)
start_time = self.get_query_argument("start_time", None)
sort_by = compute_sort_model(self.get_query_arguments("sort_by"))
list_jobs_response = self.scheduler.list_jobs(
ListJobsQuery(
job_definition_id=self.get_query_argument("job_definition_id", None),
status=Status(status.upper()) if status else None,
name=self.get_query_argument("name", None),
tags=self.get_query_arguments("tags", None),
start_time=int(start_time) if start_time else None,
sort_by=sort_by if sort_by else [DEFAULT_SORT],
max_items=self.get_query_argument("max_items", DEFAULT_MAX_ITEMS),
next_token=self.get_query_argument("next_token", None),
)
list_jobs_query = ListJobsQuery(
job_definition_id=self.get_query_argument("job_definition_id", None),
status=Status(status.upper()) if status else None,
name=self.get_query_argument("name", None),
tags=self.get_query_arguments("tags", None),
start_time=int(start_time) if start_time else None,
sort_by=sort_by if sort_by else [DEFAULT_SORT],
max_items=self.get_query_argument("max_items", DEFAULT_MAX_ITEMS),
next_token=self.get_query_argument("next_token", None),
)
if inspect.isawaitable(self.scheduler.list_jobs):
list_jobs_response = await self.scheduler.list_jobs(list_jobs_query)
else:
list_jobs_response = self.scheduler.list_jobs(list_jobs_query)
self.finish(list_jobs_response.json(exclude_none=True))

@tornado.web.authenticated
def post(self):
async def post(self):
payload = self.get_json_body()

job_id = self.scheduler.create_job(CreateJob(**payload))
if inspect.isawaitable(self.scheduler.create_job):
job_id = await self.scheduler.create_job(CreateJob(**payload))
else:
job_id = self.scheduler.create_job(CreateJob(**payload))

self.finish(json.dumps(dict(job_id=job_id)))

@tornado.web.authenticated
def patch(self, job_id):
async def patch(self, job_id):
payload = self.get_json_body()

if "status" not in payload:
raise tornado.web.HTTPError(500, "Field 'status' missing in request body")

status = Status(payload.get("status"))
if status == Status.STOPPED:
self.scheduler.stop_job(job_id)
if inspect.isawaitable(self.scheduler.stop_job):
await self.scheduler.stop_job(job_id)
else:
self.scheduler.stop_job(job_id)
else:
self.scheduler.update_job(UpdateJob(job_id=job_id, status=str(status)))
if inspect.isawaitable(self.scheduler.update_job):
await self.scheduler.update_job(UpdateJob(job_id=job_id, status=str(status)))
else:
self.scheduler.update_job(UpdateJob(job_id=job_id, status=str(status)))
self.set_status(204)
self.finish()

@tornado.web.authenticated
def delete(self, job_id):
self.scheduler.delete_job(job_id)
async def delete(self, job_id):
if inspect.isawaitable(self.scheduler.delete_job):
await self.scheduler.delete_job(job_id)
else:
self.scheduler.delete_job(job_id)
self.set_status(204)
self.finish()


class BatchJobHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def delete(self):
async def delete(self):
job_ids = self.get_query_arguments("job_id")
for job_id in job_ids:
self.scheduler.delete_job(job_id)
if inspect.isawaitable(self.scheduler.delete_job):
for job_id in job_ids:
await self.scheduler.delete_job(job_id)
else:
for job_id in job_ids:
self.scheduler.delete_job(job_id)

self.set_status(204)
self.finish()


class JobsCountHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@tornado.web.authenticated
def get(self):
async def get(self):
status = self.get_query_argument("status", None)
count = self.scheduler.count_jobs(
CountJobsQuery(status=Status(status.upper()) if status else Status.IN_PROGRESS)
count_jobs_query = CountJobsQuery(
status=Status(status.upper()) if status else Status.IN_PROGRESS
)
if inspect.isawaitable(self.scheduler.count_jobs):
count = await self.scheduler.count_jobs(count_jobs_query)
else:
count = self.scheduler.count_jobs(count_jobs_query)
self.finish(json.dumps(dict(count=count)))


Expand All @@ -174,11 +200,14 @@ def environment_manager(self):
self._environment_manager = config.environments_manager_class()
return self._environment_manager

def get(self):
async def get(self):
"""Returns names of available runtime environments"""

try:
environments = self.environment_manager.list_environments()
if inspect.isawaitable(self.environment_manager.list_environments):
environments = await self.environment_manager.list_environments()
else:
environments = self.environment_manager.list_environments()
except EnvironmentRetrievalError as e:
raise tornado.web.HTTPError(500, str(e))

Expand Down