Skip to content

Commit

Permalink
Add an option to RunConfig and train_and_evaluate to run distribute c…
Browse files Browse the repository at this point in the history
…oordinator.

This is necessary to run multi-worker MirroredStrategy and CollectiveAllReduceStrategy with estimator.

PiperOrigin-RevId: 210192378
  • Loading branch information
Yuefeng Zhou authored and tensorflower-gardener committed Aug 25, 2018
1 parent 9599b47 commit ca94990
Show file tree
Hide file tree
Showing 11 changed files with 1,100 additions and 26 deletions.
1 change: 1 addition & 0 deletions tensorflow/contrib/distribute/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ py_library(
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:distribute_config",
],
)
2 changes: 2 additions & 0 deletions tensorflow/contrib/distribute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.distribute.distribute_config import DistributeConfig
from tensorflow.python.training.distribute import *
from tensorflow.python.training.distribution_strategy_context import *

Expand All @@ -37,6 +38,7 @@
'AllReduceCrossTowerOps',
'CollectiveAllReduceStrategy',
'CrossTowerOps',
'DistributeConfig',
'DistributionStrategy',
'MirroredStrategy',
'Monitor',
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/contrib/distribute/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,32 @@ cuda_py_test(
],
)

cuda_py_test(
name = "estimator_training_test",
size = "large",
srcs = ["estimator_training_test.py"],
additional_deps = [
":combinations",
":mirrored_strategy",
":multi_worker_test_base",
":parameter_server_strategy",
"//third_party/py/numpy",
"//tensorflow/contrib/optimizer_v2:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute",
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:summary",
],
tags = [
"multi_and_single_gpu",
"no_pip",
],
)

py_library(
name = "single_loss_example",
srcs = ["single_loss_example.py"],
Expand Down
Loading

0 comments on commit ca94990

Please sign in to comment.