diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 3e2d1b85c..b7493c2f4 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -619,7 +619,7 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): - Key is used for dict. New key is acceptable. scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 64) - max_workers: Maximum number of threads to execute function calls. (Default value: None) + max_workers: Maximum number of threads to execute function calls **threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor`` Note: @@ -628,7 +628,8 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): Note: For optimal use of all threads, we recommend ``scheduled_tasks`` > ``max_workers``. High value of ``scheduled_tasks`` - might lead to long waiting period until the first element is yielded as tasks are executed out of order. + might lead to long waiting period until the first element is yielded as ``next`` is called + ``scheduled_tasks`` many times on ``source_datapipe`` before yielding. Example: @@ -683,7 +684,7 @@ def mul_ten(x): def __init__( self, - datapipe: IterDataPipe, + source_datapipe: IterDataPipe, fn: Callable, input_col=None, output_col=None, @@ -692,7 +693,7 @@ def __init__( **threadpool_kwargs, ) -> None: super().__init__() - self.datapipe = datapipe + self.datapipe = source_datapipe _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment]