diff --git a/python/orca/dev/example/run-example-tests.sh b/python/orca/dev/example/run-example-tests.sh index fb813e8e68e..02a26d302ae 100644 --- a/python/orca/dev/example/run-example-tests.sh +++ b/python/orca/dev/example/run-example-tests.sh @@ -237,17 +237,17 @@ else export PYTHONPATH=`pwd`/analytics-zoo-tensorflow-models/slim:$PYTHONPATH fi -echo "start example test for tensorflow distributed_training train_lenet 1" +echo "start example test for TFPark tf_optimizer train_lenet 1" ${SPARK_HOME}/bin/spark-submit \ --master ${MASTER} \ --driver-memory 200g \ --executor-memory 200g \ --properties-file ${ANALYTICS_ZOO_CONF} \ - --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/train_lenet.py \ + --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/train_lenet.py \ --jars ${ANALYTICS_ZOO_JAR} \ --conf spark.driver.extraClassPath=${ANALYTICS_ZOO_JAR} \ --conf spark.executor.extraClassPath=${ANALYTICS_ZOO_JAR} \ - ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/train_lenet.py 1 1000\ + ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/train_lenet.py 1 1000\ sed "s%/tmp%analytics-zoo-tensorflow-models%g;s%models/slim%slim%g" if [ -d analytics-zoo-tensorflow-models/slim ] @@ -269,11 +269,11 @@ ${SPARK_HOME}/bin/spark-submit \ --driver-memory 200g \ --executor-memory 200g \ --properties-file ${ANALYTICS_ZOO_CONF} \ - --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/evaluate_lenet.py \ + --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/evaluate_lenet.py \ --jars ${ANALYTICS_ZOO_JAR} \ --conf spark.driver.extraClassPath=${ANALYTICS_ZOO_JAR} \ --conf spark.executor.extraClassPath=${ANALYTICS_ZOO_JAR} \ - ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/evaluate_lenet.py 1000\ + ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/evaluate_lenet.py 1000\ echo "start example test for tensorflow distributed_training train_mnist_keras 3" ${SPARK_HOME}/bin/spark-submit \ @@ -281,11 +281,11 @@ ${SPARK_HOME}/bin/spark-submit \ --driver-memory 200g \ --executor-memory 200g \ --properties-file ${ANALYTICS_ZOO_CONF} \ - --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/train_mnist_keras.py \ + --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/train_mnist_keras.py \ --jars ${ANALYTICS_ZOO_JAR} \ --conf spark.driver.extraClassPath=${ANALYTICS_ZOO_JAR} \ --conf spark.executor.extraClassPath=${ANALYTICS_ZOO_JAR} \ - ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/train_mnist_keras.py 1 1000\ + ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/train_mnist_keras.py 1 1000\ echo "start example test for tensorflow distributed_training evaluate_lenet 4" ${SPARK_HOME}/bin/spark-submit \ @@ -293,11 +293,11 @@ ${SPARK_HOME}/bin/spark-submit \ --driver-memory 200g \ --executor-memory 200g \ --properties-file ${ANALYTICS_ZOO_CONF} \ - --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/evaluate_mnist_keras.py \ + --py-files ${ANALYTICS_ZOO_PYZIP},${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/evaluate_mnist_keras.py \ --jars ${ANALYTICS_ZOO_JAR} \ --conf spark.driver.extraClassPath=${ANALYTICS_ZOO_JAR} \ --conf spark.executor.extraClassPath=${ANALYTICS_ZOO_JAR} \ - ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/distributed_training/evaluate_mnist_keras.py 1000\ + ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/tensorflow/tfpark/tf_optimizer/evaluate_mnist_keras.py 1000\ now=$(date "+%s") diff --git a/python/orca/example/tfpark/README.md b/python/orca/example/tfpark/README.md index 971624b4594..01ab215777a 100644 --- a/python/orca/example/tfpark/README.md +++ b/python/orca/example/tfpark/README.md @@ -41,7 +41,7 @@ Using TFDataset as data input export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package export SPARK_HOME=... # the root directory of Spark -sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] keras_dataset.py +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] keras/keras_dataset.py ``` Using numpy.ndarray as data input @@ -49,14 +49,14 @@ Using numpy.ndarray as data input export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package export SPARK_HOME=... # the root directory of Spark -sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] keras_ndarray.py +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] keras/keras_ndarray.py ``` ## Run the TFEstimator example after pip install Using TFDataset as data input ```bash -python estimator_dataset.py +python estimator/estimator_dataset.py ``` Using FeatureSet as data input @@ -76,7 +76,7 @@ IMAGE_PATH=... NUM_CLASSES=.. -python estimator_inception.py --image-path $IMAGE_PATH --num-classes $NUM_CLASSES +python estimator/estimator_inception.py --image-path $IMAGE_PATH --num-classes $NUM_CLASSES ``` ## Run the TFEstimator example with prebuilt package @@ -86,7 +86,7 @@ Using TFDataset as data input export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package export SPARK_HOME=... # the root directory of Spark -sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] estimator_dataset.py +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] estimator/estimator_dataset.py ``` Using FeatureSet as data input @@ -110,5 +110,65 @@ IMAGE_PATH=... NUM_CLASSES=.. -sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] estimator_inception.py --image-path $IMAGE_PATH --num-classes $NUM_CLASSES +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] estimator/estimator_inception.py --image-path $IMAGE_PATH --num-classes $NUM_CLASSES +``` + +## Run the Training Example using TFOptimizer after pip install + +```bash +python tf_optimzer/train_lenet.py +``` + +For Keras users: + +```bash +python tf_optimizer/train_mnist_keras.py +``` + +## Run the Training Example using TFOptimizer with prebuilt package + +```bash +export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package +export SPARK_HOME=... # the root directory of Spark + +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] tf_optimizer/train_lenet.py +``` + +For Keras users: + +```bash +export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package +export SPARK_HOME=... # the root directory of Spark + +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] tf_optimizer/train_mnist_keras.py +``` + +## Run the Evaluation Example using TFPredictor after pip install + +```bash +python tf_optimizer/evaluate_lenet.py +``` + +For Keras users: + +```bash +python tf_optimizer/evaluate_mnist_keras.py +``` + +## Run the Evaluation Example using TFPredictor with prebuilt package + +```bash +export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package +export SPARK_HOME=... # the root directory of Spark + +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] tf_optimizer/evaluate_lenet.py +``` + +For Keras users: + +```bash +export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package +export SPARK_HOME=... # the root directory of Spark + +sh $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] tf_optimizer/evaluate_mnist_keras.py ``` \ No newline at end of file diff --git a/python/orca/example/tfpark/estimator/__init__.py b/python/orca/example/tfpark/estimator/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/python/orca/example/tfpark/estimator/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/orca/example/tfpark/estimator_dataset.py b/python/orca/example/tfpark/estimator/estimator_dataset.py similarity index 100% rename from python/orca/example/tfpark/estimator_dataset.py rename to python/orca/example/tfpark/estimator/estimator_dataset.py diff --git a/python/orca/example/tfpark/estimator_inception.py b/python/orca/example/tfpark/estimator/estimator_inception.py similarity index 100% rename from python/orca/example/tfpark/estimator_inception.py rename to python/orca/example/tfpark/estimator/estimator_inception.py diff --git a/python/orca/example/tfpark/keras/__init__.py b/python/orca/example/tfpark/keras/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/python/orca/example/tfpark/keras/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/orca/example/tfpark/keras_dataset.py b/python/orca/example/tfpark/keras/keras_dataset.py similarity index 100% rename from python/orca/example/tfpark/keras_dataset.py rename to python/orca/example/tfpark/keras/keras_dataset.py diff --git a/python/orca/example/tfpark/keras_ndarray.py b/python/orca/example/tfpark/keras/keras_ndarray.py similarity index 100% rename from python/orca/example/tfpark/keras_ndarray.py rename to python/orca/example/tfpark/keras/keras_ndarray.py diff --git a/python/orca/example/tfpark/tf_optimizer/__init__.py b/python/orca/example/tfpark/tf_optimizer/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/python/orca/example/tfpark/tf_optimizer/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/orca/example/tfpark/tf_optimizer/evaluate_lenet.py b/python/orca/example/tfpark/tf_optimizer/evaluate_lenet.py new file mode 100644 index 00000000000..df3511ce554 --- /dev/null +++ b/python/orca/example/tfpark/tf_optimizer/evaluate_lenet.py @@ -0,0 +1,81 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf +from zoo import init_nncontext +from zoo.tfpark import TFDataset, TFPredictor +import numpy as np +import sys + +from bigdl.dataset import mnist +from bigdl.dataset.transformer import * + +sys.path.append("/tmp/models/slim") # add the slim library +from nets import lenet + +slim = tf.contrib.slim + + +def main(data_num): + + sc = init_nncontext() + + # get data, pre-process and create TFDataset + (images_data, labels_data) = mnist.read_data_sets("/tmp/mnist", "test") + image_rdd = sc.parallelize(images_data[:data_num]) + labels_rdd = sc.parallelize(labels_data[:data_num]) + rdd = image_rdd.zip(labels_rdd) \ + .map(lambda rec_tuple: [normalizer(rec_tuple[0], mnist.TRAIN_MEAN, mnist.TRAIN_STD), + np.array(rec_tuple[1])]) + + dataset = TFDataset.from_rdd(rdd, + names=["features", "labels"], + shapes=[[28, 28, 1], [1]], + types=[tf.float32, tf.int32], + batch_per_thread=20 + ) + + # construct the model from TFDataset + images, labels = dataset.tensors + + labels = tf.squeeze(labels) + + with slim.arg_scope(lenet.lenet_arg_scope()): + logits, end_points = lenet.lenet(images, num_classes=10, is_training=False) + + predictions = tf.to_int32(tf.argmax(logits, axis=1)) + correct = tf.expand_dims(tf.to_int32(tf.equal(predictions, labels)), axis=1) + + saver = tf.train.Saver() + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + saver.restore(sess, "/tmp/lenet/model") + + predictor = TFPredictor(sess, [correct]) + + accuracy = predictor.predict().mean() + + print("predict accuracy is %s" % accuracy) + + +if __name__ == '__main__': + + data_num = 10000 + + if len(sys.argv) > 1: + data_num = int(sys.argv[1]) + main(data_num) diff --git a/python/orca/example/tfpark/tf_optimizer/evaluate_mnist_keras.py b/python/orca/example/tfpark/tf_optimizer/evaluate_mnist_keras.py new file mode 100644 index 00000000000..1ec7eca6367 --- /dev/null +++ b/python/orca/example/tfpark/tf_optimizer/evaluate_mnist_keras.py @@ -0,0 +1,85 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf +from zoo import init_nncontext +from zoo.tfpark import TFDataset, TFPredictor +import numpy as np +import sys +from tensorflow.keras.models import Model +from tensorflow.keras.layers import * + +from bigdl.dataset import mnist +from bigdl.dataset.transformer import * + +DISTRIBUTED = True + + +def main(data_num): + + data = Input(shape=[28, 28, 1]) + + x = Flatten()(data) + x = Dense(64, activation='relu')(x) + x = Dense(64, activation='relu')(x) + predictions = Dense(10, activation='softmax')(x) + + model = Model(inputs=data, outputs=predictions) + + model.load_weights("/tmp/mnist_keras.h5") + + if DISTRIBUTED: + # using RDD api to do distributed evaluation + sc = init_nncontext() + # get data, pre-process and create TFDataset + (images_data, labels_data) = mnist.read_data_sets("/tmp/mnist", "test") + image_rdd = sc.parallelize(images_data[:data_num]) + labels_rdd = sc.parallelize(labels_data[:data_num]) + rdd = image_rdd.zip(labels_rdd) \ + .map(lambda rec_tuple: [normalizer(rec_tuple[0], mnist.TRAIN_MEAN, mnist.TRAIN_STD)]) + + dataset = TFDataset.from_rdd(rdd, + names=["features"], + shapes=[[28, 28, 1]], + types=[tf.float32], + batch_per_thread=20 + ) + predictor = TFPredictor.from_keras(model, dataset) + + accuracy = predictor.predict().zip(labels_rdd).map(lambda x: np.argmax(x[0]) == x[1]).mean() + + print("predict accuracy is %s" % accuracy) + + else: + # using keras api for local evaluation + model.compile(optimizer='rmsprop', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + (images_data, labels_data) = mnist.read_data_sets("/tmp/mnist", "test") + images_data = normalizer(images_data, mnist.TRAIN_MEAN, mnist.TRAIN_STD) + result = model.evaluate(images_data, labels_data) + print(model.metrics_names) + print(result) + + +if __name__ == '__main__': + + data_num = 10000 + + if len(sys.argv) > 1: + data_num = int(sys.argv[1]) + main(data_num) diff --git a/python/orca/example/tfpark/tf_optimizer/train_lenet.py b/python/orca/example/tfpark/tf_optimizer/train_lenet.py new file mode 100644 index 00000000000..77a9e95345d --- /dev/null +++ b/python/orca/example/tfpark/tf_optimizer/train_lenet.py @@ -0,0 +1,86 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import heapq + +import tensorflow as tf +from zoo import init_nncontext +from zoo.tfpark import TFOptimizer, TFDataset +from bigdl.optim.optimizer import * +import numpy as np +import sys + +from bigdl.dataset import mnist +from bigdl.dataset.transformer import * + +sys.path.append("/tmp/models/slim") # add the slim library +from nets import lenet + +slim = tf.contrib.slim + + +def main(max_epoch, data_num): + sc = init_nncontext() + + # get data, pre-process and create TFDataset + def get_data_rdd(dataset): + (images_data, labels_data) = mnist.read_data_sets("/tmp/mnist", dataset) + image_rdd = sc.parallelize(images_data[:data_num]) + labels_rdd = sc.parallelize(labels_data[:data_num]) + rdd = image_rdd.zip(labels_rdd) \ + .map(lambda rec_tuple: [normalizer(rec_tuple[0], mnist.TRAIN_MEAN, mnist.TRAIN_STD), + np.array(rec_tuple[1])]) + return rdd + + training_rdd = get_data_rdd("train") + testing_rdd = get_data_rdd("test") + dataset = TFDataset.from_rdd(training_rdd, + names=["features", "labels"], + shapes=[[28, 28, 1], []], + types=[tf.float32, tf.int32], + batch_size=280, + val_rdd=testing_rdd + ) + + # construct the model from TFDataset + images, labels = dataset.tensors + + with slim.arg_scope(lenet.lenet_arg_scope()): + logits, end_points = lenet.lenet(images, num_classes=10, is_training=True) + + loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)) + + # create a optimizer + optimizer = TFOptimizer(loss, Adam(1e-3), + val_outputs=[logits], + val_labels=[labels], + val_method=Top1Accuracy()) + optimizer.set_train_summary(TrainSummary("/tmp/az_lenet", "lenet")) + optimizer.set_val_summary(ValidationSummary("/tmp/az_lenet", "lenet")) + # kick off training + optimizer.optimize(end_trigger=MaxEpoch(max_epoch)) + + saver = tf.train.Saver() + saver.save(optimizer.sess, "/tmp/lenet/model") + +if __name__ == '__main__': + + max_epoch = 5 + data_num = 60000 + + if len(sys.argv) > 1: + max_epoch = int(sys.argv[1]) + data_num = int(sys.argv[2]) + main(max_epoch, data_num) diff --git a/python/orca/example/tfpark/tf_optimizer/train_mnist_keras.py b/python/orca/example/tfpark/tf_optimizer/train_mnist_keras.py new file mode 100644 index 00000000000..de7339d0d54 --- /dev/null +++ b/python/orca/example/tfpark/tf_optimizer/train_mnist_keras.py @@ -0,0 +1,82 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf +from zoo import init_nncontext +from zoo.tfpark import TFOptimizer, TFDataset +from bigdl.optim.optimizer import * +import sys +from tensorflow.keras.models import Model +from tensorflow.keras.layers import * + +from bigdl.dataset import mnist +from bigdl.dataset.transformer import * + + +def main(max_epoch, data_num): + sc = init_nncontext() + + # get data, pre-process and create TFDataset + def get_data_rdd(dataset): + (images_data, labels_data) = mnist.read_data_sets("/tmp/mnist", dataset) + image_rdd = sc.parallelize(images_data[:data_num]) + labels_rdd = sc.parallelize(labels_data[:data_num]) + rdd = image_rdd.zip(labels_rdd) \ + .map(lambda rec_tuple: [normalizer(rec_tuple[0], mnist.TRAIN_MEAN, mnist.TRAIN_STD), + np.array(rec_tuple[1])]) + return rdd + + training_rdd = get_data_rdd("train") + testing_rdd = get_data_rdd("test") + dataset = TFDataset.from_rdd(training_rdd, + names=["features", "labels"], + shapes=[[28, 28, 1], []], + types=[tf.float32, tf.int32], + batch_size=280, + val_rdd=testing_rdd + ) + + data = Input(shape=[28, 28, 1]) + + x = Flatten()(data) + x = Dense(64, activation='relu')(x) + x = Dense(64, activation='relu')(x) + predictions = Dense(10, activation='softmax')(x) + + model = Model(inputs=data, outputs=predictions) + + model.compile(optimizer='rmsprop', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + optimizer = TFOptimizer.from_keras(model, dataset) + + optimizer.set_train_summary(TrainSummary("/tmp/mnist_log", "mnist")) + optimizer.set_val_summary(ValidationSummary("/tmp/mnist_log", "mnist")) + # kick off training + optimizer.optimize(end_trigger=MaxEpoch(max_epoch)) + + model.save_weights("/tmp/mnist_keras.h5") + +if __name__ == '__main__': + + max_epoch = 5 + data_num = 60000 + + if len(sys.argv) > 1: + max_epoch = int(sys.argv[1]) + data_num = int(sys.argv[2]) + main(max_epoch, data_num)