Skip to content

Commit

Permalink
Implement a TypeVar for ParallelRun subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Oct 11, 2024
1 parent b933089 commit 8f4241d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions hai/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def failed_task_names(self) -> Set[str]:


RT = TypeVar("RT")
ParallelRunType = TypeVar("ParallelRunType", bound="BaseParallelRun")


class BaseParallelRun:
Expand Down Expand Up @@ -86,7 +87,7 @@ def wait(
self,
fail_fast: bool = True,
interval: float = 0.5,
callback: Optional[Callable[["BaseParallelRun"], None]] = None,
callback: Optional[Callable[[ParallelRunType], None]] = None,
max_wait: Optional[float] = None,
) -> List["ApplyResult[Any]"]:
"""
Expand Down Expand Up @@ -128,7 +129,7 @@ def wait(
had_any_incomplete_task = self._wait_tick(fail_fast)

if callback:
callback(self)
callback(self) # type: ignore[arg-type]

# If there were no incomplete tasks left last iteration, quit.
if not had_any_incomplete_task:
Expand Down

0 comments on commit 8f4241d

Please sign in to comment.