-
Notifications
You must be signed in to change notification settings - Fork 716
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add distributed training example of using TF 2.1 Strategy API (#1164)
* add distributed training example of using TF 2.1 Strategy API . add multi-worker strategy example to keras-API directory. . move existing strategy example to estimator-API directory, to distinguish strategy examples. * remove duplicated example . This example was moved to estimator-API directory
- Loading branch information
Showing
10 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FROM tensorflow/tensorflow:2.1.0-gpu-py3 | ||
|
||
RUN pip install tensorflow_datasets | ||
|
||
COPY multi_worker_strategy-with-keras.py / | ||
ENTRYPOINT ["python", "/multi_worker_strategy-with-keras.py", "--saved_model_dir", "/train/saved_model/", "--checkpoint_dir", "/train/checkpoint"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Multi-worker training with Keras | ||
|
||
This directory contains a example for running multi-worker distributed training | ||
using Tensorflow 2.1 keras API on Kubeflow. For more information about the | ||
source code, please see TensorFlow tutorials [here](https://www.tensorflow.org/tutorials/distribute/keras) and [here](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) | ||
|
||
## Prerequisite | ||
|
||
Your cluster must be configured to use Multiple GPUs, | ||
please follow the [instructions](https://www.kubeflow.org/docs/components/training/tftraining/#using-gpus) | ||
|
||
## Steps | ||
|
||
1. Build a image | ||
``` | ||
docker build -f Dockerfile -t kubeflow/multi_worker_strategy:v1.0 . | ||
``` | ||
2. Specify your storageClassName and create a persistent volume claim to save | ||
models and checkpoints | ||
``` | ||
kubectl -n ${NAMESPACE} create -f pvc.yaml | ||
``` | ||
3. Create a TFJob, if you use some GPUs other than NVIDIA, please replace | ||
`nvidia.com/gpu` with your GPU vendor in the `limits` section. | ||
``` | ||
kubectl -n ${NAMESPACE} create -f multi_worker_tfjob.yaml | ||
``` |
161 changes: 161 additions & 0 deletions
161
examples/v1/distribution_strategy/keras-API/multi_worker_strategy-with-keras.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright 2018 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""An example of multi-worker training with Keras model using Strategy API.""" | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import argparse | ||
import json | ||
import os | ||
|
||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
from tensorflow.keras import datasets, layers, models | ||
|
||
|
||
def make_datasets_unbatched(): | ||
BUFFER_SIZE = 10000 | ||
|
||
# Scaling MNIST data from (0, 255] to (0., 1.] | ||
def scale(image, label): | ||
image = tf.cast(image, tf.float32) | ||
image /= 255 | ||
return image, label | ||
|
||
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True) | ||
|
||
return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE) | ||
|
||
|
||
def build_and_compile_cnn_model(): | ||
model = models.Sequential() | ||
model.add( | ||
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))) | ||
model.add(layers.MaxPooling2D((2, 2))) | ||
model.add(layers.Conv2D(64, (3, 3), activation='relu')) | ||
model.add(layers.MaxPooling2D((2, 2))) | ||
model.add(layers.Conv2D(64, (3, 3), activation='relu')) | ||
model.add(layers.Flatten()) | ||
model.add(layers.Dense(64, activation='relu')) | ||
model.add(layers.Dense(10, activation='softmax')) | ||
|
||
model.summary() | ||
|
||
model.compile(optimizer='adam', | ||
loss='sparse_categorical_crossentropy', | ||
metrics=['accuracy']) | ||
|
||
return model | ||
|
||
|
||
def decay(epoch): | ||
if epoch < 3: | ||
return 1e-3 | ||
elif epoch >= 3 and epoch < 7: | ||
return 1e-4 | ||
else: | ||
return 1e-5 | ||
|
||
|
||
def main(args): | ||
|
||
# MultiWorkerMirroredStrategy creates copies of all variables in the model's | ||
# layers on each device across all workers | ||
# if your GPUs don't support NCCL, replace "communication" with another | ||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( | ||
communication=tf.distribute.experimental.CollectiveCommunication.NCCL) | ||
|
||
BATCH_SIZE_PER_REPLICA = 64 | ||
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync | ||
|
||
with strategy.scope(): | ||
ds_train = make_datasets_unbatched().batch(BATCH_SIZE).repeat() | ||
options = tf.data.Options() | ||
options.experimental_distribute.auto_shard_policy = \ | ||
tf.data.experimental.AutoShardPolicy.DATA | ||
ds_train = ds_train.with_options(options) | ||
# Model building/compiling need to be within `strategy.scope()`. | ||
multi_worker_model = build_and_compile_cnn_model() | ||
|
||
# Define the checkpoint directory to store the checkpoints | ||
checkpoint_dir = args.checkpoint_dir | ||
|
||
# Name of the checkpoint files | ||
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") | ||
|
||
# Function for decaying the learning rate. | ||
# You can define any decay function you need. | ||
# Callback for printing the LR at the end of each epoch. | ||
class PrintLR(tf.keras.callbacks.Callback): | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
print('\nLearning rate for epoch {} is {}'.format( | ||
epoch + 1, multi_worker_model.optimizer.lr.numpy())) | ||
|
||
callbacks = [ | ||
tf.keras.callbacks.TensorBoard(log_dir='./logs'), | ||
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, | ||
save_weights_only=True), | ||
tf.keras.callbacks.LearningRateScheduler(decay), | ||
PrintLR() | ||
] | ||
|
||
# Keras' `model.fit()` trains the model with specified number of epochs and | ||
# number of steps per epoch. Note that the numbers here are for demonstration | ||
# purposes only and may not sufficiently produce a model with good quality. | ||
multi_worker_model.fit(ds_train, | ||
epochs=10, | ||
steps_per_epoch=70, | ||
callbacks=callbacks) | ||
|
||
# Saving a model | ||
# Let `is_chief` be a utility function that inspects the cluster spec and | ||
# current task type and returns True if the worker is the chief and False | ||
# otherwise. | ||
def is_chief(): | ||
return (TASK_INDEX == 0) | ||
|
||
if is_chief(): | ||
model_path = args.saved_model_dir | ||
|
||
else: | ||
# Save to a path that is unique across workers. | ||
model_path = args.saved_model_dir + '/worker_tmp_' + str(TASK_INDEX) | ||
|
||
multi_worker_model.save(model_path) | ||
|
||
|
||
if __name__ == '__main__': | ||
os.environ['NCCL_DEBUG'] = 'INFO' | ||
|
||
tfds.disable_progress_bar() | ||
|
||
# to decide if a worker is chief, get TASK_INDEX in Cluster info | ||
tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}') | ||
TASK_INDEX = tf_config['task']['index'] | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--saved_model_dir', | ||
type=str, | ||
required=True, | ||
help='Tensorflow export directory.') | ||
|
||
parser.add_argument('--checkpoint_dir', | ||
type=str, | ||
required=True, | ||
help='Tensorflow checkpoint directory.') | ||
|
||
args = parser.parse_args() | ||
main(args) |
25 changes: 25 additions & 0 deletions
25
examples/v1/distribution_strategy/keras-API/multi_worker_tfjob.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
apiVersion: kubeflow.org/v1 | ||
kind: TFJob | ||
metadata: | ||
name: multi-worker | ||
spec: | ||
cleanPodPolicy: None | ||
tfReplicaSpecs: | ||
Worker: | ||
replicas: 2 | ||
restartPolicy: Never | ||
template: | ||
spec: | ||
containers: | ||
- name: tensorflow | ||
image: kubeflowimages/multi_worker_strategy:v20200522-2a5b081c | ||
volumeMounts: | ||
- mountPath: /train | ||
name: training | ||
resources: | ||
limits: | ||
nvidia.com/gpu: 1 | ||
volumes: | ||
- name: training | ||
persistentVolumeClaim: | ||
claimName: strategy-volume |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
apiVersion: v1 | ||
kind: PersistentVolumeClaim | ||
metadata: | ||
name: strategy-volume | ||
labels: | ||
app: strategy-volume | ||
spec: | ||
storageClassName: "Your storageClassName" | ||
accessModes: | ||
- ReadWriteMany | ||
resources: | ||
requests: | ||
storage: 10Gi |