Skip to content

Commit

Permalink
support spark dataframe in tf2 estimator (#3113)
Browse files Browse the repository at this point in the history
* support spark dataframe in tf2 estimator

* fix style
  • Loading branch information
yangw1234 committed Sep 27, 2021
1 parent 3132daa commit d868479
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 4 deletions.
93 changes: 89 additions & 4 deletions python/orca/src/bigdl/orca/learn/tf2/tf_ray_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from zoo.orca.learn.tf2.tf_runner import TFRunner
from zoo.ray import RayContext
from zoo.tfpark.tf_dataset import convert_row_to_numpy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,6 +57,42 @@ def process_spark_xshards(spark_xshards, num_workers):
return max_length, ray_xshards


def arrays2dict(iter, feature_cols, label_cols):

feature_lists = [[] for col in feature_cols]
if label_cols is not None:
label_lists = [[] for col in label_cols]
else:
label_lists = None

for row in iter:
# feature
if not isinstance(row[0], list):
features = [row[0]]
else:
features = row[0]

for i, arr in enumerate(features):
feature_lists[i].append(arr)

# label
if label_cols is not None:
if not isinstance(row[1], list):
labels = [row[1]]
else:
labels = row[1]

for i, arr in enumerate(labels):
label_lists[i].append(arr)

feature_arrs = [np.stack(l) for l in feature_lists]
if label_lists is not None:
label_arrs = [np.stack(l) for l in label_lists]
return [{"x": feature_arrs, "y": label_arrs}]

return [{"x": feature_arrs}]


class Estimator:
def __init__(self,
model_creator,
Expand Down Expand Up @@ -156,7 +193,8 @@ def from_keras(cls, model_creator,
def fit(self, data_creator, epochs=1, verbose=1,
callbacks=None, validation_data_creator=None, class_weight=None,
steps_per_epoch=None, validation_steps=None, validation_freq=1,
data_config=None):
data_config=None, feature_cols=None,
label_cols=None,):
"""Runs a training epoch."""
params = dict(
epochs=epochs,
Expand All @@ -170,6 +208,22 @@ def fit(self, data_creator, epochs=1, verbose=1,
)

from zoo.orca.data import SparkXShards
from pyspark.sql import DataFrame
if isinstance(data_creator, DataFrame):
assert feature_cols is not None,\
"feature_col must be provided if data_creator is a spark dataframe"
assert label_cols is not None,\
"label_cols must be provided if data_creator is a spark dataframe"
schema = data_creator.schema
numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(row,
schema,
feature_cols,
label_cols))
shard_rdd = numpy_rdd.mapPartitions(lambda x: arrays2dict(x,
feature_cols,
label_cols))
data_creator = SparkXShards(shard_rdd)

if isinstance(data_creator, SparkXShards):
max_length, ray_xshards = process_spark_xshards(data_creator, self.num_workers)

Expand Down Expand Up @@ -208,7 +262,8 @@ def zip_func(worker, this_shards_ref, that_shards_ref):
return stats

def evaluate(self, data_creator, verbose=1, sample_weight=None,
steps=None, callbacks=None, data_config=None):
steps=None, callbacks=None, data_config=None,
feature_cols=None, label_cols=None):
"""Evaluates the model on the validation data set."""
logger.info("Starting validation step.")
params = dict(
Expand All @@ -219,11 +274,27 @@ def evaluate(self, data_creator, verbose=1, sample_weight=None,
data_config=data_config,
)
from zoo.orca.data import SparkXShards
from pyspark.sql import DataFrame

if isinstance(data_creator, DataFrame):
assert feature_cols is not None,\
"feature_col must be provided if data_creator is a spark dataframe"
assert label_cols is not None,\
"label_cols must be provided if data_creator is a spark dataframe"
schema = data_creator.schema
numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(row,
schema,
feature_cols,
label_cols))
shard_rdd = numpy_rdd.mapPartitions(lambda x: arrays2dict(x,
feature_cols,
label_cols))
data_creator = SparkXShards(shard_rdd)

if isinstance(data_creator, SparkXShards):
data = data_creator
if data.num_partitions() != self.num_workers:
data = data.repartition(self.num_workers)
max_length = data.rdd.map(data_length).max()

ray_xshards = RayXShards.from_spark_xshards(data)

Expand All @@ -247,7 +318,8 @@ def transform_func(worker, shards_ref):
return stats

def predict(self, data_creator, batch_size=None, verbose=1,
steps=None, callbacks=None, data_config=None):
steps=None, callbacks=None, data_config=None,
feature_cols=None):
"""Evaluates the model on the validation data set."""
logger.info("Starting predict step.")
params = dict(
Expand All @@ -258,6 +330,19 @@ def predict(self, data_creator, batch_size=None, verbose=1,
data_config=data_config,
)
from zoo.orca.data import SparkXShards
from pyspark.sql import DataFrame
if isinstance(data_creator, DataFrame):
assert feature_cols is not None,\
"feature_col must be provided if data_creator is a spark dataframe"
schema = data_creator.schema
numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(row,
schema,
feature_cols,
None))
shard_rdd = numpy_rdd.mapPartitions(lambda x: arrays2dict(x,
feature_cols,
None))
data_creator = SparkXShards(shard_rdd)
if isinstance(data_creator, SparkXShards):
ray_xshards = RayXShards.from_spark_xshards(data_creator)

Expand Down
28 changes: 28 additions & 0 deletions python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import numpy as np
import pytest
import tensorflow as tf

from zoo import init_nncontext
from zoo.orca.data import XShards

import zoo.orca.data.pandas
Expand Down Expand Up @@ -324,6 +326,32 @@ def test_sparkxshards(self):
trainer.fit(train_data_shard, epochs=1, steps_per_epoch=25)
trainer.evaluate(train_data_shard, steps=25)

def test_dataframe(self):

sc = init_nncontext()
rdd = sc.range(0, 10)
from pyspark.sql import SparkSession
spark = SparkSession(sc)
from pyspark.ml.linalg import DenseVector
df = rdd.map(lambda x: (DenseVector(np.random.randn(1,).astype(np.float)),
int(np.random.randint(0, 1, size=())))).toDF(["feature", "label"])

config = {
"batch_size": 4,
"lr": 0.8
}
trainer = Estimator(
model_creator=model_creator,
verbose=True,
config=config,
workers_per_node=2)

trainer.fit(df, epochs=1, steps_per_epoch=25,
feature_cols=["feature"],
label_cols=["label"])
trainer.evaluate(df, steps=25, feature_cols=["feature"], label_cols=["label"])
trainer.predict(df, feature_cols=["feature"]).collect()

def test_sparkxshards_with_inbalanced_data(self):

train_data_shard = XShards.partition({"x": np.random.randn(100, 1),
Expand Down

0 comments on commit d868479

Please sign in to comment.