diff --git a/opennmt/models/model.py b/opennmt/models/model.py index 8dcc997b1..6732b1821 100644 --- a/opennmt/models/model.py +++ b/opennmt/models/model.py @@ -375,6 +375,9 @@ def _input_fn_impl(self, padded_shapes_fn = lambda: ( feat_padded_shapes_fn(), labels_padded_shapes_fn()) + if mode == tf.estimator.ModeKeys.TRAIN: + dataset = dataset.shuffle(buffer_size, seed=int(time.time())) + dataset = dataset.map( process_fn, num_parallel_calls=num_parallel_process_calls).prefetch(buffer_size) @@ -386,8 +389,6 @@ def _input_fn_impl(self, labels, maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length)) - dataset = dataset.shuffle(buffer_size, seed=int(time.time())) - dataset = dataset.repeat() num_buckets = num_buckets or 1 @@ -432,6 +433,9 @@ def _reduce_func(unused_key, dataset): batch_size, padded_shapes=padded_shapes) + if mode == tf.estimator.ModeKeys.TRAIN: + dataset = dataset.repeat() + iterator = dataset.make_initializable_iterator() # Add the initializer to a standard collection for it to be initialized.