This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix shared memory with gluon dataloader, add option pin_memory #11908
Merged
Merged
Changes from 10 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
7f6a821
use threading for mp dataloader fetching, allow pin_memory option
zhreshold 80c1368
allow pin tuple of data into cpu_pinned
zhreshold e49bb44
fix as_in_context if not cpu_pinned
zhreshold 2642f58
fix cpu_pinned
zhreshold b7481f9
fix unittest for windows, update doc that windows mp is available
zhreshold 61ddac4
fix pin_memory
zhreshold 14f5192
fix lint
zhreshold 91a5102
always use simplequeue for data queue
zhreshold c37e3b7
remove main thread clearing for data_queue
zhreshold 52f96eb
do not use outside folder as pythonpath but run nosetests inside
zhreshold 945b3d1
use :MXNET_LIBRARY_PATH= to locate dll
zhreshold d06e631
fix dll path
zhreshold b06638f
correct dll path
zhreshold File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable= | ||
# pylint: disable=ungrouped-imports | ||
"""Dataset generator.""" | ||
__all__ = ['DataLoader'] | ||
|
||
|
@@ -26,6 +26,7 @@ | |
import multiprocessing | ||
import multiprocessing.queues | ||
from multiprocessing.reduction import ForkingPickler | ||
import threading | ||
import numpy as np | ||
|
||
try: | ||
|
@@ -149,6 +150,14 @@ def default_mp_batchify_fn(data): | |
ctx=context.Context('cpu_shared', 0)) | ||
|
||
|
||
def _as_in_context(data, ctx): | ||
"""Move data into new context.""" | ||
if isinstance(data, nd.NDArray): | ||
return data.as_in_context(ctx) | ||
elif isinstance(data, (list, tuple)): | ||
return [_as_in_context(d, ctx) for d in data] | ||
return data | ||
|
||
def worker_loop(dataset, key_queue, data_queue, batchify_fn): | ||
"""Worker loop for multiprocessing DataLoader.""" | ||
dataset._fork() | ||
|
@@ -159,9 +168,21 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn): | |
batch = batchify_fn([dataset[i] for i in samples]) | ||
data_queue.put((idx, batch)) | ||
|
||
def fetcher_loop(data_queue, data_buffer, pin_memory=False): | ||
"""Fetcher loop for fetching data from queue and put in reorder dict.""" | ||
while True: | ||
idx, batch = data_queue.get() | ||
if idx is None: | ||
break | ||
if pin_memory: | ||
batch = _as_in_context(batch, context.cpu_pinned()) | ||
else: | ||
batch = _as_in_context(batch, context.cpu()) | ||
data_buffer[idx] = batch | ||
|
||
class _MultiWorkerIter(object): | ||
"""Interal multi-worker iterator for DataLoader.""" | ||
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler): | ||
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False): | ||
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) | ||
self._num_workers = num_workers | ||
self._dataset = dataset | ||
|
@@ -184,6 +205,12 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler): | |
worker.start() | ||
workers.append(worker) | ||
|
||
self._fetcher = threading.Thread( | ||
target=fetcher_loop, | ||
args=(self._data_queue, self._data_buffer, pin_memory)) | ||
self._fetcher.daemon = True | ||
self._fetcher.start() | ||
|
||
# pre-fetch | ||
for _ in range(2 * self._num_workers): | ||
self._push_next() | ||
|
@@ -210,13 +237,11 @@ def __next__(self): | |
raise StopIteration | ||
|
||
while True: | ||
self._push_next() | ||
if self._rcvd_idx in self._data_buffer: | ||
batch = self._data_buffer.pop(self._rcvd_idx) | ||
self._rcvd_idx += 1 | ||
self._push_next() | ||
return batch | ||
idx, batch = self._data_queue.get() | ||
self._data_buffer[idx] = batch | ||
|
||
def next(self): | ||
return self.__next__() | ||
|
@@ -229,11 +254,7 @@ def shutdown(self): | |
if not self._shutdown: | ||
for _ in range(self._num_workers): | ||
self._key_queue.put((None, None)) | ||
try: | ||
while not self._data_queue.empty(): | ||
self._data_queue.get() | ||
except IOError: | ||
pass | ||
self._data_queue.put((None, None)) | ||
self._shutdown = True | ||
|
||
|
||
|
@@ -277,12 +298,16 @@ def default_batchify_fn(data): | |
|
||
num_workers : int, default 0 | ||
The number of multiprocessing workers to use for data preprocessing. | ||
`num_workers > 0` is not supported on Windows yet. | ||
pin_memory : boolean, default False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why default to False? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary for non-gpu instances, and using non-pagable memory for everything can screw up the stability |
||
If ``True``, the dataloader will copy NDArrays into pinned memory | ||
before returning them. Copying from CPU pinned memory to GPU is faster | ||
than from normal CPU memory. | ||
""" | ||
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, | ||
last_batch=None, batch_sampler=None, batchify_fn=None, | ||
num_workers=0): | ||
num_workers=0, pin_memory=False): | ||
self._dataset = dataset | ||
self._pin_memory = pin_memory | ||
|
||
if batch_sampler is None: | ||
if batch_size is None: | ||
|
@@ -315,13 +340,17 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, | |
|
||
def __iter__(self): | ||
if self._num_workers == 0: | ||
generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch])) | ||
for batch in self._batch_sampler] | ||
return generator() | ||
def same_process_iter(): | ||
for batch in self._batch_sampler: | ||
ret = self._batchify_fn([self._dataset[idx] for idx in batch]) | ||
if self._pin_memory: | ||
ret = _as_in_context(ret, context.cpu_pinned()) | ||
yield ret | ||
return same_process_iter() | ||
|
||
# multi-worker | ||
return _MultiWorkerIter(self._num_workers, self._dataset, | ||
self._batchify_fn, self._batch_sampler) | ||
self._batchify_fn, self._batch_sampler, self._pin_memory) | ||
|
||
def __len__(self): | ||
return len(self._batch_sampler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,99 +130,89 @@ def test_multi_worker(): | |
for i, batch in enumerate(loader): | ||
assert (batch.asnumpy() == i).all() | ||
|
||
@with_seed() | ||
def test_multi_worker_forked_data_loader(): | ||
class _Dummy(Dataset): | ||
"""Dummpy dataset for randomized shape arrays.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will fix that |
||
def __init__(self, random_shape): | ||
self.random_shape = random_shape | ||
|
||
def __getitem__(self, idx): | ||
key = idx | ||
if self.random_shape: | ||
out = np.random.uniform(size=(random.randint(1000, 1100), 40)) | ||
labels = np.random.uniform(size=(random.randint(10, 15))) | ||
else: | ||
out = np.random.uniform(size=(1000, 40)) | ||
labels = np.random.uniform(size=(10)) | ||
return key, out, labels | ||
|
||
def __len__(self): | ||
return 50 | ||
|
||
def _batchify_list(data): | ||
""" | ||
return list of ndarray without stack/concat/pad | ||
""" | ||
Test should successfully run its course of multi-process/forked data loader without errors | ||
if isinstance(data, (tuple, list)): | ||
return list(data) | ||
if isinstance(data, mx.nd.NDArray): | ||
return [data] | ||
return data | ||
|
||
def _batchify(data): | ||
""" | ||
Collate data into batch. Use shared memory for stacking. | ||
:param data: a list of array, with layout of 'NTC'. | ||
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths | ||
if labels are not supplied. | ||
""" | ||
class Dummy(Dataset): | ||
def __init__(self, random_shape): | ||
self.random_shape = random_shape | ||
|
||
def __getitem__(self, idx): | ||
key = idx | ||
if self.random_shape: | ||
out = np.random.uniform(size=(random.randint(1000, 1100), 40)) | ||
labels = np.random.uniform(size=(random.randint(10, 15))) | ||
else: | ||
out = np.random.uniform(size=(1000, 40)) | ||
labels = np.random.uniform(size=(10)) | ||
return key, out, labels | ||
|
||
def __len__(self): | ||
return 50 | ||
|
||
def batchify_list(self, data): | ||
""" | ||
return list of ndarray without stack/concat/pad | ||
""" | ||
if isinstance(data, (tuple, list)): | ||
return list(data) | ||
if isinstance(data, mx.nd.NDArray): | ||
return [data] | ||
return data | ||
|
||
def batchify(self, data): | ||
""" | ||
Collate data into batch. Use shared memory for stacking. | ||
|
||
:param data: a list of array, with layout of 'NTC'. | ||
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths | ||
if labels are not supplied. | ||
""" | ||
|
||
# input layout is NTC | ||
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \ | ||
[item[2] for item in data] | ||
|
||
if len(data) > 1: | ||
max_data_len = max([seq.shape[0] for seq in inputs]) | ||
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels]) | ||
else: | ||
max_data_len = inputs[0].shape[0] | ||
max_labels_len = 0 if not labels else labels[0].shape[0] | ||
|
||
x_lens = [item.shape[0] for item in inputs] | ||
y_lens = [item.shape[0] for item in labels] | ||
|
||
for i, seq in enumerate(inputs): | ||
pad_len = max_data_len - seq.shape[0] | ||
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0) | ||
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]), | ||
'constant', constant_values=-1) | ||
|
||
inputs = np.asarray(inputs, dtype=np.float32) | ||
if labels is not None: | ||
labels = np.asarray(labels, dtype=np.float32) | ||
inputs = inputs.transpose((1, 0, 2)) | ||
labels = labels.transpose((1, 0)) | ||
|
||
return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \ | ||
if labels is None else ( | ||
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(y_lens, ctx=context.Context('cpu_shared', 0))) | ||
|
||
# input layout is NTC | ||
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \ | ||
[item[2] for item in data] | ||
|
||
if len(data) > 1: | ||
max_data_len = max([seq.shape[0] for seq in inputs]) | ||
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels]) | ||
else: | ||
max_data_len = inputs[0].shape[0] | ||
max_labels_len = 0 if not labels else labels[0].shape[0] | ||
|
||
x_lens = [item.shape[0] for item in inputs] | ||
y_lens = [item.shape[0] for item in labels] | ||
|
||
for i, seq in enumerate(inputs): | ||
pad_len = max_data_len - seq.shape[0] | ||
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0) | ||
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]), | ||
'constant', constant_values=-1) | ||
|
||
inputs = np.asarray(inputs, dtype=np.float32) | ||
if labels is not None: | ||
labels = np.asarray(labels, dtype=np.float32) | ||
inputs = inputs.transpose((1, 0, 2)) | ||
labels = labels.transpose((1, 0)) | ||
|
||
return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \ | ||
if labels is None else ( | ||
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)), | ||
nd.array(y_lens, ctx=context.Context('cpu_shared', 0))) | ||
|
||
# This test is pointless on Windows because Windows doesn't fork | ||
if platform.system() != 'Windows': | ||
data = Dummy(True) | ||
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify, num_workers=2) | ||
for epoch in range(1): | ||
for i, data in enumerate(loader): | ||
if i % 100 == 0: | ||
print(data) | ||
print('{}:{}'.format(epoch, i)) | ||
|
||
data = Dummy(True) | ||
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2) | ||
for epoch in range(1): | ||
for i, data in enumerate(loader): | ||
if i % 100 == 0: | ||
print(data) | ||
print('{}:{}'.format(epoch, i)) | ||
@with_seed() | ||
def test_multi_worker_forked_data_loader(): | ||
data = _Dummy(False) | ||
loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2) | ||
for epoch in range(1): | ||
for i, data in enumerate(loader): | ||
pass | ||
|
||
data = _Dummy(True) | ||
loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2) | ||
for epoch in range(1): | ||
for i, data in enumerate(loader): | ||
pass | ||
|
||
if __name__ == '__main__': | ||
import nose | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we explain this? I don't get it, aren't the packed files in the windows_package folder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to make PYTHONPATH the git cloned one instead of stashed build, because otherwise when unittest is launched in git cloned folder, the pickled functions(from multiprocess) cannot find the correct dlls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you instead set the MXNET_LIBRARY_PATH environment variable? this should be the correct way.