Skip to content

Commit

Permalink
Revise dataset processing order
Browse files Browse the repository at this point in the history
* shuffle before applying the preprocessing function
* call repeat after batching the dataset otherwise batches can contain
  examples from future epochs
  • Loading branch information
guillaumekln committed Nov 17, 2017
1 parent 5a4d715 commit 6f3fe6e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions opennmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ def _input_fn_impl(self,
padded_shapes_fn = lambda: (
feat_padded_shapes_fn(), labels_padded_shapes_fn())

if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.shuffle(buffer_size, seed=int(time.time()))

dataset = dataset.map(
process_fn,
num_parallel_calls=num_parallel_process_calls).prefetch(buffer_size)
Expand All @@ -386,8 +389,6 @@ def _input_fn_impl(self,
labels,
maximum_features_length=maximum_features_length,
maximum_labels_length=maximum_labels_length))
dataset = dataset.shuffle(buffer_size, seed=int(time.time()))
dataset = dataset.repeat()

num_buckets = num_buckets or 1

Expand Down Expand Up @@ -432,6 +433,9 @@ def _reduce_func(unused_key, dataset):
batch_size,
padded_shapes=padded_shapes)

if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.repeat()

iterator = dataset.make_initializable_iterator()

# Add the initializer to a standard collection for it to be initialized.
Expand Down

0 comments on commit 6f3fe6e

Please sign in to comment.