Skip to content

Commit

Permalink
support convert rdd of pandas df to spark df with arrow (intel-analyt…
Browse files Browse the repository at this point in the history
  • Loading branch information
dding3 authored and ForJadeForest committed Sep 20, 2022
1 parent 9f1884c commit 2fd9dfc
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 28 deletions.
2 changes: 2 additions & 0 deletions python/dllib/src/bigdl/dllib/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ def _py2java(gateway, obj):
obj = obj._jdf
elif isinstance(obj, SparkContext):
obj = obj._jsc
elif isinstance(obj, SQLContext):
obj = obj._jsqlContext
elif isinstance(obj, (list, tuple)):
obj = ListConverter().convert([_py2java(gateway, x) for x in obj],
gateway._gateway_client)
Expand Down
1 change: 1 addition & 0 deletions python/orca/src/bigdl/orca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
JavaCreator.add_creator_class("com.intel.analytics.bigdl.orca.net.python.PythonZooNet")
JavaCreator.add_creator_class("com.intel.analytics.bigdl.orca.python.PythonOrca")
JavaCreator.add_creator_class("com.intel.analytics.bigdl.orca.inference.PythonInferenceModel")
JavaCreator.add_creator_class("org.apache.spark.sql.PythonOrcaSQLUtils")
for clz in creator_classes:
JavaCreator.add_creator_class(clz)
2 changes: 1 addition & 1 deletion python/orca/src/bigdl/orca/data/pandas/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def loadFile(iterator):
pd_rdd = spark_df_to_rdd_pd(df, squeeze, index_col, dtype, index_map)

try:
data_shards = SparkXShards(pd_rdd)
data_shards = SparkXShards(pd_rdd, class_name="pandas.core.frame.DataFrame")
except Exception as e:
alternative_backend = "pandas" if backend == "spark" else "spark"
print("An error occurred when reading files with '%s' backend, you may switch to '%s' "
Expand Down
133 changes: 106 additions & 27 deletions python/orca/src/bigdl/orca/data/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy
from py4j.protocol import Py4JError

from bigdl.orca.data.utils import *
from bigdl.orca import OrcaContext
from bigdl.dllib.nncontext import init_nncontext, ZooContext
from bigdl.dllib.utils.common import get_node_and_core_number
from bigdl.dllib.utils.common import *
from bigdl.dllib.utils import nest
from bigdl.dllib.utils.log4Error import *

Expand Down Expand Up @@ -138,7 +137,7 @@ class SparkXShards(XShards):
A collection of data which can be pre-processed in parallel on Spark
"""
def __init__(self, rdd, transient=False):
def __init__(self, rdd, transient=False, class_name=None):
self.rdd = rdd
self.user_cached = False
if transient:
Expand All @@ -149,6 +148,8 @@ def __init__(self, rdd, transient=False):
if self.eager:
self.compute()
self.type = {}
if class_name:
self.type['class_name'] = class_name

def transform_shard(self, func, *args):
"""
Expand Down Expand Up @@ -261,7 +262,7 @@ def combine_df(iter):
return iter
rdd = self.rdd.coalesce(num_partitions)
repartitioned_shard = SparkXShards(rdd.mapPartitions(combine_df))
elif self._get_class_name() == 'list':
elif self._get_class_name() == 'builtins.list':
if num_partitions > self.rdd.getNumPartitions():
rdd = self.rdd \
.flatMap(lambda data: data) \
Expand Down Expand Up @@ -294,7 +295,7 @@ def combine_df(iter):
lambda iter: [np.concatenate(list(iter), axis=0)]))
else:
repartitioned_shard = SparkXShards(self.rdd.repartition(num_partitions))
elif self._get_class_name() == "dict":
elif self._get_class_name() == "builtins.dict":
elem = self.rdd.first()
keys = list(elem.keys())
dtypes = []
Expand Down Expand Up @@ -532,28 +533,66 @@ def zip(self, other):
"The two SparkXShards should have the same number of elements "
"in each partition")

def to_spark_df(self):
if self._get_class_name() != 'pandas.core.frame.DataFrame':
invalidInputError(False,
"Currently only support to_spark_df on XShards of Pandas DataFrame")

def _to_spark_df_without_arrow(self):
def f(iter):
for pdf in iter:
np_records = pdf.to_records(index=False)
return [r.tolist() for r in np_records]

def getSchema(iter):
for pdf in iter:
return [pdf.columns.values]

rdd = self.rdd.mapPartitions(f)
column = self.rdd.mapPartitions(getSchema).first()
column = self.get_schema()['columns']
df = rdd.toDF(list(column))
df.cache()
df.count()
self.uncache()
return df

# to_spark_df adapted from pyspark
# https://github.com/apache/spark/blob/master/python/pyspark/sql/pandas/conversion.py
def to_spark_df(self):
if self._get_class_name() != 'pandas.core.frame.DataFrame':
invalidInputError(False,
"Currently only support to_spark_df on XShards of Pandas DataFrame")

try:
import pyarrow as pa
sdf_schema = self._get_spark_df_schema()

sqlContext = get_spark_sql_context(get_spark_context())
timezone = sqlContext._conf.sessionLocalTimeZone()

def f(iter):
for pdf in iter:
import os
import uuid
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
from tempfile import NamedTemporaryFile

tmpFile = "/tmp/" + str(uuid.uuid1())
os.mkdir(tmpFile)

arrow_types = [to_arrow_type(f.dataType) for f in sdf_schema.fields]

arrow_data = [[(c, t) for (_, c), t in zip(pdf.iteritems(), arrow_types)]]
col_by_name = True
safecheck = False
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)

tempFile = NamedTemporaryFile(delete=False, dir=tmpFile)
try:
ser.dump_stream(arrow_data, tempFile)
finally:
tempFile.close()
return [tempFile.name]

jiter = self.rdd.mapPartitions(f)
from bigdl.dllib.utils.file_utils import callZooFunc

df = callZooFunc("float", "orcaToDataFrame", jiter, sdf_schema.json(), sqlContext)
return df
except Exception as e:
print(f"createDataFrame from shards attempted Arrow optimization failed as: {str(e)},"
f"Will try without Arrow optimization")
return self._to_spark_df_without_arrow()

def __len__(self):
return self.rdd.map(lambda data: len(data) if hasattr(data, '__len__') else 1)\
.reduce(lambda l1, l2: l1 + l2)
Expand Down Expand Up @@ -601,27 +640,67 @@ def get_schema(self):

if 'class_name' not in self.type\
or self.type['class_name'] == 'pandas.core.frame.DataFrame':
class_name, schema = self._get_schema_class_name()
class_name, pdf_schema, sdf_schema = self._get_schema_class_name()
self.type['class_name'] = class_name
self.type['schema'] = schema
self.type['schema'] = pdf_schema
self.type['spark_df_schema'] = sdf_schema
return self.type['schema']

def _get_spark_df_schema(self):
if 'spark_df_schema' in self.type:
return self.type['spark_df_schema']

if 'class_name' not in self.type\
or self.type['class_name'] == 'pandas.core.frame.DataFrame':
class_name, pdf_schema, sdf_schema = self._get_schema_class_name()
self.type['class_name'] = class_name
self.type['schema'] = pdf_schema
self.type['spark_df_schema'] = sdf_schema
return self.type['spark_df_schema']

def _get_class_name(self):
if 'class_name' in self.type:
return self.type['class_name']
else:
class_name, schema = self._get_schema_class_name()
class_name, schema, sdf_schema = self._get_schema_class_name()
self.type['class_name'] = class_name
self.type['schema'] = schema
self.type['spark_df_schema'] = sdf_schema
return self.type['class_name']

def _get_schema_class_name(self):
def func(x):
class_name = get_class_name(x)
schema = None
if class_name == 'pandas.core.frame.DataFrame':
schema = {'columns': x.columns, 'dtypes': x.dtypes}
return (class_name, schema)
class_name = self.type['class_name'] if 'class_name' in self.type else None
import pyspark
spark_version = pyspark.version.__version__
major_version = spark_version.split(".")[0]

def func(pdf):
pdf_schema = None
spark_df_schema = None
_class_name = class_name
if not _class_name:
_class_name = pdf.__class__.__module__ + '.' + pdf.__class__.__name__

if _class_name == 'pandas.core.frame.DataFrame':
schema = [str(x) if not isinstance(x, str) else x for x in pdf.columns]
pdf_schema = {'columns': schema, 'dtypes': pdf.dtypes}

if major_version >= '3':
from pyspark.sql.pandas.types import from_arrow_type
from pyspark.sql.types import StructType

if isinstance(schema, (list, tuple)):
import pyarrow as pa
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(
name, from_arrow_type(field.type), nullable=field.nullable
)
spark_df_schema = struct

return (_class_name, pdf_schema, spark_df_schema)

return self.rdd.map(lambda x: func(x)).first()


Expand Down
7 changes: 7 additions & 0 deletions python/orca/test/bigdl/orca/data/test_spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ def test_write_read_imagenet(self):
finally:
shutil.rmtree(temp_dir)

def test_to_spark_df(self):
file_path = os.path.join(self.resource_path, "orca/data/csv")
data_shard = bigdl.orca.data.pandas.read_csv(file_path, header=0, names=['user', 'item'],
usecols=[0, 1])
df = data_shard.to_spark_df()
df.show()


if __name__ == "__main__":
pytest.main([__file__])
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.io.FileInputStream

import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath.TensorNumeric
import org.apache.spark.TaskContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{ShutdownHookManager, Utils}

import java.io._

import scala.reflect.ClassTag

object PythonOrcaSQLUtils {

def ofFloat(): PythonOrcaSQLUtils[Float] = new PythonOrcaSQLUtils[Float]()

def ofDouble(): PythonOrcaSQLUtils[Double] = new PythonOrcaSQLUtils[Double]()
}

class PythonOrcaSQLUtils[T: ClassTag](implicit ev: TensorNumeric[T]) {
def orcaToDataFrame(jrdd: JavaRDD[String], schemaString: String,
sqlContext: SQLContext): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
val rdd = jrdd.rdd.mapPartitions { iter =>
val context = TaskContext.get()
val file = iter.next()
val dir = new File(file)
ShutdownHookManager.registerShutdownDeleteDir(dir)

Utils.tryWithResource(new FileInputStream(file)) { fileStream =>
// Create array to consume iterator so that we can safely close the file
val batches = ArrowConverters.getBatchesFromStream(fileStream.getChannel)
ArrowConverters.fromBatchIterator(batches,
DataType.fromJson(schemaString).asInstanceOf[StructType], timeZoneId, context)
}
}
sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema)
}
}

0 comments on commit 2fd9dfc

Please sign in to comment.