diff --git a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkSchemaUtils.scala b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkSchemaUtils.scala index f7ba1029d..efbfd9261 100644 --- a/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkSchemaUtils.scala +++ b/arrow-data-source/common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/arrow/SparkSchemaUtils.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.datasources.v2.arrow import java.util.Objects import java.util.TimeZone -import org.apache.arrow.vector.types.pojo.Schema +import org.apache.arrow.vector.types.pojo.{Field, Schema} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ArrowUtils object SparkSchemaUtils { @@ -36,6 +36,11 @@ object SparkSchemaUtils { ArrowUtils.toArrowSchema(schema, timeZoneId) } + def toArrowField( + name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { + ArrowUtils.toArrowField(name, dt, nullable, timeZoneId) + } + @deprecated // experimental def getGandivaCompatibleTimeZoneID(): String = { val zone = SQLConf.get.sessionLocalTimeZone diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala index 1f172b043..155ecc21c 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala @@ -20,6 +20,7 @@ package com.intel.oap.spark.sql.execution.datasources.arrow import java.net.URLDecoder import scala.collection.JavaConverters._ +import scala.collection.mutable import com.intel.oap.spark.sql.ArrowWriteExtension.FakeRow import com.intel.oap.spark.sql.ArrowWriteQueue @@ -27,6 +28,7 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, Arr import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._ import com.intel.oap.vectorized.ArrowWritableColumnVector import org.apache.arrow.dataset.scanner.ScanOptions +import org.apache.arrow.vector.types.pojo.Schema import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.Job @@ -117,6 +119,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab val sqlConf = sparkSession.sessionState.conf; val batchSize = sqlConf.parquetVectorizedReaderBatchSize val enableFilterPushDown = sqlConf.arrowFilterPushDown + val caseSensitive = sqlConf.caseSensitiveAnalysis (file: PartitionedFile) => { val factory = ArrowUtils.makeArrowDiscovery( @@ -126,16 +129,34 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab options.asJava).asScala.toMap)) // todo predicate validation / pushdown - val dataset = factory.finish(ArrowUtils.toArrowSchema(requiredSchema)); + val parquetFileFields = factory.inspect().getFields.asScala + val caseInsensitiveFieldMap = mutable.Map[String, String]() + val requiredFields = if (sqlConf.caseSensitiveAnalysis) { + new Schema(requiredSchema.map { field => + parquetFileFields.find(_.getName.equals(field.name)) + .getOrElse(ArrowUtils.toArrowField(field)) + }.asJava) + } else { + new Schema(requiredSchema.map { readField => + parquetFileFields.find(_.getName.equalsIgnoreCase(readField.name)) + .map{ field => + caseInsensitiveFieldMap += (readField.name -> field.getName) + field + }.getOrElse(ArrowUtils.toArrowField(readField)) + }.asJava) + } + val dataset = factory.finish(requiredFields) val filter = if (enableFilterPushDown) { - ArrowFilters.translateFilters(filters) + ArrowFilters.translateFilters(filters, caseInsensitiveFieldMap.toMap) } else { org.apache.arrow.dataset.filter.Filter.EMPTY } - val scanOptions = new ScanOptions(requiredSchema.map(f => f.name).toArray, - filter, batchSize) + val scanOptions = new ScanOptions( + requiredFields.getFields.asScala.map(f => f.getName).toArray, + filter, + batchSize) val scanner = dataset.newScan(scanOptions) val taskList = scanner diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala index f33c7995a..0bcfd3812 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowFilters.scala @@ -20,6 +20,7 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow import org.apache.arrow.dataset.DatasetTypes import org.apache.arrow.dataset.DatasetTypes.TreeNode import org.apache.arrow.dataset.filter.FilterImpl +import org.apache.arrow.vector.types.pojo.Field import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -56,9 +57,11 @@ object ArrowFilters { false } - def translateFilters(pushedFilters: Seq[Filter]): org.apache.arrow.dataset.filter.Filter = { + def translateFilters( + pushedFilters: Seq[Filter], + caseInsensitiveFieldMap: Map[String, String]): org.apache.arrow.dataset.filter.Filter = { val node = pushedFilters - .flatMap(translateFilter) + .flatMap(filter => translateFilter(filter, caseInsensitiveFieldMap)) .reduceOption((t1: TreeNode, t2: TreeNode) => { DatasetTypes.TreeNode.newBuilder.setAndNode( DatasetTypes.AndNode.newBuilder() @@ -100,28 +103,35 @@ object ArrowFilters { } } - private def translateFilter(pushedFilter: Filter): Option[TreeNode] = { + private def translateFilter( + pushedFilter: Filter, + caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = { pushedFilter match { case EqualTo(attribute, value) => - createComparisonNode("equal", attribute, value) + createComparisonNode( + "equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value) case GreaterThan(attribute, value) => - createComparisonNode("greater", attribute, value) + createComparisonNode( + "greater", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value) case GreaterThanOrEqual(attribute, value) => - createComparisonNode("greater_equal", attribute, value) + createComparisonNode( + "greater_equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value) case LessThan(attribute, value) => - createComparisonNode("less", attribute, value) + createComparisonNode( + "less", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value) case LessThanOrEqual(attribute, value) => - createComparisonNode("less_equal", attribute, value) + createComparisonNode( + "less_equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value) case Not(child) => - createNotNode(child) + createNotNode(child, caseInsensitiveFieldMap) case And(left, right) => - createAndNode(left, right) + createAndNode(left, right, caseInsensitiveFieldMap) case Or(left, right) => - createOrNode(left, right) + createOrNode(left, right, caseInsensitiveFieldMap) case IsNotNull(attribute) => - createIsNotNullNode(attribute) + createIsNotNullNode(caseInsensitiveFieldMap.getOrElse(attribute, attribute)) case IsNull(attribute) => - createIsNullNode(attribute) + createIsNullNode(caseInsensitiveFieldMap.getOrElse(attribute, attribute)) case _ => None // fixme complete this } } @@ -145,8 +155,10 @@ object ArrowFilters { } } - def createNotNode(child: Filter): Option[TreeNode] = { - val translatedChild = translateFilter(child) + def createNotNode( + child: Filter, + caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = { + val translatedChild = translateFilter(child, caseInsensitiveFieldMap) if (translatedChild.isEmpty) { return None } @@ -176,9 +188,12 @@ object ArrowFilters { .build()).build()).build()).build()) } - def createAndNode(left: Filter, right: Filter): Option[TreeNode] = { - val translatedLeft = translateFilter(left) - val translatedRight = translateFilter(right) + def createAndNode( + left: Filter, + right: Filter, + caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = { + val translatedLeft = translateFilter(left, caseInsensitiveFieldMap) + val translatedRight = translateFilter(right, caseInsensitiveFieldMap) if (translatedLeft.isEmpty || translatedRight.isEmpty) { return None } @@ -190,9 +205,12 @@ object ArrowFilters { .build()) } - def createOrNode(left: Filter, right: Filter): Option[TreeNode] = { - val translatedLeft = translateFilter(left) - val translatedRight = translateFilter(right) + def createOrNode( + left: Filter, + right: Filter, + caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = { + val translatedLeft = translateFilter(left, caseInsensitiveFieldMap) + val translatedRight = translateFilter(right, caseInsensitiveFieldMap) if (translatedLeft.isEmpty || translatedRight.isEmpty) { return None } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala index 1e443a9d2..0ff3a2d56 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala @@ -19,10 +19,12 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow import java.net.URLDecoder import scala.collection.JavaConverters._ +import scala.collection.mutable import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowPartitionReaderFactory.ColumnarBatchRetainer import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._ import org.apache.arrow.dataset.scanner.ScanOptions +import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow @@ -59,9 +61,27 @@ case class ArrowPartitionReaderFactory( val path = partitionedFile.filePath val factory = ArrowUtils.makeArrowDiscovery(URLDecoder.decode(path, "UTF-8"), partitionedFile.start, partitionedFile.length, options) - val dataset = factory.finish(ArrowUtils.toArrowSchema(readDataSchema)) + val parquetFileFields = factory.inspect().getFields.asScala + val caseInsensitiveFieldMap = mutable.Map[String, String]() + val requiredFields = if (sqlConf.caseSensitiveAnalysis) { + new Schema(readDataSchema.map { field => + parquetFileFields.find(_.getName.equals(field.name)) + .getOrElse(ArrowUtils.toArrowField(field)) + }.asJava) + } else { + new Schema(readDataSchema.map { readField => + parquetFileFields.find(_.getName.equalsIgnoreCase(readField.name)) + .map{ field => + caseInsensitiveFieldMap += (readField.name -> field.getName) + field + }.getOrElse(ArrowUtils.toArrowField(readField)) + }.asJava) + } + val dataset = factory.finish(requiredFields) val filter = if (enableFilterPushDown) { - ArrowFilters.translateFilters(ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema)) + ArrowFilters.translateFilters( + ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema), + caseInsensitiveFieldMap.toMap) } else { org.apache.arrow.dataset.filter.Filter.EMPTY } diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala index c30571b9e..5603326f3 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import com.intel.oap.vectorized.{ArrowColumnVectorUtils, ArrowWritableColumnVector} import org.apache.arrow.dataset.file.FileSystemDatasetFactory import org.apache.arrow.vector.ipc.message.ArrowRecordBatch -import org.apache.arrow.vector.types.pojo.Schema +import org.apache.arrow.vector.types.pojo.{Field, Schema} import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.catalyst.InternalRow @@ -34,10 +34,9 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils} import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.sql.vectorized.ColumnVector -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ArrowUtils { @@ -89,6 +88,11 @@ object ArrowUtils { SparkSchemaUtils.toArrowSchema(t, SparkSchemaUtils.getLocalTimezoneID()) } + def toArrowField(t: StructField): Field = { + SparkSchemaUtils.toArrowField( + t.name, t.dataType, t.nullable, SparkSchemaUtils.getLocalTimezoneID()) + } + def loadBatch(input: ArrowRecordBatch, partitionValues: InternalRow, partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = { val rowCount: Int = input.getLength diff --git a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala index 9896ac1b4..c6c8ebe7f 100644 --- a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala +++ b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala @@ -286,6 +286,36 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { } + test("read and write with case sensitive or insensitive") { + val caseSensitiveAnalysisEnabled = Seq[Boolean](true, false) + val v1SourceList = Seq[String]("", "arrow") + caseSensitiveAnalysisEnabled.foreach{ caseSensitiveAnalysis => + v1SourceList.foreach{v1Source => + withSQLConf( + SQLConf.CASE_SENSITIVE.key -> caseSensitiveAnalysis.toString, + SQLConf.USE_V1_SOURCE_LIST.key -> v1Source) { + withTempPath { tempPath => + spark.range(0, 100) + .withColumnRenamed("id", "Id") + .write + .mode("overwrite") + .arrow(tempPath.getPath) + val selectColName = if (caseSensitiveAnalysis) { + "Id" + } else { + "id" + } + val df = spark.read + .schema(s"$selectColName long") + .arrow(tempPath.getPath) + .filter(s"$selectColName <= 2") + checkAnswer(df, Row(0) :: Row(1) :: Row(2) :: Nil) + } + } + } + } + } + test("file descriptor leak") { val path = ArrowDataSourceTest.locateResourcePath(parquetFile1) val frame = spark.read