Skip to content
This repository has been archived by the owner on Dec 29, 2022. It is now read-only.

Data parallelism across multiple GPUs #121

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
"""In addition to keeping the most recent checkpoint
files, keep one checkpoint file for every N hours of
training.""")
tf.flags.DEFINE_boolean("log_device_placement", False,
"""If true, logs device placement.""")

FLAGS = tf.flags.FLAGS

Expand All @@ -121,7 +123,8 @@ def create_experiment(output_dir):
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
log_device_placement=FLAGS.log_device_placement,
)

train_options = training_utils.TrainOptions(
Expand Down
90 changes: 86 additions & 4 deletions seq2seq/models/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def _clip_gradients(self, grads_and_vars):

def _build_train_op(self, loss):
"""Creates the training operation"""
colocate_gradients_with_ops = self.params["training.data_parallelism"] > 1

learning_rate_decay_fn = training_utils.create_learning_rate_decay_fn(
decay_type=self.params["optimizer.lr_decay_type"] or None,
decay_steps=self.params["optimizer.lr_decay_steps"],
Expand All @@ -88,7 +90,8 @@ def _build_train_op(self, loss):
learning_rate_decay_fn=learning_rate_decay_fn,
clip_gradients=self._clip_gradients,
optimizer=self.params["optimizer.name"],
summaries=["learning_rate", "loss", "gradients", "gradient_norm"])
summaries=["learning_rate", "loss", "gradients", "gradient_norm"],
colocate_gradients_with_ops=colocate_gradients_with_ops)

@staticmethod
def default_params():
Expand All @@ -104,22 +107,101 @@ def default_params():
"optimizer.lr_min_learning_rate": 1e-12,
"optimizer.lr_staircase": False,
"optimizer.clip_gradients": 5.0,
"training.data_parallelism": 0
}

def batch_size(self, features, labels):
"""Returns the batch size for a batch of examples"""
raise NotImplementedError()

def _build_parallel(self, features, labels, params):
"""Builds one or more model replicas on GPU devices. If
`training.data_parallelism` is set to 0 this function does nothing and
just calls the build method.

If `training.data_parallelism` is >= 1 and not enough GPUs are
available this will throw an error.

If `training.data_parallelism` = N >= 1 and enough GPUs are available this
will create a model replica on each GPU andsplit the training batch into
N pieces, merge the predictions and average the losses.

If model is not in training mode this does nothing and just calls the
build method.
"""
parallelism = self.params["training.data_parallelism"]

# Create model template function
template_build = tf.make_template(
self.name, self._build, create_scope_now_=True)

# Data parallelism is disabled
if parallelism <= 0:
with tf.variable_scope(self.name):
return template_build(features, labels, params)

# Not training
if self.mode != tf.contrib.learn.ModeKeys.TRAIN:
with tf.variable_scope(self.name):
return template_build(features, labels, params)

# Data parallelism is enabled
available_gpus = training_utils.get_available_gpus()
tf.logging.info("Available GPUs: %s", available_gpus)

# Make sure we have enough GPUs
if len(available_gpus) < parallelism:
raise ValueError(
"Data Parallelism set to {}, but only {} GPUs available""".format(
parallelism, len(available_gpus)))

# Split all features and labels
features_split = {k: tf.split(v, parallelism) for k, v in features.items()}
labels_split = {k: tf.split(v, parallelism) for k, v in labels.items()}
tf.logging.info(features_split)

all_losses = []
all_predictions = []

for idx in range(parallelism):
# Create each model replica
gpu_device = available_gpus[idx]
tf.logging.info("Creating replica %d on device %s", idx, gpu_device)
with tf.device(gpu_device):
rep_features = {k: v[idx] for k, v in features_split.items()}
rep_labels = {k: v[idx] for k, v in labels_split.items()}
rep_pred, rep_loss = template_build(rep_features, rep_labels, params)
all_losses.append(rep_loss)
all_predictions.append(rep_pred)

# Concat all predictions
prediction_keys = all_predictions[0].keys()
predictions = {
k: tf.concat([_[k] for _ in all_predictions], 0)
for k in prediction_keys
}

# Take the average loss
loss = tf.reduce_mean(all_losses)

return predictions, loss


def __call__(self, features, labels, params):
"""Creates the model graph. See the model_fn documentation in
tf.contrib.learn.Estimator class for a more detailed explanation.
"""
with tf.variable_scope("model"):
with tf.variable_scope(self.name):
return self._build(features, labels, params)
predictions, loss = self._build_parallel(features, labels, params)
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
train_op = self._build_train_op(loss)
else:
train_op = None
return predictions, loss, train_op

def _build(self, features, labels, params):
"""Subclasses should implement this method. See the `model_fn` documentation
in tf.contrib.learn.Estimator class for a more detailed explanation.
in tf.contrib.learn.Estimator class for a more detailed explanation. This
function should return a tuple of (predictions, loss)
"""
raise NotImplementedError
19 changes: 10 additions & 9 deletions seq2seq/models/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def _preprocess(self, features, labels):
"target.max_seq_len"]]
labels["target_len"] = tf.minimum(labels["target_len"],
self.params["target.max_seq_len"])
# Slice up to longest example in this batch.
# This is required for multi-gpu training where each batch is split
# across replicas
labels["target_tokens"] = labels["target_tokens"][:, :tf.reduce_max(
labels["target_len"])]

# Look up the target ids in the vocabulary
labels["target_ids"] = target_vocab_to_id.lookup(labels["target_tokens"])
Expand All @@ -255,9 +260,10 @@ def _preprocess(self, features, labels):
tf.summary.histogram("target_len", tf.to_float(labels["target_len"]))

# Keep track of the number of processed tokens
num_tokens = tf.reduce_sum(labels["target_len"])
num_tokens += tf.reduce_sum(features["source_len"])
token_counter_var = tf.Variable(0, "tokens_counter")
num_tokens = tf.to_int64(tf.reduce_sum(labels["target_len"]))
num_tokens += tf.to_int64(tf.reduce_sum(features["source_len"]))
token_counter_var = tf.get_variable("tokens_counter", [],
dtype=tf.int64, initializer=tf.constant_initializer(0))
total_tokens = tf.assign_add(token_counter_var, num_tokens)
tf.summary.scalar("num_tokens", total_tokens)

Expand Down Expand Up @@ -301,14 +307,9 @@ def _build(self, features, labels, params):
predictions = self._create_predictions(
decoder_output=decoder_output, features=features, labels=labels)
loss = None
train_op = None
else:
losses, loss = self.compute_loss(decoder_output, features, labels)

train_op = None
if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
train_op = self._build_train_op(loss)

predictions = self._create_predictions(
decoder_output=decoder_output,
features=features,
Expand All @@ -319,4 +320,4 @@ def _build(self, features, labels, params):
# can easly find them in our hooks/monitors.
graph_utils.add_dict_to_collection(predictions, "predictions")

return predictions, loss, train_op
return predictions, loss
3 changes: 2 additions & 1 deletion seq2seq/test/example_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
EXAMPLE_CONFIG_DIR = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../example_configs"))

# Do not test multi-device support here - it takes too long
delattr(EncoderDecoderTests, "test_train_multi_device")

def _load_model_from_config(config_path, hparam_overrides, vocab_file, mode):
"""Loads model from a configuration file"""
Expand All @@ -50,7 +52,6 @@ def _load_model_from_config(config_path, hparam_overrides, vocab_file, mode):
model_params["vocab_target"] = vocab_file
return model_cls(params=model_params, mode=mode)


class ExampleConfigTest(object):
"""Interface for configuration-based tests"""

Expand Down
25 changes: 25 additions & 0 deletions seq2seq/test/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,31 @@ def _test_pipeline(self, mode, params=None):

return model, fetches_

def test_train_multi_device(self):
parallelism = 4
self.batch_size = self.batch_size * parallelism
# Return 4 CPU devices to make sure model splitting
training_utils.get_available_gpus = lambda: ["cpu:0"] * parallelism
model, fetches_ = self._test_pipeline(
tf.contrib.learn.ModeKeys.TRAIN,
params={"training.data_parallelism": parallelism})

predictions_, loss_, _ = fetches_

target_len = self.sequence_length + 10 + 2
max_decode_length = model.params["target.max_seq_len"]
expected_decode_len = np.minimum(target_len, max_decode_length)

np.testing.assert_array_equal(predictions_["logits"].shape, [
self.batch_size, expected_decode_len - 1,
model.target_vocab_info.total_size
])
np.testing.assert_array_equal(predictions_["losses"].shape,
[self.batch_size, expected_decode_len - 1])
np.testing.assert_array_equal(predictions_["predicted_ids"].shape,
[self.batch_size, expected_decode_len - 1])
self.assertFalse(np.isnan(loss_))

def test_train(self):
model, fetches_ = self._test_pipeline(tf.contrib.learn.ModeKeys.TRAIN)
predictions_, loss_, _ = fetches_
Expand Down
2 changes: 1 addition & 1 deletion seq2seq/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from __future__ import unicode_literals

import tempfile
import tensorflow as tf

import tensorflow as tf

def create_temp_parallel_data(sources, targets):
"""
Expand Down
8 changes: 8 additions & 0 deletions seq2seq/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@

import tensorflow as tf
from tensorflow import gfile
from tensorflow.python.client import device_lib # pylint: disable=E0611

from seq2seq.contrib import rnn_cell


def get_available_gpus():
"""Returns a list of available GPU devices names.
"""
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == "GPU"]


class TrainOptions(object):
"""A collection of options that are passed to the training script
and can be saved to perform inference later.
Expand Down