diff --git a/bin/train.py b/bin/train.py index fe352aac..0a6f1b2a 100755 --- a/bin/train.py +++ b/bin/train.py @@ -105,6 +105,8 @@ """In addition to keeping the most recent checkpoint files, keep one checkpoint file for every N hours of training.""") +tf.flags.DEFINE_boolean("log_device_placement", False, + """If true, logs device placement.""") FLAGS = tf.flags.FLAGS @@ -121,7 +123,8 @@ def create_experiment(output_dir): save_checkpoints_secs=FLAGS.save_checkpoints_secs, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, - keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours + keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, + log_device_placement=FLAGS.log_device_placement, ) train_options = training_utils.TrainOptions( diff --git a/seq2seq/models/model_base.py b/seq2seq/models/model_base.py index b5e41552..f8f3508f 100644 --- a/seq2seq/models/model_base.py +++ b/seq2seq/models/model_base.py @@ -72,6 +72,8 @@ def _clip_gradients(self, grads_and_vars): def _build_train_op(self, loss): """Creates the training operation""" + colocate_gradients_with_ops = self.params["training.data_parallelism"] > 1 + learning_rate_decay_fn = training_utils.create_learning_rate_decay_fn( decay_type=self.params["optimizer.lr_decay_type"] or None, decay_steps=self.params["optimizer.lr_decay_steps"], @@ -88,7 +90,8 @@ def _build_train_op(self, loss): learning_rate_decay_fn=learning_rate_decay_fn, clip_gradients=self._clip_gradients, optimizer=self.params["optimizer.name"], - summaries=["learning_rate", "loss", "gradients", "gradient_norm"]) + summaries=["learning_rate", "loss", "gradients", "gradient_norm"], + colocate_gradients_with_ops=colocate_gradients_with_ops) @staticmethod def default_params(): @@ -104,22 +107,101 @@ def default_params(): "optimizer.lr_min_learning_rate": 1e-12, "optimizer.lr_staircase": False, "optimizer.clip_gradients": 5.0, + "training.data_parallelism": 0 } def batch_size(self, features, labels): """Returns the batch size for a batch of examples""" raise NotImplementedError() + def _build_parallel(self, features, labels, params): + """Builds one or more model replicas on GPU devices. If + `training.data_parallelism` is set to 0 this function does nothing and + just calls the build method. + + If `training.data_parallelism` is >= 1 and not enough GPUs are + available this will throw an error. + + If `training.data_parallelism` = N >= 1 and enough GPUs are available this + will create a model replica on each GPU andsplit the training batch into + N pieces, merge the predictions and average the losses. + + If model is not in training mode this does nothing and just calls the + build method. + """ + parallelism = self.params["training.data_parallelism"] + + # Create model template function + template_build = tf.make_template( + self.name, self._build, create_scope_now_=True) + + # Data parallelism is disabled + if parallelism <= 0: + with tf.variable_scope(self.name): + return template_build(features, labels, params) + + # Not training + if self.mode != tf.contrib.learn.ModeKeys.TRAIN: + with tf.variable_scope(self.name): + return template_build(features, labels, params) + + # Data parallelism is enabled + available_gpus = training_utils.get_available_gpus() + tf.logging.info("Available GPUs: %s", available_gpus) + + # Make sure we have enough GPUs + if len(available_gpus) < parallelism: + raise ValueError( + "Data Parallelism set to {}, but only {} GPUs available""".format( + parallelism, len(available_gpus))) + + # Split all features and labels + features_split = {k: tf.split(v, parallelism) for k, v in features.items()} + labels_split = {k: tf.split(v, parallelism) for k, v in labels.items()} + tf.logging.info(features_split) + + all_losses = [] + all_predictions = [] + + for idx in range(parallelism): + # Create each model replica + gpu_device = available_gpus[idx] + tf.logging.info("Creating replica %d on device %s", idx, gpu_device) + with tf.device(gpu_device): + rep_features = {k: v[idx] for k, v in features_split.items()} + rep_labels = {k: v[idx] for k, v in labels_split.items()} + rep_pred, rep_loss = template_build(rep_features, rep_labels, params) + all_losses.append(rep_loss) + all_predictions.append(rep_pred) + + # Concat all predictions + prediction_keys = all_predictions[0].keys() + predictions = { + k: tf.concat([_[k] for _ in all_predictions], 0) + for k in prediction_keys + } + + # Take the average loss + loss = tf.reduce_mean(all_losses) + + return predictions, loss + + def __call__(self, features, labels, params): """Creates the model graph. See the model_fn documentation in tf.contrib.learn.Estimator class for a more detailed explanation. """ with tf.variable_scope("model"): - with tf.variable_scope(self.name): - return self._build(features, labels, params) + predictions, loss = self._build_parallel(features, labels, params) + if self.mode == tf.contrib.learn.ModeKeys.TRAIN: + train_op = self._build_train_op(loss) + else: + train_op = None + return predictions, loss, train_op def _build(self, features, labels, params): """Subclasses should implement this method. See the `model_fn` documentation - in tf.contrib.learn.Estimator class for a more detailed explanation. + in tf.contrib.learn.Estimator class for a more detailed explanation. This + function should return a tuple of (predictions, loss) """ raise NotImplementedError diff --git a/seq2seq/models/seq2seq_model.py b/seq2seq/models/seq2seq_model.py index 423ffb75..73fa2bcf 100644 --- a/seq2seq/models/seq2seq_model.py +++ b/seq2seq/models/seq2seq_model.py @@ -247,6 +247,11 @@ def _preprocess(self, features, labels): "target.max_seq_len"]] labels["target_len"] = tf.minimum(labels["target_len"], self.params["target.max_seq_len"]) + # Slice up to longest example in this batch. + # This is required for multi-gpu training where each batch is split + # across replicas + labels["target_tokens"] = labels["target_tokens"][:, :tf.reduce_max( + labels["target_len"])] # Look up the target ids in the vocabulary labels["target_ids"] = target_vocab_to_id.lookup(labels["target_tokens"]) @@ -255,9 +260,10 @@ def _preprocess(self, features, labels): tf.summary.histogram("target_len", tf.to_float(labels["target_len"])) # Keep track of the number of processed tokens - num_tokens = tf.reduce_sum(labels["target_len"]) - num_tokens += tf.reduce_sum(features["source_len"]) - token_counter_var = tf.Variable(0, "tokens_counter") + num_tokens = tf.to_int64(tf.reduce_sum(labels["target_len"])) + num_tokens += tf.to_int64(tf.reduce_sum(features["source_len"])) + token_counter_var = tf.get_variable("tokens_counter", [], + dtype=tf.int64, initializer=tf.constant_initializer(0)) total_tokens = tf.assign_add(token_counter_var, num_tokens) tf.summary.scalar("num_tokens", total_tokens) @@ -301,14 +307,9 @@ def _build(self, features, labels, params): predictions = self._create_predictions( decoder_output=decoder_output, features=features, labels=labels) loss = None - train_op = None else: losses, loss = self.compute_loss(decoder_output, features, labels) - train_op = None - if self.mode == tf.contrib.learn.ModeKeys.TRAIN: - train_op = self._build_train_op(loss) - predictions = self._create_predictions( decoder_output=decoder_output, features=features, @@ -319,4 +320,4 @@ def _build(self, features, labels, params): # can easly find them in our hooks/monitors. graph_utils.add_dict_to_collection(predictions, "predictions") - return predictions, loss, train_op + return predictions, loss diff --git a/seq2seq/test/example_config_test.py b/seq2seq/test/example_config_test.py index 4733548a..aca574d0 100644 --- a/seq2seq/test/example_config_test.py +++ b/seq2seq/test/example_config_test.py @@ -35,6 +35,8 @@ EXAMPLE_CONFIG_DIR = os.path.abspath( os.path.join(os.path.dirname(__file__), "../../example_configs")) +# Do not test multi-device support here - it takes too long +delattr(EncoderDecoderTests, "test_train_multi_device") def _load_model_from_config(config_path, hparam_overrides, vocab_file, mode): """Loads model from a configuration file""" @@ -50,7 +52,6 @@ def _load_model_from_config(config_path, hparam_overrides, vocab_file, mode): model_params["vocab_target"] = vocab_file return model_cls(params=model_params, mode=mode) - class ExampleConfigTest(object): """Interface for configuration-based tests""" diff --git a/seq2seq/test/models_test.py b/seq2seq/test/models_test.py index d9e19acd..20b353b4 100644 --- a/seq2seq/test/models_test.py +++ b/seq2seq/test/models_test.py @@ -130,6 +130,31 @@ def _test_pipeline(self, mode, params=None): return model, fetches_ + def test_train_multi_device(self): + parallelism = 4 + self.batch_size = self.batch_size * parallelism + # Return 4 CPU devices to make sure model splitting + training_utils.get_available_gpus = lambda: ["cpu:0"] * parallelism + model, fetches_ = self._test_pipeline( + tf.contrib.learn.ModeKeys.TRAIN, + params={"training.data_parallelism": parallelism}) + + predictions_, loss_, _ = fetches_ + + target_len = self.sequence_length + 10 + 2 + max_decode_length = model.params["target.max_seq_len"] + expected_decode_len = np.minimum(target_len, max_decode_length) + + np.testing.assert_array_equal(predictions_["logits"].shape, [ + self.batch_size, expected_decode_len - 1, + model.target_vocab_info.total_size + ]) + np.testing.assert_array_equal(predictions_["losses"].shape, + [self.batch_size, expected_decode_len - 1]) + np.testing.assert_array_equal(predictions_["predicted_ids"].shape, + [self.batch_size, expected_decode_len - 1]) + self.assertFalse(np.isnan(loss_)) + def test_train(self): model, fetches_ = self._test_pipeline(tf.contrib.learn.ModeKeys.TRAIN) predictions_, loss_, _ = fetches_ diff --git a/seq2seq/test/utils.py b/seq2seq/test/utils.py index facec60f..5029efc5 100644 --- a/seq2seq/test/utils.py +++ b/seq2seq/test/utils.py @@ -20,8 +20,8 @@ from __future__ import unicode_literals import tempfile -import tensorflow as tf +import tensorflow as tf def create_temp_parallel_data(sources, targets): """ diff --git a/seq2seq/training/utils.py b/seq2seq/training/utils.py index 48aebc4d..ce29725b 100644 --- a/seq2seq/training/utils.py +++ b/seq2seq/training/utils.py @@ -30,10 +30,18 @@ import tensorflow as tf from tensorflow import gfile +from tensorflow.python.client import device_lib # pylint: disable=E0611 from seq2seq.contrib import rnn_cell +def get_available_gpus(): + """Returns a list of available GPU devices names. + """ + local_device_protos = device_lib.list_local_devices() + return [x.name for x in local_device_protos if x.device_type == "GPU"] + + class TrainOptions(object): """A collection of options that are passed to the training script and can be saved to perform inference later.