Skip to content

Commit

Permalink
Add environment and rpc_layer to the TF_CONFIG environment variable i…
Browse files Browse the repository at this point in the history
…n distribute coordinator.

PiperOrigin-RevId: 210197404
  • Loading branch information
Yuefeng Zhou authored and tensorflower-gardener committed Aug 25, 2018
1 parent ca94990 commit 04ffe2f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 12 deletions.
65 changes: 54 additions & 11 deletions tensorflow/python/distribute/distribute_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
import json
import os
import threading
import time

from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib

Expand Down Expand Up @@ -332,16 +335,38 @@ def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
rpc_layer=None):
rpc_layer=None,
environment=None):
"""Runs a standard server."""
server = server_lib.Server(
cluster_spec,
job_name=task_type,
task_index=task_id,
config=session_config,
protocol=rpc_layer)
server.start()
return server

class _FakeServer(object):
"""A fake server that runs a master session."""

def start(self):
assert cluster_spec
target = cluster_spec.task_address(task_type, task_id)
if rpc_layer:
target = rpc_layer + "://" + target
# A tensorflow server starts when a remote session is created.
session.Session(target=target, config=session_config)

def join(self):
while True:
time.sleep(5)

if environment == "google":
server = _FakeServer()
server.start()
return server
else:
server = server_lib.Server(
cluster_spec,
job_name=task_type,
task_index=task_id,
config=session_config,
protocol=rpc_layer)
server.start()
return server


def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
Expand Down Expand Up @@ -541,8 +566,18 @@ def run_distribute_coordinator(worker_fn,
"`tf.train.ClusterDef` object")
# TODO(yuefengz): validate cluster_spec.

rpc_layer = tf_config.get("rpc_layer", rpc_layer)
environment = tf_config.get("environment", None)

if cluster_spec:
logging.info(
"Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
"task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)

if not cluster_spec:
# `mode` is ignored in the local case.
logging.info("Running local Distribute Coordinator.")
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
Expand All @@ -564,7 +599,11 @@ def run_distribute_coordinator(worker_fn,
else:
# If not a client job, run the standard server.
server = _run_std_server(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
rpc_layer=rpc_layer,
environment=environment)
server.join()
else:
if mode != CoordinatorMode.INDEPENDENT_WORKER:
Expand All @@ -575,7 +614,11 @@ def run_distribute_coordinator(worker_fn,

# Every one starts a standard server.
server = _run_std_server(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
rpc_layer=rpc_layer,
environment=environment)

if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if strategy.between_graph:
Expand Down
64 changes: 63 additions & 1 deletion tensorflow/python/distribute/distribute_coordinator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

import contextlib
import copy
import json
import os
import sys
import time
import threading
import six

Expand Down Expand Up @@ -59,6 +61,8 @@
NUM_WORKERS = 3
NUM_PS = 2

original_sys_exit = sys.exit


def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
Expand Down Expand Up @@ -369,7 +373,8 @@ def _run_mock_std_server(self,
cluster_spec=None,
task_type=None,
task_id=None,
rpc_layer=None):
rpc_layer=None,
environment=None):
task_type = str(task_type)
task_id = task_id or 0
with self._lock:
Expand Down Expand Up @@ -730,6 +735,63 @@ def testInGraphContextWithEval(self):
self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)

def testRunStdServerInGoogleEnvironment(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
tf_config = {"cluster": cluster_spec, "environment": "google"}

joined = [False]

def _fake_sleep(_):
joined[0] = True
original_sys_exit(0)

def _thread_fn(cluster_spec):
distribute_coordinator.run_distribute_coordinator(
None,
None,
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
task_id=0)

with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
time, "sleep", _fake_sleep):
t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
t.start()
t.join()
self.assertTrue(joined[0])

def testRpcLayerEnvironmentVariable(self):
cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}

rpc_layer_from_coordinator = [None]

def _run_mock_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
rpc_layer=None,
environment=None):
del cluster_spec, task_type, task_id, session_config, environment
rpc_layer_from_coordinator[0] = rpc_layer
return MockServer()

with test.mock.patch.dict(
"os.environ",
{"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
distribute_coordinator, "_run_std_server", _run_mock_server):
distribute_coordinator.run_distribute_coordinator(
None,
None,
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
task_id=0)
self.assertEqual(rpc_layer_from_coordinator[0], "cake")


if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.
Expand Down

0 comments on commit 04ffe2f

Please sign in to comment.