From 10095b2998f5c520b44b0cb5e0cf2b184eab7e67 Mon Sep 17 00:00:00 2001
From: jenniew <jenniewang123@gmail.com>
Date: Tue, 11 Aug 2020 12:24:30 -0700
Subject: [PATCH] Predict result as xshard (#2590)

* predict result as xshard

* merge2

* update predict result

* update test

* fix style
---
 .../orca/src/bigdl/orca/learn/tf/estimator.py |  4 +++
 python/orca/src/bigdl/orca/learn/tf/utils.py  | 26 ++++++++++++--
 .../learn/spark/test_estimator_for_spark.py   |  1 +
 .../spark/test_estimator_keras_for_spark.py   | 34 ++++++++++++++++++-
 4 files changed, 62 insertions(+), 3 deletions(-)

diff --git a/python/orca/src/bigdl/orca/learn/tf/estimator.py b/python/orca/src/bigdl/orca/learn/tf/estimator.py
index 33a6352dcfa..711818bb3c0 100644
--- a/python/orca/src/bigdl/orca/learn/tf/estimator.py
+++ b/python/orca/src/bigdl/orca/learn/tf/estimator.py
@@ -201,6 +201,8 @@ def predict(self, data, batch_size=4,
         predicted_rdd = tfnet.predict(dataset)
         if isinstance(data, DataFrame):
             return convert_predict_to_dataframe(data, predicted_rdd)
+        elif isinstance(data, SparkXShards):
+            return convert_predict_to_xshard(data, predicted_rdd)
         else:
             return predicted_rdd
 
@@ -297,6 +299,8 @@ def predict(self, data, batch_size=4,
         predicted_rdd = self.model.predict(dataset, batch_size)
         if isinstance(data, DataFrame):
             return convert_predict_to_dataframe(data, predicted_rdd)
+        elif isinstance(data, SparkXShards):
+            return convert_predict_to_xshard(data, predicted_rdd)
         else:
             return predicted_rdd
 
diff --git a/python/orca/src/bigdl/orca/learn/tf/utils.py b/python/orca/src/bigdl/orca/learn/tf/utils.py
index c6fca6c3f46..ea44b730d9e 100644
--- a/python/orca/src/bigdl/orca/learn/tf/utils.py
+++ b/python/orca/src/bigdl/orca/learn/tf/utils.py
@@ -19,6 +19,7 @@
 import re
 import shutil
 import tensorflow as tf
+import numpy as np
 from pyspark.sql.dataframe import DataFrame
 
 from zoo.orca.data import SparkXShards
@@ -105,14 +106,20 @@ def to_dataset(data, batch_size, batch_per_thread, validation_data,
 
 def convert_predict_to_dataframe(df, prediction_rdd):
     from pyspark.sql import Row
-    from pyspark.sql.types import StructType, StructField, FloatType
+    from pyspark.sql.types import StructType, StructField, FloatType, ArrayType
     from pyspark.ml.linalg import VectorUDT, Vectors
 
     def combine(pair):
+        # list of np array
+        if isinstance(pair[1], list):
+            row = Row(*([pair[0][col] for col in pair[0].__fields__] +
+                        [[Vectors.dense(elem) for elem in pair[1]]]))
+            return row, ArrayType(VectorUDT())
         # scalar
-        if len(pair[1].shape) == 0:
+        elif len(pair[1].shape) == 0:
             row = Row(*([pair[0][col] for col in pair[0].__fields__] + [float(pair[1].item(0))]))
             return row, FloatType()
+        # np array
         else:
             row = Row(*([pair[0][col] for col in pair[0].__fields__] + [Vectors.dense(pair[1])]))
             return row, VectorUDT()
@@ -125,6 +132,21 @@ def combine(pair):
     return result_df
 
 
+def convert_predict_to_xshard(data_shard, prediction_rdd):
+    def transform_predict(iter):
+        predictions = list(iter)
+        # list of np array
+        if isinstance(predictions[0], list):
+            predictions = np.array(predictions).T.tolist()
+            result = [np.array(predict) for predict in predictions]
+            return [{'prediction': result}]
+        # np array
+        else:
+            return [{'prediction': np.array(predictions)}]
+
+    return SparkXShards(prediction_rdd.mapPartitions(transform_predict))
+
+
 def find_latest_checkpoint(model_dir):
     import os
     import re
diff --git a/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py b/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py
index fd2069d1848..bcdcb3652b3 100644
--- a/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py
+++ b/python/orca/test/bigdl/orca/learn/spark/test_estimator_for_spark.py
@@ -88,6 +88,7 @@ def transform(df):
 
         data_shard = data_shard.transform_shard(transform)
         predictions = est.predict(data_shard).collect()
+        assert 'prediction' in predictions[0]
         print(predictions)
 
     def test_estimator_graph_fit(self):
diff --git a/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py b/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py
index 731b5af74b1..156c95e6b0d 100644
--- a/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py
+++ b/python/orca/test/bigdl/orca/learn/spark/test_estimator_keras_for_spark.py
@@ -22,6 +22,7 @@
 from bigdl.optim.optimizer import SeveralIteration
 from zoo.orca.learn.tf.estimator import Estimator
 from zoo.common.nncontext import *
+from zoo.orca.learn.tf.utils import convert_predict_to_dataframe
 import zoo.orca.data.pandas
 
 
@@ -70,6 +71,8 @@ def create_model_with_clip(self):
     def test_estimator_keras_xshards(self):
         import zoo.orca.data.pandas
 
+        tf.reset_default_graph()
+
         model = self.create_model()
         file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
         data_shard = zoo.orca.data.pandas.read_csv(file_path)
@@ -104,11 +107,13 @@ def transform(df):
 
         data_shard = data_shard.transform_shard(transform)
         predictions = est.predict(data_shard).collect()
-        assert len(predictions[0]) == 2
+        assert predictions[0]['prediction'].shape[1] == 2
 
     def test_estimator_keras_xshards_options(self):
         import zoo.orca.data.pandas
 
+        tf.reset_default_graph()
+
         model = self.create_model()
         file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
         data_shard = zoo.orca.data.pandas.read_csv(file_path)
@@ -157,6 +162,8 @@ def transform(df):
     def test_estimator_keras_xshards_clip(self):
         import zoo.orca.data.pandas
 
+        tf.reset_default_graph()
+
         model = self.create_model_with_clip()
         file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
         data_shard = zoo.orca.data.pandas.read_csv(file_path)
@@ -180,6 +187,8 @@ def transform(df):
     def test_estimator_keras_xshards_checkpoint(self):
         import zoo.orca.data.pandas
 
+        tf.reset_default_graph()
+
         import tensorflow.keras.backend as K
         K.clear_session()
         tf.reset_default_graph()
@@ -227,6 +236,9 @@ def transform(df):
         shutil.rmtree(temp)
 
     def test_estimator_keras_dataframe(self):
+
+        tf.reset_default_graph()
+
         model = self.create_model()
         sc = init_nncontext()
         sqlcontext = SQLContext(sc)
@@ -253,6 +265,9 @@ def test_estimator_keras_dataframe(self):
         assert len(predictions) == 10
 
     def test_estimator_keras_dataframe_no_fit(self):
+
+        tf.reset_default_graph()
+
         model = self.create_model()
         sc = init_nncontext()
         sqlcontext = SQLContext(sc)
@@ -272,6 +287,23 @@ def test_estimator_keras_dataframe_no_fit(self):
         predictions = prediction_df.collect()
         assert len(predictions) == 10
 
+    def test_convert_predict_list_of_array(self):
+
+        tf.reset_default_graph()
+
+        sc = init_nncontext()
+        sqlcontext = SQLContext(sc)
+        rdd = sc.parallelize([(1, 2, 3), (4, 5, 6), (7, 8, 9)])
+        df = rdd.toDF(["feature", "label", "c"])
+        predict_rdd = df.rdd.map(lambda row: [np.array([1, 2]), np.array(0)])
+        resultDF = convert_predict_to_dataframe(df, predict_rdd)
+        resultDF.printSchema()
+        print(resultDF.collect()[0])
+        predict_rdd = df.rdd.map(lambda row: np.array(1))
+        resultDF = convert_predict_to_dataframe(df, predict_rdd)
+        resultDF.printSchema()
+        print(resultDF.collect()[0])
+
 
 if __name__ == "__main__":
     import pytest