From 378449be85caa1947f8c694fdf2c10c0cf5785f6 Mon Sep 17 00:00:00 2001 From: azawlocki Date: Tue, 1 Jun 2021 15:14:08 +0200 Subject: [PATCH] Wait for instance tasks when shutting down Cluster --- yapapi/executor/services.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/yapapi/executor/services.py b/yapapi/executor/services.py index df56b88b2..9e3b27e3c 100644 --- a/yapapi/executor/services.py +++ b/yapapi/executor/services.py @@ -203,13 +203,14 @@ def __init__( self._payload = payload self._num_instances = num_instances self._expiration = expiration or datetime.now(timezone.utc) + DEFAULT_SERVICE_EXPIRATION + self._task_ids = itertools.count(1) + self._stack = AsyncExitStack() self.__instances: List[ServiceInstance] = [] """List of Service instances""" - self._task_ids = itertools.count(1) - - self._stack = AsyncExitStack() + self._instance_tasks: Set[asyncio.Task] = set() + """Set of asyncio tasks that run spawn_service()""" def __repr__(self): return f"Cluster {self.id}: {self._num_instances}x[Service: {self._service_class.__name__}, Payload: {self._payload}]" @@ -237,6 +238,17 @@ async def agreements_pool_cycler(): async def __aexit__(self, exc_type, exc_val, exc_tb): logger.debug("%s is shutting down...", self) + # Give the instance tasks some time to terminate gracefully. + # Then cancel them without mercy! + if self._instance_tasks: + logger.debug("Waiting for service instances to terminate...") + _, still_running = await asyncio.wait(self._instance_tasks, timeout=10) + if still_running: + for task in still_running: + logger.debug("Cancelling task: %s", task) + task.cancel() + await asyncio.gather(*still_running, return_exceptions=True) + # TODO: should be different if we stop due to an error termination_reason = { "message": "Successfully finished all work", @@ -443,7 +455,8 @@ def spawn_instances(self, num_instances: Optional[int] = None) -> None: loop = asyncio.get_event_loop() for i in range(num_instances): - loop.create_task(self.spawn_instance()) + task = loop.create_task(self.spawn_instance()) + self._instance_tasks.add(task) def stop(self): """Signal the whole cluster to stop."""