Skip to content

Commit

Permalink
use BatchSetGenerator caching only if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 14, 2017
1 parent d01b9ff commit a925c7a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CachedDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def init_seq_order(self, epoch=None, seq_list=None):
return False
return True

def batch_set_generator_cache_whole_epoch(self):
return True

def _init_alloc_intervals(self):
assert self.num_seqs > 0
assert self.num_inputs > 0
Expand Down
21 changes: 20 additions & 1 deletion Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,21 @@ def _generate_batches(self, recurrent_net, batch_size, max_seqs=-1, seq_drop=0.0
if batch.get_all_slices_num_frames() > 0:
yield batch

def batch_set_generator_cache_whole_epoch(self):
"""
The BatchSetGenerator can cache the list of batches which we generated across epochs.
See self.generate_batches() and self._generate_batches().
In many cases, the dataset does not support this, and in that case,
it is not needed to enable this cache and waste memory.
Caching it together with option shuffle_batches could also mean that
there will be self.load_seqs() calls with non-monotonic seq-idxs.
The only dataset currently which enables this is CachedDataset and thus HDFDataset.
:return: whether we should enable this cache
:rtype: bool
"""
return False

def generate_batches(self, recurrent_net, batch_size, max_seqs=-1, seq_drop=0.0, max_seq_length=sys.maxsize, shuffle_batches=False):
"""
:type recurrent_net: bool
Expand All @@ -546,7 +561,11 @@ def generate_batches(self, recurrent_net, batch_size, max_seqs=-1, seq_drop=0.0,
:type shuffle_batches: bool
:rtype: BatchSetGenerator
"""
return BatchSetGenerator(self, self._generate_batches(recurrent_net, batch_size, max_seqs, seq_drop, max_seq_length), shuffle_batches)
return BatchSetGenerator(
dataset=self,
generator=self._generate_batches(recurrent_net, batch_size, max_seqs, seq_drop, max_seq_length),
shuffle_batches=shuffle_batches,
cache_whole_epoch=self.batch_set_generator_cache_whole_epoch())

def shapes_for_batches(self, batches, data_keys, batch_dim_first=False):
"""
Expand Down
4 changes: 2 additions & 2 deletions Engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def train_epoch(self):
self.print_network_info()

training_devices = self.devices
if not 'train' in self.dataset_batches:
if 'train' not in self.dataset_batches or not self.train_data.batch_set_generator_cache_whole_epoch():
self.dataset_batches['train'] = self.train_data.generate_batches(recurrent_net=self.network.recurrent,
batch_size=self.batch_size,
max_seqs=self.max_seqs,
Expand Down Expand Up @@ -498,7 +498,7 @@ def format_score(self, score):
def eval_model(self):
eval_dump_str = []
for dataset_name, dataset in self.get_eval_datasets().items():
if not dataset_name in self.dataset_batches:
if dataset_name not in self.dataset_batches or not dataset.batch_set_generator_cache_whole_epoch():
self.dataset_batches[dataset_name] = dataset.generate_batches(recurrent_net=self.network.recurrent,
batch_size=self.batch_size,
max_seqs=self.max_seqs,
Expand Down
21 changes: 15 additions & 6 deletions EngineBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,37 @@ class BatchSetGenerator:
you call self.advance() explicitly to go forward to next batches.
"""

def __init__(self, dataset, generator, shuffle_batches=True):
def __init__(self, dataset, generator, shuffle_batches=True, cache_whole_epoch=True):
"""
:type dataset: Dataset.Dataset
:type generator: iter[Batch]
"""
self.dataset = dataset
self.generator = generator
self.shuffle_batches = shuffle_batches
self.cache = []; " :type: list[Batch] "
# In some cases, it might be faster to cache the list of batches.
self.cache_whole_epoch = cache_whole_epoch
self.cache = [] # type: list[Batch]
self.reached_end = False
random.seed(1234)
self.reset()
self._reset()

def reset(self):
def _reset(self):
self.buffer = self.cache[:]
if self.shuffle_batches:
random.shuffle(self.buffer)
self.cache_active = self.reached_end
self.reached_end = False
self.last_batch = None; " :type: Batch "
self.last_batch = None # type: Batch
self.current_batch_idx = 0

def reset(self):
"""
Call this after one epoch to reuse the previously cached batches.
"""
assert self.cache_whole_epoch
self._reset()

def _read_next(self):
if self.reached_end:
return False
Expand All @@ -165,7 +174,7 @@ def _read_next(self):
return False
else:
self.buffer += [batch]
if not self.cache_active:
if self.cache_whole_epoch and not self.cache_active:
self.cache += [batch]
return True

Expand Down
4 changes: 2 additions & 2 deletions TFEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def train_epoch(self):
if self.is_pretrain_epoch():
self.print_network_info()

if 'train' not in self.dataset_batches:
if 'train' not in self.dataset_batches or not self.train_data.batch_set_generator_cache_whole_epoch():
self.dataset_batches['train'] = self.train_data.generate_batches(recurrent_net=self.network.recurrent,
batch_size=self.batch_size,
max_seqs=self.max_seqs,
Expand Down Expand Up @@ -887,7 +887,7 @@ def format_score(self, score):
def eval_model(self):
eval_dump_str = []
for dataset_name, dataset in self.get_eval_datasets().items():
if dataset_name not in self.dataset_batches:
if dataset_name not in self.dataset_batches or not dataset.batch_set_generator_cache_whole_epoch():
self.dataset_batches[dataset_name] = dataset.generate_batches(recurrent_net=self.network.recurrent,
batch_size=self.batch_size,
max_seqs=self.max_seqs,
Expand Down

0 comments on commit a925c7a

Please sign in to comment.