Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-979] Support reading parquet with case sensitive (#980)
Browse files Browse the repository at this point in the history
* support reading parquet with case sensitive

* support filter pushed down with case sensitive

* add unit test

* change code style

* fix mvn compile failed
  • Loading branch information
jackylee-ch authored Jun 27, 2022
1 parent 174aad0 commit b28ec12
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.datasources.v2.arrow
import java.util.Objects
import java.util.TimeZone

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.types.pojo.{Field, Schema}

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils

object SparkSchemaUtils {
Expand All @@ -36,6 +36,11 @@ object SparkSchemaUtils {
ArrowUtils.toArrowSchema(schema, timeZoneId)
}

def toArrowField(
name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = {
ArrowUtils.toArrowField(name, dt, nullable, timeZoneId)
}

@deprecated // experimental
def getGandivaCompatibleTimeZoneID(): String = {
val zone = SQLConf.get.sessionLocalTimeZone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ package com.intel.oap.spark.sql.execution.datasources.arrow
import java.net.URLDecoder

import scala.collection.JavaConverters._
import scala.collection.mutable

import com.intel.oap.spark.sql.ArrowWriteExtension.FakeRow
import com.intel.oap.spark.sql.ArrowWriteQueue
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, ArrowOptions, ArrowUtils}
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import com.intel.oap.vectorized.ArrowWritableColumnVector
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.Job
Expand Down Expand Up @@ -117,6 +119,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
val sqlConf = sparkSession.sessionState.conf;
val batchSize = sqlConf.parquetVectorizedReaderBatchSize
val enableFilterPushDown = sqlConf.arrowFilterPushDown
val caseSensitive = sqlConf.caseSensitiveAnalysis

(file: PartitionedFile) => {
val factory = ArrowUtils.makeArrowDiscovery(
Expand All @@ -126,16 +129,34 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
options.asJava).asScala.toMap))

// todo predicate validation / pushdown
val dataset = factory.finish(ArrowUtils.toArrowSchema(requiredSchema));
val parquetFileFields = factory.inspect().getFields.asScala
val caseInsensitiveFieldMap = mutable.Map[String, String]()
val requiredFields = if (sqlConf.caseSensitiveAnalysis) {
new Schema(requiredSchema.map { field =>
parquetFileFields.find(_.getName.equals(field.name))
.getOrElse(ArrowUtils.toArrowField(field))
}.asJava)
} else {
new Schema(requiredSchema.map { readField =>
parquetFileFields.find(_.getName.equalsIgnoreCase(readField.name))
.map{ field =>
caseInsensitiveFieldMap += (readField.name -> field.getName)
field
}.getOrElse(ArrowUtils.toArrowField(readField))
}.asJava)
}
val dataset = factory.finish(requiredFields)

val filter = if (enableFilterPushDown) {
ArrowFilters.translateFilters(filters)
ArrowFilters.translateFilters(filters, caseInsensitiveFieldMap.toMap)
} else {
org.apache.arrow.dataset.filter.Filter.EMPTY
}

val scanOptions = new ScanOptions(requiredSchema.map(f => f.name).toArray,
filter, batchSize)
val scanOptions = new ScanOptions(
requiredFields.getFields.asScala.map(f => f.getName).toArray,
filter,
batchSize)
val scanner = dataset.newScan(scanOptions)

val taskList = scanner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow
import org.apache.arrow.dataset.DatasetTypes
import org.apache.arrow.dataset.DatasetTypes.TreeNode
import org.apache.arrow.dataset.filter.FilterImpl
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -56,9 +57,11 @@ object ArrowFilters {
false
}

def translateFilters(pushedFilters: Seq[Filter]): org.apache.arrow.dataset.filter.Filter = {
def translateFilters(
pushedFilters: Seq[Filter],
caseInsensitiveFieldMap: Map[String, String]): org.apache.arrow.dataset.filter.Filter = {
val node = pushedFilters
.flatMap(translateFilter)
.flatMap(filter => translateFilter(filter, caseInsensitiveFieldMap))
.reduceOption((t1: TreeNode, t2: TreeNode) => {
DatasetTypes.TreeNode.newBuilder.setAndNode(
DatasetTypes.AndNode.newBuilder()
Expand Down Expand Up @@ -100,28 +103,35 @@ object ArrowFilters {
}
}

private def translateFilter(pushedFilter: Filter): Option[TreeNode] = {
private def translateFilter(
pushedFilter: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
pushedFilter match {
case EqualTo(attribute, value) =>
createComparisonNode("equal", attribute, value)
createComparisonNode(
"equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case GreaterThan(attribute, value) =>
createComparisonNode("greater", attribute, value)
createComparisonNode(
"greater", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case GreaterThanOrEqual(attribute, value) =>
createComparisonNode("greater_equal", attribute, value)
createComparisonNode(
"greater_equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case LessThan(attribute, value) =>
createComparisonNode("less", attribute, value)
createComparisonNode(
"less", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case LessThanOrEqual(attribute, value) =>
createComparisonNode("less_equal", attribute, value)
createComparisonNode(
"less_equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case Not(child) =>
createNotNode(child)
createNotNode(child, caseInsensitiveFieldMap)
case And(left, right) =>
createAndNode(left, right)
createAndNode(left, right, caseInsensitiveFieldMap)
case Or(left, right) =>
createOrNode(left, right)
createOrNode(left, right, caseInsensitiveFieldMap)
case IsNotNull(attribute) =>
createIsNotNullNode(attribute)
createIsNotNullNode(caseInsensitiveFieldMap.getOrElse(attribute, attribute))
case IsNull(attribute) =>
createIsNullNode(attribute)
createIsNullNode(caseInsensitiveFieldMap.getOrElse(attribute, attribute))
case _ => None // fixme complete this
}
}
Expand All @@ -145,8 +155,10 @@ object ArrowFilters {
}
}

def createNotNode(child: Filter): Option[TreeNode] = {
val translatedChild = translateFilter(child)
def createNotNode(
child: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
val translatedChild = translateFilter(child, caseInsensitiveFieldMap)
if (translatedChild.isEmpty) {
return None
}
Expand Down Expand Up @@ -176,9 +188,12 @@ object ArrowFilters {
.build()).build()).build()).build())
}

def createAndNode(left: Filter, right: Filter): Option[TreeNode] = {
val translatedLeft = translateFilter(left)
val translatedRight = translateFilter(right)
def createAndNode(
left: Filter,
right: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
val translatedLeft = translateFilter(left, caseInsensitiveFieldMap)
val translatedRight = translateFilter(right, caseInsensitiveFieldMap)
if (translatedLeft.isEmpty || translatedRight.isEmpty) {
return None
}
Expand All @@ -190,9 +205,12 @@ object ArrowFilters {
.build())
}

def createOrNode(left: Filter, right: Filter): Option[TreeNode] = {
val translatedLeft = translateFilter(left)
val translatedRight = translateFilter(right)
def createOrNode(
left: Filter,
right: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
val translatedLeft = translateFilter(left, caseInsensitiveFieldMap)
val translatedRight = translateFilter(right, caseInsensitiveFieldMap)
if (translatedLeft.isEmpty || translatedRight.isEmpty) {
return None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow
import java.net.URLDecoder

import scala.collection.JavaConverters._
import scala.collection.mutable

import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowPartitionReaderFactory.ColumnarBatchRetainer
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -59,9 +61,27 @@ case class ArrowPartitionReaderFactory(
val path = partitionedFile.filePath
val factory = ArrowUtils.makeArrowDiscovery(URLDecoder.decode(path, "UTF-8"),
partitionedFile.start, partitionedFile.length, options)
val dataset = factory.finish(ArrowUtils.toArrowSchema(readDataSchema))
val parquetFileFields = factory.inspect().getFields.asScala
val caseInsensitiveFieldMap = mutable.Map[String, String]()
val requiredFields = if (sqlConf.caseSensitiveAnalysis) {
new Schema(readDataSchema.map { field =>
parquetFileFields.find(_.getName.equals(field.name))
.getOrElse(ArrowUtils.toArrowField(field))
}.asJava)
} else {
new Schema(readDataSchema.map { readField =>
parquetFileFields.find(_.getName.equalsIgnoreCase(readField.name))
.map{ field =>
caseInsensitiveFieldMap += (readField.name -> field.getName)
field
}.getOrElse(ArrowUtils.toArrowField(readField))
}.asJava)
}
val dataset = factory.finish(requiredFields)
val filter = if (enableFilterPushDown) {
ArrowFilters.translateFilters(ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema))
ArrowFilters.translateFilters(
ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema),
caseInsensitiveFieldMap.toMap)
} else {
org.apache.arrow.dataset.filter.Filter.EMPTY
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@ import scala.collection.JavaConverters._
import com.intel.oap.vectorized.{ArrowColumnVectorUtils, ArrowWritableColumnVector}
import org.apache.arrow.dataset.file.FileSystemDatasetFactory
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.types.pojo.{Field, Schema}
import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils}
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnVector
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

object ArrowUtils {

Expand Down Expand Up @@ -89,6 +88,11 @@ object ArrowUtils {
SparkSchemaUtils.toArrowSchema(t, SparkSchemaUtils.getLocalTimezoneID())
}

def toArrowField(t: StructField): Field = {
SparkSchemaUtils.toArrowField(
t.name, t.dataType, t.nullable, SparkSchemaUtils.getLocalTimezoneID())
}

def loadBatch(input: ArrowRecordBatch, partitionValues: InternalRow,
partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = {
val rowCount: Int = input.getLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,36 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession {

}

test("read and write with case sensitive or insensitive") {
val caseSensitiveAnalysisEnabled = Seq[Boolean](true, false)
val v1SourceList = Seq[String]("", "arrow")
caseSensitiveAnalysisEnabled.foreach{ caseSensitiveAnalysis =>
v1SourceList.foreach{v1Source =>
withSQLConf(
SQLConf.CASE_SENSITIVE.key -> caseSensitiveAnalysis.toString,
SQLConf.USE_V1_SOURCE_LIST.key -> v1Source) {
withTempPath { tempPath =>
spark.range(0, 100)
.withColumnRenamed("id", "Id")
.write
.mode("overwrite")
.arrow(tempPath.getPath)
val selectColName = if (caseSensitiveAnalysis) {
"Id"
} else {
"id"
}
val df = spark.read
.schema(s"$selectColName long")
.arrow(tempPath.getPath)
.filter(s"$selectColName <= 2")
checkAnswer(df, Row(0) :: Row(1) :: Row(2) :: Nil)
}
}
}
}
}

test("file descriptor leak") {
val path = ArrowDataSourceTest.locateResourcePath(parquetFile1)
val frame = spark.read
Expand Down

0 comments on commit b28ec12

Please sign in to comment.