From 8f4241d8fde3cdd3a37f1b3efb5e8e4932115257 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 11 Oct 2024 13:04:31 +0300 Subject: [PATCH] Implement a TypeVar for ParallelRun subclasses --- hai/parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hai/parallel.py b/hai/parallel.py index 4a2f30a..67e5204 100644 --- a/hai/parallel.py +++ b/hai/parallel.py @@ -34,6 +34,7 @@ def failed_task_names(self) -> Set[str]: RT = TypeVar("RT") +ParallelRunType = TypeVar("ParallelRunType", bound="BaseParallelRun") class BaseParallelRun: @@ -86,7 +87,7 @@ def wait( self, fail_fast: bool = True, interval: float = 0.5, - callback: Optional[Callable[["BaseParallelRun"], None]] = None, + callback: Optional[Callable[[ParallelRunType], None]] = None, max_wait: Optional[float] = None, ) -> List["ApplyResult[Any]"]: """ @@ -128,7 +129,7 @@ def wait( had_any_incomplete_task = self._wait_tick(fail_fast) if callback: - callback(self) + callback(self) # type: ignore[arg-type] # If there were no incomplete tasks left last iteration, quit. if not had_any_incomplete_task: