diff --git a/yapapi/executor/__init__.py b/yapapi/executor/__init__.py index bd9da346d..0cc302e09 100644 --- a/yapapi/executor/__init__.py +++ b/yapapi/executor/__init__.py @@ -309,9 +309,10 @@ async def _shutdown(self, *exc_info): "%s still unpaid, waiting for invoices...", pluralize(len(self._agreements_to_pay), "agreement"), ) - await asyncio.wait( - {self._process_invoices_job}, timeout=30, return_when=asyncio.ALL_COMPLETED - ) + try: + await asyncio.wait_for(self._process_invoices_job, timeout=30) + except asyncio.TimeoutError: + logger.debug("process_invoices_job cancelled") if self._agreements_to_pay: logger.warning("Unpaid agreements: %s", self._agreements_to_pay) diff --git a/yapapi/executor/services.py b/yapapi/executor/services.py index b4e1759cc..be4350454 100644 --- a/yapapi/executor/services.py +++ b/yapapi/executor/services.py @@ -203,13 +203,14 @@ def __init__( self._payload = payload self._num_instances = num_instances self._expiration = expiration or datetime.now(timezone.utc) + DEFAULT_SERVICE_EXPIRATION + self._task_ids = itertools.count(1) + self._stack = AsyncExitStack() self.__instances: List[ServiceInstance] = [] """List of Service instances""" - self._task_ids = itertools.count(1) - - self._stack = AsyncExitStack() + self._instance_tasks: Set[asyncio.Task] = set() + """Set of asyncio tasks that run spawn_service()""" def __repr__(self): return f"Cluster {self.id}: {self._num_instances}x[Service: {self._service_class.__name__}, Payload: {self._payload}]" @@ -237,6 +238,17 @@ async def agreements_pool_cycler(): async def __aexit__(self, exc_type, exc_val, exc_tb): logger.debug("%s is shutting down...", self) + # Give the instance tasks some time to terminate gracefully. + # Then cancel them without mercy! + if self._instance_tasks: + logger.debug("Waiting for service instances to terminate...") + _, still_running = await asyncio.wait(self._instance_tasks, timeout=10) + if still_running: + for task in still_running: + logger.debug("Cancelling task: %s", task) + task.cancel() + await asyncio.gather(*still_running, return_exceptions=True) + # TODO: should be different if we stop due to an error termination_reason = { "message": "Successfully finished all work", @@ -446,7 +458,8 @@ def spawn_instances(self, num_instances: Optional[int] = None) -> None: loop = asyncio.get_event_loop() for i in range(num_instances): - loop.create_task(self.spawn_instance()) + task = loop.create_task(self.spawn_instance()) + self._instance_tasks.add(task) def stop(self): """Signal the whole cluster to stop."""