From 10095b2998f5c520b44b0cb5e0cf2b184eab7e67 Mon Sep 17 00:00:00 2001 From: jenniew <jenniewang123@gmail.com> Date: Tue, 11 Aug 2020 12:24:30 -0700 Subject: [PATCH] Predict result as xshard (#2590) * predict result as xshard * merge2 * update predict result * update test * fix style --- .../orca/src/bigdl/orca/learn/tf/estimator.py | 4 +++ python/orca/src/bigdl/orca/learn/tf/utils.py | 26 ++++++++++++-- .../learn/spark/test_estimator_for_spark.py | 1 + .../spark/test_estimator_keras_for_spark.py | 34 ++++++++++++++++++- 4 files changed, 62 insertions(+), 3 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/tf/estimator.py b/python/orca/src/bigdl/orca/learn/tf/estimator.py index 33a6352dcfa..711818bb3c0 100644 --- a/python/orca/src/bigdl/orca/learn/tf/estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf/estimator.py @@ -201,6 +201,8 @@ def predict(self, data, batch_size=4, predicted_rdd = tfnet.predict(dataset) if isinstance(data, DataFrame): return convert_predict_to_dataframe(data, predicted_rdd) + elif isinstance(data, SparkXShards): + return convert_predict_to_xshard(data, predicted_rdd) else: return predicted_rdd @@ -297,6 +299,8 @@ def predict(self, data, batch_size=4, predicted_rdd = self.model.predict(dataset, batch_size) if isinstance(data, DataFrame): return convert_predict_to_dataframe(data, predicted_rdd) + elif isinstance(data, SparkXShards): + return convert_predict_to_xshard(data, predicted_rdd) else: return predicted_rdd diff --git a/python/orca/src/bigdl/orca/learn/tf/utils.py b/python/orca/src/bigdl/orca/learn/tf/utils.py index c6fca6c3f46..ea44b730d9e 100644 --- a/python/orca/src/bigdl/orca/learn/tf/utils.py +++ b/python/orca/src/bigdl/orca/learn/tf/utils.py @@ -19,6 +19,7 @@ import re import shutil import tensorflow as tf +import numpy as np from pyspark.sql.dataframe import DataFrame from zoo.orca.data import SparkXShards @@ -105,14 +106,20 @@ def to_dataset(data, batch_size, batch_per_thread, validation_data, def convert_predict_to_dataframe(df, prediction_rdd): from pyspark.sql import Row - from pyspark.sql.types import StructType, StructField, FloatType + from pyspark.sql.types import StructType, StructField, FloatType, ArrayType from pyspark.ml.linalg import VectorUDT, Vectors def combine(pair): + # list of np array + if isinstance(pair[1], list): + row = Row(*([pair[0][col] for col in pair[0].__fields__] + + [[Vectors.dense(elem) for elem in pair[1]]])) + return row, ArrayType(VectorUDT()) # scalar - if len(pair[1].shape) == 0: + elif len(pair[1].shape) == 0: row = Row(*([pair[0][col] for col in pair[0].__fields__] + [float(pair[1].item(0))])) return row, FloatType() + # np array else: row = Row(*([pair[0][col] for col in pair[0].__fields__] + [Vectors.dense(pair[1])])) return row, VectorUDT() @@ -125,6 +132,21 @@ def combine(pair): return result_df +def convert_predict_to_xshard(data_shard, prediction_rdd): + def transform_predict(iter): + predictions = list(iter) + # list of np array + if isinstance(predictions[0], list): + predictions = np.array(predictions).T.tolist() + result = [np.array(predict) for predict in predictions] + return [{'prediction': result}] + # np array + else: + return [{'prediction': np.array(predictions)}] + + return SparkXShards(prediction_rdd.mapPartitions(transform_predict)) + + def find_latest_checkpoint(model_dir): import os import re diff --git a/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py b/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py index fd2069d1848..bcdcb3652b3 100644 --- a/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py +++ b/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py @@ -88,6 +88,7 @@ def transform(df): data_shard = data_shard.transform_shard(transform) predictions = est.predict(data_shard).collect() + assert 'prediction' in predictions[0] print(predictions) def test_estimator_graph_fit(self): diff --git a/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py b/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py index 731b5af74b1..156c95e6b0d 100644 --- a/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py +++ b/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py @@ -22,6 +22,7 @@ from bigdl.optim.optimizer import SeveralIteration from zoo.orca.learn.tf.estimator import Estimator from zoo.common.nncontext import * +from zoo.orca.learn.tf.utils import convert_predict_to_dataframe import zoo.orca.data.pandas @@ -70,6 +71,8 @@ def create_model_with_clip(self): def test_estimator_keras_xshards(self): import zoo.orca.data.pandas + tf.reset_default_graph() + model = self.create_model() file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv") data_shard = zoo.orca.data.pandas.read_csv(file_path) @@ -104,11 +107,13 @@ def transform(df): data_shard = data_shard.transform_shard(transform) predictions = est.predict(data_shard).collect() - assert len(predictions[0]) == 2 + assert predictions[0]['prediction'].shape[1] == 2 def test_estimator_keras_xshards_options(self): import zoo.orca.data.pandas + tf.reset_default_graph() + model = self.create_model() file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv") data_shard = zoo.orca.data.pandas.read_csv(file_path) @@ -157,6 +162,8 @@ def transform(df): def test_estimator_keras_xshards_clip(self): import zoo.orca.data.pandas + tf.reset_default_graph() + model = self.create_model_with_clip() file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv") data_shard = zoo.orca.data.pandas.read_csv(file_path) @@ -180,6 +187,8 @@ def transform(df): def test_estimator_keras_xshards_checkpoint(self): import zoo.orca.data.pandas + tf.reset_default_graph() + import tensorflow.keras.backend as K K.clear_session() tf.reset_default_graph() @@ -227,6 +236,9 @@ def transform(df): shutil.rmtree(temp) def test_estimator_keras_dataframe(self): + + tf.reset_default_graph() + model = self.create_model() sc = init_nncontext() sqlcontext = SQLContext(sc) @@ -253,6 +265,9 @@ def test_estimator_keras_dataframe(self): assert len(predictions) == 10 def test_estimator_keras_dataframe_no_fit(self): + + tf.reset_default_graph() + model = self.create_model() sc = init_nncontext() sqlcontext = SQLContext(sc) @@ -272,6 +287,23 @@ def test_estimator_keras_dataframe_no_fit(self): predictions = prediction_df.collect() assert len(predictions) == 10 + def test_convert_predict_list_of_array(self): + + tf.reset_default_graph() + + sc = init_nncontext() + sqlcontext = SQLContext(sc) + rdd = sc.parallelize([(1, 2, 3), (4, 5, 6), (7, 8, 9)]) + df = rdd.toDF(["feature", "label", "c"]) + predict_rdd = df.rdd.map(lambda row: [np.array([1, 2]), np.array(0)]) + resultDF = convert_predict_to_dataframe(df, predict_rdd) + resultDF.printSchema() + print(resultDF.collect()[0]) + predict_rdd = df.rdd.map(lambda row: np.array(1)) + resultDF = convert_predict_to_dataframe(df, predict_rdd) + resultDF.printSchema() + print(resultDF.collect()[0]) + if __name__ == "__main__": import pytest