diff --git a/yapapi/engine.py b/yapapi/engine.py index 9c7b8dda7..c32608e0d 100644 --- a/yapapi/engine.py +++ b/yapapi/engine.py @@ -176,6 +176,8 @@ def __init__( self._jobs: Set[Job] = set() self._process_invoices_job: Optional[asyncio.Task] = None + # a set of async generators created by executors that use this engine + self._generators: Set[AsyncGenerator] = set() self._services: Set[asyncio.Task] = set() self._stack = AsyncExitStack() @@ -267,7 +269,13 @@ async def _shutdown(self, *exc_info): logger.info("Golem is shutting down...") + # Some generators created by `execute_tasks` may still have elements; + # if we don't close them now, their jobs will never be marked as finished. + for gen in self._generators: + await gen.aclose() + # Wait until all computations are finished + logger.debug("Waiting for the jobs to finish...") await asyncio.gather(*[job.finished.wait() for job in self._jobs]) logger.info("All jobs have finished") @@ -306,8 +314,8 @@ async def _shutdown(self, *exc_info): except Exception: logger.debug("Got error when waiting for services to finish", exc_info=True) - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._stack.aclose() + async def __aexit__(self, *exc_info) -> Optional[bool]: + return await self._stack.__aexit__(*exc_info) async def _create_allocations(self) -> rest.payment.MarketDecoration: @@ -458,6 +466,10 @@ def finalize_job(job: "Job"): """Mark a job as finished.""" job.finished.set() + def register_generator(self, generator: AsyncGenerator) -> None: + """Register a generator with this engine.""" + self._generators.add(generator) + @dataclass class PaymentDecorator(DemandDecorator): """A `DemandDecorator` that adds payment-related constraints and properties to a Demand.""" diff --git a/yapapi/executor/__init__.py b/yapapi/executor/__init__.py index 3d33f9f39..5c8f474ea 100644 --- a/yapapi/executor/__init__.py +++ b/yapapi/executor/__init__.py @@ -215,7 +215,7 @@ def emit(self, event: events.Event) -> None: """Emit a computation event using this `Executor`'s engine.""" self._engine.emit(event) - async def submit( + def submit( self, worker: Callable[ [WorkContext, AsyncIterator[Task[D, R]]], @@ -231,6 +231,19 @@ async def submit( on providers :return: yields computation progress events """ + generator = self._create_task_generator(worker, data) + self._engine.register_generator(generator) + return generator + + async def _create_task_generator( + self, + worker: Callable[ + [WorkContext, AsyncIterator[Task[D, R]]], + AsyncGenerator[WorkItem, Awaitable[List[events.CommandEvent]]], + ], + data: Union[AsyncIterator[Task[D, R]], Iterable[Task[D, R]]], + ) -> AsyncGenerator[Task[D, R], None]: + """Create an async generator yielding completed tasks.""" job = Job(self._engine, expiration_time=self._expires, payload=self._payload) self._engine.add_job(job)