diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index e0b6aec294a0..412d3134476b 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -183,7 +183,8 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False): class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader.""" - def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False): + def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, + worker_fn=worker_loop): assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) self._num_workers = num_workers self._dataset = dataset @@ -200,7 +201,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory= workers = [] for _ in range(self._num_workers): worker = multiprocessing.Process( - target=worker_loop, + target=worker_fn, args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn)) worker.daemon = True worker.start()