Skip to content

Commit

Permalink
Implement canceling fetch_data task
Browse files Browse the repository at this point in the history
Two key things, move stream_to_file into subprocess, so we're not
blocking when reading from a socket, and cancel task directly via
the job.state. I couldn't get AbortableTask and revoke to work,
and it seems like that only works via the database anyway.
  • Loading branch information
mvdbeek committed Apr 27, 2022
1 parent a723b35 commit 6a7a385
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 10 deletions.
33 changes: 25 additions & 8 deletions lib/galaxy/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
)
from celery.contrib.abortable import AbortableTask
from kombu import serialization
from galaxy import model

from galaxy.config import Configuration
from galaxy.main_config import find_config
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.util import ExecutionTimer
from galaxy.util.custom_logging import get_logger
from galaxy.util.properties import load_app_properties
Expand Down Expand Up @@ -156,19 +158,36 @@ def get_cleanup_short_term_storage_interval():
celery_app.conf.timezone = "UTC"


async def cancellable_task(f: Callable, task, request_id, *args, **kwargs):
task.request.id = request_id
async def cancellable_task(f: Callable, *args, **kwargs):
coro = sync_to_async(f)
done, pending = await asyncio.wait([coro(*args, **kwargs), is_aborted(task)], return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait(
[coro(*args, **kwargs), is_aborted(kwargs["job_id"])], return_when=asyncio.FIRST_COMPLETED
)
for pending_future in pending:
pending_future.cancel()
for done_future in done:
return done_future.result()


async def is_aborted(task: AbortableTask):
while not task.is_aborted():
async def is_aborted(job_id):
# This is not ideal
# 1. we're calling a method that does I/O (sqlalchemy query)
# 2. we're polling every second
# Maybe we should listen for broadcast messages ?
# Alternatively we could construct an AsyncSession and decrease the polling frequency over time
# Other notes: AbortableTask could be be revoked, but that only works with the database backend
# ... which shouldn't be used.
app = get_galaxy_app()
session = app[galaxy_scoped_session]

def get_state():
return session.query(model.Job.state).filter_by(id=job_id).one()[0]

state = get_state()
while state not in {model.Job.states.DELETED, model.Job.states.DELETING}:
await asyncio.sleep(1)
state = get_state()
log.debug(f"Job {job_id} aborted")


def galaxy_task(*args, action=None, **celery_task_kwd):
Expand All @@ -193,9 +212,7 @@ def wrapper(*args, **kwds):
try:
partial_task_function = app.magic_partial(func)
if args and isinstance(args[0], AbortableTask):
task = args[0]
request_id = task.request.id
rval = async_to_sync(cancellable_task)(partial_task_function, args[0], request_id, *args, **kwds)
rval = async_to_sync(cancellable_task)(partial_task_function, *args, **kwds)
else:
rval = partial_task_function(*args, **kwds)
message = f"Successfully executed Celery task {desc} {timer}"
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _fetch_data(setup_return, datatypes_registry: DatatypesRegistry):
def fetch_data(
self,
setup_return,
job_id,
datatypes_registry: DatatypesRegistry,
):
return _fetch_data(setup_return=setup_return, datatypes_registry=datatypes_registry)
Expand Down
8 changes: 8 additions & 0 deletions lib/galaxy/datatypes/sniff.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def sniff_with_cls(cls, fname):


def stream_url_to_file(path: str, file_sources: Optional[ConfiguredFileSources] = None):
from concurrent.futures import ProcessPoolExecutor

pool = ProcessPoolExecutor()
fut = pool.submit(_stream_url_to_file, path, file_sources)
return fut.result()


def _stream_url_to_file(path: str, file_sources: Optional[ConfiguredFileSources] = None):
prefix = "url_paste"
if file_sources and file_sources.looks_like_uri(path):
file_source_path = file_sources.get_file_source_path(path)
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/tools/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def execute_single_job(execution_slice, completed_job):
raw_tool_source = tool.tool_source.to_string()
async_result = (
setup_fetch_data.s(job_id, raw_tool_source=raw_tool_source)
| fetch_data.s()
| fetch_data.s(job_id=job_id)
| set_job_metadata.s(extended_metadata_collection="extended" in tool.app.config.metadata_strategy)
| finish_job.si(job_id=job_id, raw_tool_source=raw_tool_source)
)()
Expand Down
4 changes: 3 additions & 1 deletion lib/galaxy_test/api/test_tools_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def test_fetch_html_from_url(self, history_id):

@uses_test_history(require_new=False)
def test_abort_fetch_job(self, history_id):
# This should probably be an integration test that also verifies
# that the celery chord is properly canceled.
item = {
"src": "url",
"url": "https://httpstat.us/200?sleep=10000",
Expand Down Expand Up @@ -382,7 +384,7 @@ def test_abort_fetch_job(self, history_id):
history_id, dataset_id=response["outputs"][0]["id"], assert_ok=False
)
assert dataset["file_size"] == 0
assert dataset["state"] == "deleted"
assert dataset["state"] == "discarded"

@skip_without_datatype("velvet")
def test_composite_datatype(self):
Expand Down

0 comments on commit 6a7a385

Please sign in to comment.