Skip to content

Commit

Permalink
Predict result as xshard (intel-analytics#2590)
Browse files Browse the repository at this point in the history
* predict result as xshard

* merge2

* update predict result

* update test

* fix style
  • Loading branch information
jenniew authored and Wang, Yang committed Sep 26, 2021
1 parent 7da434d commit 10095b2
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 3 deletions.
4 changes: 4 additions & 0 deletions python/orca/src/bigdl/orca/learn/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def predict(self, data, batch_size=4,
predicted_rdd = tfnet.predict(dataset)
if isinstance(data, DataFrame):
return convert_predict_to_dataframe(data, predicted_rdd)
elif isinstance(data, SparkXShards):
return convert_predict_to_xshard(data, predicted_rdd)
else:
return predicted_rdd

Expand Down Expand Up @@ -297,6 +299,8 @@ def predict(self, data, batch_size=4,
predicted_rdd = self.model.predict(dataset, batch_size)
if isinstance(data, DataFrame):
return convert_predict_to_dataframe(data, predicted_rdd)
elif isinstance(data, SparkXShards):
return convert_predict_to_xshard(data, predicted_rdd)
else:
return predicted_rdd

Expand Down
26 changes: 24 additions & 2 deletions python/orca/src/bigdl/orca/learn/tf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import shutil
import tensorflow as tf
import numpy as np
from pyspark.sql.dataframe import DataFrame

from zoo.orca.data import SparkXShards
Expand Down Expand Up @@ -105,14 +106,20 @@ def to_dataset(data, batch_size, batch_per_thread, validation_data,

def convert_predict_to_dataframe(df, prediction_rdd):
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, FloatType
from pyspark.sql.types import StructType, StructField, FloatType, ArrayType
from pyspark.ml.linalg import VectorUDT, Vectors

def combine(pair):
# list of np array
if isinstance(pair[1], list):
row = Row(*([pair[0][col] for col in pair[0].__fields__] +
[[Vectors.dense(elem) for elem in pair[1]]]))
return row, ArrayType(VectorUDT())
# scalar
if len(pair[1].shape) == 0:
elif len(pair[1].shape) == 0:
row = Row(*([pair[0][col] for col in pair[0].__fields__] + [float(pair[1].item(0))]))
return row, FloatType()
# np array
else:
row = Row(*([pair[0][col] for col in pair[0].__fields__] + [Vectors.dense(pair[1])]))
return row, VectorUDT()
Expand All @@ -125,6 +132,21 @@ def combine(pair):
return result_df


def convert_predict_to_xshard(data_shard, prediction_rdd):
def transform_predict(iter):
predictions = list(iter)
# list of np array
if isinstance(predictions[0], list):
predictions = np.array(predictions).T.tolist()
result = [np.array(predict) for predict in predictions]
return [{'prediction': result}]
# np array
else:
return [{'prediction': np.array(predictions)}]

return SparkXShards(prediction_rdd.mapPartitions(transform_predict))


def find_latest_checkpoint(model_dir):
import os
import re
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def transform(df):

data_shard = data_shard.transform_shard(transform)
predictions = est.predict(data_shard).collect()
assert 'prediction' in predictions[0]
print(predictions)

def test_estimator_graph_fit(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from bigdl.optim.optimizer import SeveralIteration
from zoo.orca.learn.tf.estimator import Estimator
from zoo.common.nncontext import *
from zoo.orca.learn.tf.utils import convert_predict_to_dataframe
import zoo.orca.data.pandas


Expand Down Expand Up @@ -70,6 +71,8 @@ def create_model_with_clip(self):
def test_estimator_keras_xshards(self):
import zoo.orca.data.pandas

tf.reset_default_graph()

model = self.create_model()
file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
data_shard = zoo.orca.data.pandas.read_csv(file_path)
Expand Down Expand Up @@ -104,11 +107,13 @@ def transform(df):

data_shard = data_shard.transform_shard(transform)
predictions = est.predict(data_shard).collect()
assert len(predictions[0]) == 2
assert predictions[0]['prediction'].shape[1] == 2

def test_estimator_keras_xshards_options(self):
import zoo.orca.data.pandas

tf.reset_default_graph()

model = self.create_model()
file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
data_shard = zoo.orca.data.pandas.read_csv(file_path)
Expand Down Expand Up @@ -157,6 +162,8 @@ def transform(df):
def test_estimator_keras_xshards_clip(self):
import zoo.orca.data.pandas

tf.reset_default_graph()

model = self.create_model_with_clip()
file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
data_shard = zoo.orca.data.pandas.read_csv(file_path)
Expand All @@ -180,6 +187,8 @@ def transform(df):
def test_estimator_keras_xshards_checkpoint(self):
import zoo.orca.data.pandas

tf.reset_default_graph()

import tensorflow.keras.backend as K
K.clear_session()
tf.reset_default_graph()
Expand Down Expand Up @@ -227,6 +236,9 @@ def transform(df):
shutil.rmtree(temp)

def test_estimator_keras_dataframe(self):

tf.reset_default_graph()

model = self.create_model()
sc = init_nncontext()
sqlcontext = SQLContext(sc)
Expand All @@ -253,6 +265,9 @@ def test_estimator_keras_dataframe(self):
assert len(predictions) == 10

def test_estimator_keras_dataframe_no_fit(self):

tf.reset_default_graph()

model = self.create_model()
sc = init_nncontext()
sqlcontext = SQLContext(sc)
Expand All @@ -272,6 +287,23 @@ def test_estimator_keras_dataframe_no_fit(self):
predictions = prediction_df.collect()
assert len(predictions) == 10

def test_convert_predict_list_of_array(self):

tf.reset_default_graph()

sc = init_nncontext()
sqlcontext = SQLContext(sc)
rdd = sc.parallelize([(1, 2, 3), (4, 5, 6), (7, 8, 9)])
df = rdd.toDF(["feature", "label", "c"])
predict_rdd = df.rdd.map(lambda row: [np.array([1, 2]), np.array(0)])
resultDF = convert_predict_to_dataframe(df, predict_rdd)
resultDF.printSchema()
print(resultDF.collect()[0])
predict_rdd = df.rdd.map(lambda row: np.array(1))
resultDF = convert_predict_to_dataframe(df, predict_rdd)
resultDF.printSchema()
print(resultDF.collect()[0])


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit 10095b2

Please sign in to comment.