Skip to content

Commit

Permalink
support convert spark df to shards with arrow (#5558)
Browse files Browse the repository at this point in the history
  • Loading branch information
dding3 authored Sep 7, 2022
1 parent 382536a commit 5ce6c7a
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 67 deletions.
2 changes: 1 addition & 1 deletion python/orca/src/bigdl/orca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
JavaCreator.add_creator_class("com.intel.analytics.bigdl.orca.net.python.PythonZooNet")
JavaCreator.add_creator_class("com.intel.analytics.bigdl.orca.python.PythonOrca")
JavaCreator.add_creator_class("com.intel.analytics.bigdl.orca.inference.PythonInferenceModel")
JavaCreator.add_creator_class("org.apache.spark.sql.PythonOrcaSQLUtils")
JavaCreator.add_creator_class("com.intel.analytics.bigdl.orca.python.PythonOrcaSQLUtils")
for clz in creator_classes:
JavaCreator.add_creator_class(clz)
3 changes: 2 additions & 1 deletion python/orca/src/bigdl/orca/data/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@ def f(iter):
os.mkdir(tmpFile)

arrow_types = [to_arrow_type(f.dataType) for f in sdf_schema.fields]

arrow_data = [[(c, t) for (_, c), t in zip(pdf.iteritems(), arrow_types)]]
col_by_name = True
safecheck = False
Expand Down Expand Up @@ -645,6 +644,7 @@ def get_schema(self):
self.type['schema'] = pdf_schema
self.type['spark_df_schema'] = sdf_schema
return self.type['schema']
return None

def _get_spark_df_schema(self):
if 'spark_df_schema' in self.type:
Expand All @@ -657,6 +657,7 @@ def _get_spark_df_schema(self):
self.type['schema'] = pdf_schema
self.type['spark_df_schema'] = sdf_schema
return self.type['spark_df_schema']
return None

def _get_class_name(self):
if 'class_name' in self.type:
Expand Down
98 changes: 64 additions & 34 deletions python/orca/src/bigdl/orca/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#
import os
import numpy as np
import pandas as pd

from bigdl.dllib.utils.file_utils import get_file_list

from bigdl.dllib.utils.file_utils import get_file_list, callZooFunc
from bigdl.dllib.utils.utils import convert_row_to_numpy
from bigdl.dllib.utils.common import *
from bigdl.dllib.utils.log4Error import *


Expand Down Expand Up @@ -378,7 +381,6 @@ def spark_df_to_rdd_pd(df, squeeze=False, index_col=None,
dtype=None, index_map=None):
from bigdl.orca.data import SparkXShards
from bigdl.orca import OrcaContext
columns = df.columns

import pyspark.sql.functions as F
import pyspark.sql.types as T
Expand All @@ -388,46 +390,55 @@ def spark_df_to_rdd_pd(df, squeeze=False, index_col=None,
df = df.withColumn(colName, to_array(colName))

shard_size = OrcaContext._shard_size
pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns, squeeze, index_col, dtype, index_map,
batch_size=shard_size))
return pd_rdd

try:
pd_rdd = to_pandas(df, squeeze, index_col, dtype, index_map,
batch_size=shard_size)
return pd_rdd
except Exception as e:
print(f"create shards from Spark DataFrame attempted Arrow optimization failed as:"
f" {str(e)}. Will try without Arrow optimization")
pd_rdd = df.rdd.mapPartitions(to_pandas_without_arrow(df.columns, squeeze, index_col,
dtype, index_map,
batch_size=shard_size))
return pd_rdd


def spark_df_to_pd_sparkxshards(df, squeeze=False, index_col=None,
dtype=None, index_map=None):
pd_rdd = spark_df_to_rdd_pd(df, squeeze, index_col, dtype, index_map)
from bigdl.orca.data import SparkXShards
spark_xshards = SparkXShards(pd_rdd)
df.unpersist()
return spark_xshards


def to_pandas(columns, squeeze=False, index_col=None, dtype=None, index_map=None,
batch_size=None):
def postprocess(pd_df):
if dtype is not None:
if isinstance(dtype, dict):
for col, type in dtype.items():
if isinstance(col, str):
if col not in pd_df.columns:
invalidInputError(False,
"column to be set type is not"
" in current dataframe")
pd_df[col] = pd_df[col].astype(type)
elif isinstance(col, int):
if index_map[col] not in pd_df.columns:
invalidInputError(False,
"column index to be set type is not"
" in current dataframe")
pd_df[index_map[col]] = pd_df[index_map[col]].astype(type)
else:
pd_df = pd_df.astype(dtype)
if squeeze and len(pd_df.columns) == 1:
pd_df = pd_df.iloc[:, 0]
if index_col:
pd_df = pd_df.set_index(index_col)
return pd_df
def set_pandas_df_type_index(pd_df, squeeze=False, index_col=None, dtype=None, index_map=None):
if dtype is not None:
if isinstance(dtype, dict):
for col, type in dtype.items():
if isinstance(col, str):
if col not in pd_df.columns:
invalidInputError(False,
"column to be set type is not"
" in current dataframe")
pd_df[col] = pd_df[col].astype(type)
elif isinstance(col, int):
if index_map[col] not in pd_df.columns:
invalidInputError(False,
"column index to be set type is not"
" in current dataframe")
pd_df[index_map[col]] = pd_df[index_map[col]].astype(type)
else:
pd_df = pd_df.astype(dtype)
if squeeze and len(pd_df.columns) == 1:
pd_df = pd_df.iloc[:, 0]
if index_col:
pd_df = pd_df.set_index(index_col)
return pd_df


def to_pandas_without_arrow(columns, squeeze=False, index_col=None, dtype=None, index_map=None,
batch_size=None):
def f(iter):
import pandas as pd
counter = 0
Expand All @@ -437,17 +448,37 @@ def f(iter):
data.append(row)
if batch_size and counter % batch_size == 0:
pd_df = pd.DataFrame(data, columns=columns)
pd_df = postprocess(pd_df)
pd_df = set_pandas_df_type_index(pd_df, squeeze, index_col, dtype, index_map)
yield pd_df
data = []
if data:
pd_df = pd.DataFrame(data, columns=columns)
pd_df = postprocess(pd_df)
pd_df = set_pandas_df_type_index(pd_df, squeeze, index_col, dtype, index_map)
yield pd_df

return f


def to_pandas(df, squeeze=False, index_col=None, dtype=None, index_map=None, batch_size=None):
def farrow(iter):
for fileName in iter:
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
ser = ArrowStreamPandasSerializer(timezone, False, True)
with open(fileName, "rb") as stream:
t = ser.load_stream(stream)
pd_df = pd.concat(next(t), axis=1)
pd_df = set_pandas_df_type_index(pd_df, squeeze, index_col, dtype, index_map)
yield pd_df

sqlContext = get_spark_sql_context(get_spark_context())
timezone = sqlContext._conf.sessionLocalTimeZone()

batch_size = -1 if not batch_size else batch_size
rdd_file = callZooFunc("float", "sparkdfTopdf", df._jdf, sqlContext, batch_size)
pd_rdd = rdd_file.mapPartitions(farrow)
return pd_rdd


def spark_xshards_to_ray_dataset(spark_xshards):
from bigdl.orca.data.ray_xshards import RayXShards
import ray
Expand All @@ -460,7 +491,6 @@ def spark_xshards_to_ray_dataset(spark_xshards):


def generate_string_idx(df, columns, freq_limit, order_by_freq):
from bigdl.dllib.utils.file_utils import callZooFunc
return callZooFunc("float", "generateStringIdx", df, columns, freq_limit, order_by_freq)


Expand Down
12 changes: 11 additions & 1 deletion python/orca/test/bigdl/orca/data/test_spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from bigdl.orca import OrcaContext
from bigdl.dllib.nncontext import *
from bigdl.orca.data.image import write_tfrecord, read_tfrecord
from bigdl.orca.data.utils import *


class TestSparkBackend(TestCase):
Expand Down Expand Up @@ -77,7 +78,7 @@ def test_dtype(self):
df = data[0]
assert df.location.dtype == "float64"
assert df.ID.dtype == "float64"
data_shard = bigdl.orca.data.pandas.read_csv(file_path, dtype={"sale_price": np.float32})
data_shard = bigdl.orca.data.pandas.read_csv(file_path, dtype={"sale_price": np.float32, "ID": np.int64})
data = data_shard.collect()
df2 = data[0]
assert df2.sale_price.dtype == "float32" and df2.ID.dtype == "int64"
Expand Down Expand Up @@ -221,6 +222,15 @@ def test_to_spark_df(self):
df = data_shard.to_spark_df()
df.show()

def test_spark_df_to_shards(self):
file_path = os.path.join(self.resource_path, "orca/data/csv")
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[1]")\
.appName('test_spark_backend')\
.config("spark.driver.memory", "6g").getOrCreate()
df = spark.read.csv(file_path)
data_shards = spark_df_to_pd_sparkxshards(df)


if __name__ == "__main__":
pytest.main([__file__])
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD


class OrcaArrowUtils() {
def orcaToDataFrame(jrdd: JavaRDD[String], schemaString: String,
sqlContext: SQLContext): DataFrame = {
null.asInstanceOf[DataFrame]
}

def sparkdfTopdf(sdf: DataFrame, sqlContext: SQLContext, batchSize: Int = -1): RDD[String] = {
null.asInstanceOf[RDD[String]]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.io.{DataOutputStream, FileInputStream, FileOutputStream}

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowWriter}
import org.apache.spark.sql.execution.python.BatchIterator
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils

import org.apache.spark.util.{ShutdownHookManager, Utils}

import java.io._


class OrcaArrowUtils() {
def orcaToDataFrame(jrdd: JavaRDD[String], schemaString: String,
sqlContext: SQLContext): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
val rdd = jrdd.rdd.mapPartitions { iter =>
val context = TaskContext.get()
val file = iter.next()
val dir = new File(file)
ShutdownHookManager.registerShutdownDeleteDir(dir)

Utils.tryWithResource(new FileInputStream(file)) { fileStream =>
// Create array to consume iterator so that we can safely close the file
val batches = ArrowConverters.getBatchesFromStream(fileStream.getChannel)
ArrowConverters.fromBatchIterator(batches,
DataType.fromJson(schemaString).asInstanceOf[StructType], timeZoneId, context)
}
}
sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema)
}

// Below code is adapted from spark, https://github.com/apache/spark/blob/branch-3.1/sql/core/
// src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
def sparkdfTopdf(sdf: DataFrame, sqlContext: SQLContext, batchSize: Int = -1): RDD[String] = {
val schemaCaptured = sdf.schema
val maxRecordsPerBatch = if (batchSize == -1) {
sqlContext.sessionState.conf.arrowMaxRecordsPerBatch
} else batchSize
val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone


val schema = sdf.schema
sdf.rdd.mapPartitions {iter =>
val batchIter = if (maxRecordsPerBatch > 0) {
new BatchIterator(iter, maxRecordsPerBatch)
} else Iterator(iter)

val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator("ItertoFile", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)

val conf = SparkEnv.get.conf
val sparkFilesDir =
Utils.createTempDir(Utils.getLocalDir(conf), "arrowCommunicate").getAbsolutePath

val filename = sparkFilesDir + "/arrowdata"
val fos = new FileOutputStream(filename)
val dataOutput = new DataOutputStream(fos)

Utils.tryWithSafeFinally {
val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowStreamWriter(root, null, dataOutput)
writer.start()

while (batchIter.hasNext) {
val nextBatch = batchIter.next()

while (nextBatch.hasNext) {
val nxtIternalRow = CatalystTypeConverters.convertToCatalyst(nextBatch.next())
arrowWriter.write(nxtIternalRow.asInstanceOf[InternalRow])
}

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
}
writer.end()
} {
root.close()
allocator.close()
if (dataOutput != null) {
dataOutput.close()
}
if (fos != null) {
fos.close()
}
}
Iterator(filename)
}
}
}
Loading

0 comments on commit 5ce6c7a

Please sign in to comment.