Skip to content

Commit

Permalink
Pass session_config to device_lib.list_local_devices when possible (#293
Browse files Browse the repository at this point in the history
)
  • Loading branch information
guillaumekln authored Dec 21, 2018
1 parent 2e8a4a4 commit ff6b7aa
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
7 changes: 5 additions & 2 deletions opennmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,24 @@ def __call__(self, features, labels, params, mode, config=None):
"""
return self._build(features, labels, params, mode, config=config)

def model_fn(self, num_devices=1, eval_prediction_hooks_fn=None):
def model_fn(self, num_devices=1, eval_prediction_hooks_fn=None, devices=None):
"""Returns the model function.
Args:
num_devices: The number of devices used for training.
eval_prediction_hooks_fn: A callable that takes the model predictions
during evaluation and return an iterable of evaluation hooks (e.g. for
saving predictions on disk, running external evaluators, etc.).
devices: The list of devices used for training, if known.
See Also:
``tf.estimator.Estimator`` 's ``model_fn`` argument for more details about
arguments and the returned value.
"""
dispatcher = GraphDispatcher(
num_devices, daisy_chain_variables=self.daisy_chain_variables)
num_devices=num_devices,
daisy_chain_variables=self.daisy_chain_variables,
devices=devices)

def _loss_op(features, labels, params, mode, config):
"""Single callable to compute the loss."""
Expand Down
10 changes: 4 additions & 6 deletions opennmt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from opennmt.utils import hooks, checkpoint, misc
from opennmt.utils.evaluator import external_evaluation_fn
from opennmt.utils.misc import format_translation_output, OrderRestorer
from opennmt.utils.parallel import get_devices


# These options require a value but we can fallback to a default one.
Expand Down Expand Up @@ -114,10 +115,6 @@ def __init__(self,
session_config=session_config,
tf_random_seed=seed)

# Create a first session to enforce GPU options.
# See https://github.com/OpenNMT/OpenNMT-tf/issues/80.
_ = tf.Session(config=session_config)

np.random.seed(seed)
random.seed(seed)

Expand All @@ -136,10 +133,11 @@ def __init__(self,
run_config = run_config.replace(
keep_checkpoint_max=self._config["train"]["keep_checkpoint_max"])

devices = get_devices(num_devices=num_devices, session_config=session_config)
self._estimator = tf.estimator.Estimator(
self._model.model_fn(
num_devices=self._num_devices,
eval_prediction_hooks_fn=self._make_eval_prediction_hooks_fn()),
eval_prediction_hooks_fn=self._make_eval_prediction_hooks_fn(),
devices=devices),
config=run_config,
params=self._config["params"])

Expand Down
60 changes: 48 additions & 12 deletions opennmt/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,41 @@
import tensorflow as tf

from tensorflow.python.client import device_lib
from tensorflow.python.estimator.util import fn_args


class GraphDispatcher(object):
"""Helper class to replicate graph parts on multiple devices and dispatch
sharded batches.
"""

def __init__(self, num_devices, daisy_chain_variables=True):
def __init__(self,
num_devices=None,
daisy_chain_variables=True,
devices=None,
session_config=None):
"""Initializes the dispatcher.
Args:
num_devices: The number of devices to dispatch on.
daisy_chain_variables: If ``True``, variables are copied in a daisy chain
fashion between devices (credits to Tensor2Tensor).
devices: List of devices to use (takes priority over :obj:`num_devices`).
session_config: Session configuration to use when querying available
devices.
Raises:
ValueError: if the number of visible devices is lower than
:obj:`num_devices`.
"""
devices = [x.name for x in device_lib.list_local_devices() if x.device_type == "GPU"]
self._daisy_chain_variables = daisy_chain_variables

if not devices:
self._n = 1
self._devices = [None]
elif len(devices) < num_devices:
raise ValueError("Only %d devices are visible but %d were requested"
% (len(devices), num_devices))
if devices:
self._devices = devices
elif num_devices is not None:
self._devices = get_devices(num_devices=num_devices, session_config=session_config)
else:
self._n = num_devices
self._devices = devices[:self._n]
self._devices = [None]
self._n = len(self._devices)
self._daisy_chain_variables = daisy_chain_variables

def shard(self, data):
"""Shards a structure of ``tf.Tensor`` for dispatching.
Expand Down Expand Up @@ -178,3 +182,35 @@ def _split_dictionary(dictionary):
data_shards = tf.split(data, num_shards)

return data_shards

def get_devices(num_devices=None, session_config=None):
"""Returns available devices.
Args:
num_devices: The number of devices to get.
session_config: An optional session configuration to use when querying
available devices.
Returns:
A list of devices.
Raises:
ValueError: if :obj:`num_devices` is set but the number of visible devices
is lower than it.
"""
kwargs = {}
if "session_config" in fn_args(device_lib.list_local_devices):
kwargs["session_config"] = session_config
else:
# Create a first session to enforce config, otherwise list_local_devices()
# will run some initialization with default options.
_ = tf.Session(config=session_config)
devices = [x.name for x in device_lib.list_local_devices(**kwargs) if x.device_type == "GPU"]
if not devices:
return [None]
elif num_devices is None:
return devices
elif len(devices) < num_devices:
raise ValueError("Only %d devices are visible but %d were requested"
% (len(devices), num_devices))
return devices[:num_devices]

0 comments on commit ff6b7aa

Please sign in to comment.