Skip to content

Commit

Permalink
Merge pull request #791 from pkuyym/fix-790
Browse files Browse the repository at this point in the history
Revert AsyncDataReader to avoid using shared memory.
  • Loading branch information
pkuyym authored Mar 29, 2018
2 parents 59bc4c1 + 7f0b566 commit bf38065
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 260 deletions.
176 changes: 84 additions & 92 deletions fluid/DeepASR/data_utils/async_data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import SharedNDArray, SharedMemoryPoolManager
from data_utils.util import DaemonProcessGroup, batch_to_ndarray
from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal
from data_utils.util import CriticalException, ForceExitWrapper


class SampleInfo(object):
"""SampleInfo holds the necessary information to load a sample from disk.
Args:
feature_bin_path (str): File containing the feature data.
feature_start (int): Start position of the sample's feature data.
Expand Down Expand Up @@ -54,6 +53,7 @@ class SampleInfoBucket(object):
data, sample start position, sample byte number etc.) to access samples'
feature data and the same with the label description file. SampleInfoBucket
is the minimum unit to do shuffle.
Args:
feature_bin_paths (list|tuple): Files containing the binary feature
data.
Expand All @@ -67,8 +67,8 @@ class SampleInfoBucket(object):
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + \
rand() % split_perturb).
(split_sub_sentence_len
+ rand() % split_perturb).
"""

def __init__(self,
Expand Down Expand Up @@ -160,9 +160,14 @@ def generate_sample_info_list(self):
return sample_info_list


class EpochEndSignal():
pass


class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation.
Args:
feature_file_list (str): File containing paths of feature data file and
corresponding description file.
Expand Down Expand Up @@ -206,17 +211,12 @@ def __init__(self,
self.generate_bucket_list(True)
self._order_id = 0
self._manager = Manager()
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._proc_num = proc_num
if self._proc_num <= 2:
raise ValueError("Value of `proc_num` should be greater than 2.")
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
# buffer queue
self._sample_info_queue = self._manager.Queue(sample_info_buffer_size)
self._sample_queue = self._manager.Queue(sample_buffer_size)
self._batch_queue = self._manager.Queue(batch_buffer_size)

def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
Expand Down Expand Up @@ -250,21 +250,13 @@ def generate_bucket_list(self, is_shuffle):
def set_transformers(self, transformers):
self._transformers = transformers

def recycle(self, *args):
for shared_ndarray in args:
if not isinstance(shared_ndarray, SharedNDArray):
raise Value("Only support recycle SharedNDArray object.")
shared_ndarray.recycle(self._pool_manager.pool)

def _start_async_processing(self):
def _sample_generator(self):
sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0

@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)

for sample_info_bucket in self._bucket_list:
try:
sample_info_list = \
Expand All @@ -277,14 +269,13 @@ def ordered_feeding_task(sample_info_queue):
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1

for i in xrange(self._sample_proc_num):
for i in xrange(self._proc_num):
sample_info_queue.put(EpochEndSignal())

feeding_proc = DaemonProcessGroup(
proc_num=1,
target=ordered_feeding_task,
args=(self._sample_info_queue, ))
feeding_proc.start_all()
feeding_thread = Thread(
target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()

@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
Expand Down Expand Up @@ -312,11 +303,12 @@ def read_bytes(fpath, start, size):
sample_info.feature_size)

assert sample_info.feature_frame_num \
* sample_info.feature_dim * 4 == len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
* sample_info.feature_dim * 4 \
== len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))

label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start,
Expand Down Expand Up @@ -360,83 +352,83 @@ def read_bytes(fpath, start, size):
sample_queue.put(EpochEndSignal())

out_order = self._manager.list([0])
args = (self._sample_info_queue, self._sample_queue, out_order)
sample_proc = DaemonProcessGroup(
proc_num=self._sample_proc_num,
target=ordered_processing_task,
args=args)
sample_proc.start_all()
args = (sample_info_queue, sample_queue, out_order)
workers = [
Process(
target=ordered_processing_task, args=args)
for _ in xrange(self._proc_num)
]

def batch_iterator(self, batch_size, minimum_batch_size):
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_queue, batch_queue, pool):
def conv_to_shared(ndarray):
while self._force_exit == False:
try:
(name, shared_ndarray) = pool.popitem()
except Exception as e:
time.sleep(0.001)
else:
shared_ndarray.copy(ndarray)
return shared_ndarray
for w in workers:
w.daemon = True
w.start()

if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)
finished_proc_num = 0

batch_samples = []
lod = [0]
done_num = 0
while done_num < self._sample_proc_num:
sample = sample_queue.get()
while self._force_exit == False:
try:
sample = sample_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal):
done_num += 1
else:
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
feature, label = batch_to_ndarray(batch_samples, lod)

feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
finished_proc_num += 1
if finished_proc_num >= self._proc_num:
break
else:
continue

batch_queue.put((feature, label, lod))
batch_samples = []
lod = [0]
yield sample

if len(batch_samples) >= minimum_batch_size:
(feature, label) = batch_to_ndarray(batch_samples, lod)
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)

feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
for sample in sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
(batch_feature, batch_label) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]

batch_queue.put((feature, label, lod))
if len(batch_samples) >= minimum_batch_size:
(batch_feature, batch_label) = batch_to_ndarray(batch_samples,
lod)
batch_queue.put((batch_feature, batch_label, lod))

batch_queue.put(EpochEndSignal())

self._start_async_processing()
batch_queue = Queue.Queue(self._batch_buffer_size)

self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager)

assembling_proc = DaemonProcessGroup(
proc_num=1,
assembling_thread = Thread(
target=batch_assembling_task,
args=(self._sample_queue, self._batch_queue,
self._pool_manager.pool))
assembling_proc.start_all()
args=(self._sample_generator, batch_queue))
assembling_thread.daemon = True
assembling_thread.start()

while self._force_exit == False:
try:
batch_data = self._batch_queue.get_nowait()
batch_data = batch_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(batch_data, EpochEndSignal):
break
yield batch_data

# clean the shared memory
del self._pool_manager
Loading

0 comments on commit bf38065

Please sign in to comment.