Skip to content

Commit

Permalink
Refactor execute_tool operation
Browse files Browse the repository at this point in the history
  • Loading branch information
heisner-tillman committed Jan 14, 2024
1 parent 0dd767c commit 3a20cfd
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions lib/galaxy/webapps/galaxy/api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from fastapi import (
Body,
Depends,
HTTPException,
Request,
UploadFile,
)
from starlette.datastructures import UploadFile as StarletteUploadFile
from typing_extensions import Annotated

from galaxy import (
exceptions,
Expand All @@ -31,6 +33,7 @@
FetchDataFormPayload,
FetchDataPayload,
)
from galaxy.schema.tools import ExecuteToolPayload
from galaxy.tools.evaluation import global_tool_errors
from galaxy.util.zipstream import ZipstreamWrapper
from galaxy.web import (
Expand Down Expand Up @@ -59,6 +62,15 @@
# Tool search bypasses the fulltext for the following list of terms
SEARCH_RESERVED_TERMS_FAVORITES = ["#favs", "#favorites", "#favourites"]

ExecuteToolBody = Annotated[
ExecuteToolPayload,
Body(
default=...,
title="",
description="",
),
]


class FormDataApiRoute(APIContentTypeRoute):
match_content_type = "multipart/form-data"
Expand Down Expand Up @@ -103,6 +115,26 @@ async def fetch_form(
files2.append(value)
return self.service.create_fetch(trans, payload, files2)

@router.post(
"/api/tools",
summary="Execute tool with a given parameter payload",
)
def execute_tool(
self,
payload: ExecuteToolBody,
# input_format: InputFormatQueryParameter,
trans: ProvidesHistoryContext = DependsOnTrans,
):
tool_id = payload.tool_id
tool_uuid = payload.tool_uuid
if tool_id in PROTECTED_TOOLS:
raise HTTPException(
status_code=400, detail=f"Cannot execute tool [{tool_id}] directly, must use alternative endpoint."
)
if tool_id is None and tool_uuid is None:
raise HTTPException(status_code=400, detail="Must specify a valid tool_id to use this endpoint.")
return self.service.create(trans, payload)


class ToolsController(BaseGalaxyAPIController, UsesVisualizationMixin):
"""
Expand Down Expand Up @@ -569,29 +601,6 @@ def error_stack(self, trans: GalaxyWebTransaction, **kwd):
"""
return global_tool_errors.error_stack

@expose_api_anonymous
def create(self, trans: GalaxyWebTransaction, payload, **kwd):
"""
POST /api/tools
Execute tool with a given parameter payload
:param input_format: input format for the payload. Possible values are
the default 'legacy' (where inputs nested inside conditionals or
repeats are identified with e.g. '<conditional_name>|<input_name>') or
'21.01' (where inputs inside conditionals or repeats are nested
elements).
:type input_format: str
"""
tool_id = payload.get("tool_id")
tool_uuid = payload.get("tool_uuid")
if tool_id in PROTECTED_TOOLS:
raise exceptions.RequestParameterInvalidException(
f"Cannot execute tool [{tool_id}] directly, must use alternative endpoint."
)
if tool_id is None and tool_uuid is None:
raise exceptions.RequestParameterInvalidException("Must specify a valid tool_id to use this endpoint.")
return self.service._create(trans, payload, **kwd)


def _kwd_or_payload(kwd: Dict[str, Any]) -> Dict[str, Any]:
if "payload" in kwd:
Expand Down

0 comments on commit 3a20cfd

Please sign in to comment.