Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert AsyncDataReader to avoid using shared memory. #791

Merged
merged 1 commit into from
Mar 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 84 additions & 92 deletions fluid/DeepASR/data_utils/async_data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import SharedNDArray, SharedMemoryPoolManager
from data_utils.util import DaemonProcessGroup, batch_to_ndarray
from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal
from data_utils.util import CriticalException, ForceExitWrapper


class SampleInfo(object):
"""SampleInfo holds the necessary information to load a sample from disk.

Args:
feature_bin_path (str): File containing the feature data.
feature_start (int): Start position of the sample's feature data.
Expand Down Expand Up @@ -54,6 +53,7 @@ class SampleInfoBucket(object):
data, sample start position, sample byte number etc.) to access samples'
feature data and the same with the label description file. SampleInfoBucket
is the minimum unit to do shuffle.

Args:
feature_bin_paths (list|tuple): Files containing the binary feature
data.
Expand All @@ -67,8 +67,8 @@ class SampleInfoBucket(object):
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + \
rand() % split_perturb).
(split_sub_sentence_len
+ rand() % split_perturb).
"""

def __init__(self,
Expand Down Expand Up @@ -160,9 +160,14 @@ def generate_sample_info_list(self):
return sample_info_list


class EpochEndSignal():
pass


class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation.

Args:
feature_file_list (str): File containing paths of feature data file and
corresponding description file.
Expand Down Expand Up @@ -206,17 +211,12 @@ def __init__(self,
self.generate_bucket_list(True)
self._order_id = 0
self._manager = Manager()
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._proc_num = proc_num
if self._proc_num <= 2:
raise ValueError("Value of `proc_num` should be greater than 2.")
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
# buffer queue
self._sample_info_queue = self._manager.Queue(sample_info_buffer_size)
self._sample_queue = self._manager.Queue(sample_buffer_size)
self._batch_queue = self._manager.Queue(batch_buffer_size)

def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
Expand Down Expand Up @@ -250,21 +250,13 @@ def generate_bucket_list(self, is_shuffle):
def set_transformers(self, transformers):
self._transformers = transformers

def recycle(self, *args):
for shared_ndarray in args:
if not isinstance(shared_ndarray, SharedNDArray):
raise Value("Only support recycle SharedNDArray object.")
shared_ndarray.recycle(self._pool_manager.pool)

def _start_async_processing(self):
def _sample_generator(self):
sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0

@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)

for sample_info_bucket in self._bucket_list:
try:
sample_info_list = \
Expand All @@ -277,14 +269,13 @@ def ordered_feeding_task(sample_info_queue):
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1

for i in xrange(self._sample_proc_num):
for i in xrange(self._proc_num):
sample_info_queue.put(EpochEndSignal())

feeding_proc = DaemonProcessGroup(
proc_num=1,
target=ordered_feeding_task,
args=(self._sample_info_queue, ))
feeding_proc.start_all()
feeding_thread = Thread(
target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()

@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
Expand Down Expand Up @@ -312,11 +303,12 @@ def read_bytes(fpath, start, size):
sample_info.feature_size)

assert sample_info.feature_frame_num \
* sample_info.feature_dim * 4 == len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
* sample_info.feature_dim * 4 \
== len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))

label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start,
Expand Down Expand Up @@ -360,83 +352,83 @@ def read_bytes(fpath, start, size):
sample_queue.put(EpochEndSignal())

out_order = self._manager.list([0])
args = (self._sample_info_queue, self._sample_queue, out_order)
sample_proc = DaemonProcessGroup(
proc_num=self._sample_proc_num,
target=ordered_processing_task,
args=args)
sample_proc.start_all()
args = (sample_info_queue, sample_queue, out_order)
workers = [
Process(
target=ordered_processing_task, args=args)
for _ in xrange(self._proc_num)
]

def batch_iterator(self, batch_size, minimum_batch_size):
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_queue, batch_queue, pool):
def conv_to_shared(ndarray):
while self._force_exit == False:
try:
(name, shared_ndarray) = pool.popitem()
except Exception as e:
time.sleep(0.001)
else:
shared_ndarray.copy(ndarray)
return shared_ndarray
for w in workers:
w.daemon = True
w.start()

if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
finished_proc_num = 0

batch_samples = []
lod = [0]
done_num = 0
while done_num < self._sample_proc_num:
sample = sample_queue.get()
while self._force_exit == False:
try:
sample = sample_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal):
done_num += 1
else:
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
feature, label = batch_to_ndarray(batch_samples, lod)

feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
finished_proc_num += 1
if finished_proc_num >= self._proc_num:
break
else:
continue

batch_queue.put((feature, label, lod))
batch_samples = []
lod = [0]
yield sample

if len(batch_samples) >= minimum_batch_size:
(feature, label) = batch_to_ndarray(batch_samples, lod)
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)

feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
for sample in sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
(batch_feature, batch_label) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]

batch_queue.put((feature, label, lod))
if len(batch_samples) >= minimum_batch_size:
(batch_feature, batch_label) = batch_to_ndarray(batch_samples,
lod)
batch_queue.put((batch_feature, batch_label, lod))

batch_queue.put(EpochEndSignal())

self._start_async_processing()
batch_queue = Queue.Queue(self._batch_buffer_size)

self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager)

assembling_proc = DaemonProcessGroup(
proc_num=1,
assembling_thread = Thread(
target=batch_assembling_task,
args=(self._sample_queue, self._batch_queue,
self._pool_manager.pool))
assembling_proc.start_all()
args=(self._sample_generator, batch_queue))
assembling_thread.daemon = True
assembling_thread.start()

while self._force_exit == False:
try:
batch_data = self._batch_queue.get_nowait()
batch_data = batch_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data

# clean the shared memory
del self._pool_manager
Loading