diff --git a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java b/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java deleted file mode 100644 index f093637d412f9..0000000000000 --- a/sql/hive/src/main/java/org/apache/hadoop/hive/ql/io/orc/SparkOrcNewRecordReader.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.hive.ql.io.orc; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.io.NullWritable; -import org.apache.hadoop.mapreduce.InputSplit; -import org.apache.hadoop.mapreduce.TaskAttemptContext; - -import java.io.IOException; -import java.util.List; - -/** - * This is based on hive-exec-1.2.1 - * {@link org.apache.hadoop.hive.ql.io.orc.OrcNewInputFormat.OrcRecordReader}. - * This class exposes getObjectInspector which can be used for reducing - * NameNode calls in OrcRelation. - */ -public class SparkOrcNewRecordReader extends - org.apache.hadoop.mapreduce.RecordReader { - private final org.apache.hadoop.hive.ql.io.orc.RecordReader reader; - private final int numColumns; - OrcStruct value; - private float progress = 0.0f; - private ObjectInspector objectInspector; - - public SparkOrcNewRecordReader(Reader file, Configuration conf, - long offset, long length) throws IOException { - List types = file.getTypes(); - numColumns = (types.size() == 0) ? 0 : types.get(0).getSubtypesCount(); - value = new OrcStruct(numColumns); - this.reader = OrcInputFormat.createReaderFromFile(file, conf, offset, - length); - this.objectInspector = file.getObjectInspector(); - } - - @Override - public void close() throws IOException { - reader.close(); - } - - @Override - public NullWritable getCurrentKey() throws IOException, - InterruptedException { - return NullWritable.get(); - } - - @Override - public OrcStruct getCurrentValue() throws IOException, - InterruptedException { - return value; - } - - @Override - public float getProgress() throws IOException, InterruptedException { - return progress; - } - - @Override - public void initialize(InputSplit split, TaskAttemptContext context) - throws IOException, InterruptedException { - } - - @Override - public boolean nextKeyValue() throws IOException, InterruptedException { - if (reader.hasNext()) { - reader.next(value); - progress = reader.getProgress(); - return true; - } else { - return false; - } - } - - public ObjectInspector getObjectInspector() { - return objectInspector; - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index edf2013a4c936..38ac34cccf810 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -25,18 +25,22 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc._ import org.apache.hadoop.hive.serde2.objectinspector.{SettableStructObjectInspector, StructObjectInspector} import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc._ +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.{OrcInputFormat, OrcOutputFormat} import org.apache.spark.TaskContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} @@ -57,10 +61,10 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - OrcFileOperator.readSchema( - files.map(_.getPath.toUri.toString), - Some(sparkSession.sessionState.newHadoopConf()) - ) + val conf = sparkSession.sparkContext.hadoopConfiguration + files.map(_.getPath).flatMap(OrcUtils.readSchema(_, conf)).headOption.map { schema => + CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] + } } override def prepareWrite( @@ -72,16 +76,19 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val configuration = job.getConfiguration - configuration.set(OrcRelation.ORC_COMPRESSION, orcOptions.compressionCodec) + configuration.set(OrcConf.COMPRESS.getAttribute, orcOptions.compressionCodec) + val outputFormatClass = classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]] configuration match { case conf: JobConf => - conf.setOutputFormat(classOf[OrcOutputFormat]) + conf.setOutputFormat(outputFormatClass) case conf => conf.setClass( "mapred.output.format.class", - classOf[OrcOutputFormat], + outputFormatClass, classOf[MapRedOutputFormat[_, _]]) } + configuration.set( + OrcConf.MAPRED_OUTPUT_SCHEMA.getAttribute, OrcUtils.getSchemaString(dataSchema)) new OutputWriterFactory { override def newInstance( @@ -93,8 +100,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable override def getFileExtension(context: TaskAttemptContext): String = { val compressionExtension: String = { - val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + val name = context.getConfiguration.get(OrcConf.COMPRESS.getAttribute) + OrcOptions.extensionsForCompressionCodecNames.getOrElse(name, "") } compressionExtension + ".orc" @@ -109,7 +116,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable true } - override def buildReader( + override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -120,10 +127,16 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => - hadoopConf.set(OrcRelation.SARG_PUSHDOWN, f.toKryo) - hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) } } + // Column Selection: we need to set at least one column due to ORC-233 + val columns = if (requiredSchema.isEmpty) { + "0" + } else { + requiredSchema.map(f => dataSchema.fieldIndex(f.name)).mkString(",") + } + hadoopConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, columns) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -132,183 +145,78 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val conf = broadcastedHadoopConf.value.value // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this - // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file + // case, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) - if (maybePhysicalSchema.isEmpty) { + val physicalSchema = OrcUtils.getSchema(dataSchema, file.filePath, conf) + if (physicalSchema.getFieldNames.isEmpty) { Iterator.empty } else { - val physicalSchema = maybePhysicalSchema.get - OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) val orcRecordReader = { - val job = Job.getInstance(conf) - FileInputFormat.setInputPaths(job, file.filePath) - val fileSplit = new FileSplit( new Path(new URI(file.filePath)), file.start, file.length, Array.empty ) - // Custom OrcRecordReader is used to get - // ObjectInspector during recordReader creation itself and can - // avoid NameNode call in unwrapOrcStructs per file. - // Specifically would be helpful for partitioned datasets. - val orcReader = OrcFile.createReader( - new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) - new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + new OrcInputFormat[OrcStruct].createRecordReader(fileSplit, taskAttemptContext) } val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + val mutableRow = new SpecificInternalRow(resultSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(resultSchema) + val partitionValues = file.partitionValues + for (i <- requiredSchema.length until resultSchema.length) { + val value = partitionValues.get(i - requiredSchema.length, resultSchema(i).dataType) + mutableRow.update(i, value) + } // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( - conf, - requiredSchema, - Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), - recordsIterator) + val valueWrappers = requiredSchema.fields.map(f => OrcUtils.getValueWrapper(f.dataType)) + recordsIterator.map { value => + unsafeProjection(OrcUtils.convertOrcStructToInternalRow( + value, requiredSchema, Some(valueWrappers), Some(mutableRow))) + } } } } } -private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) - extends HiveInspectors { - - def serialize(row: InternalRow): Writable = { - wrapOrcStruct(cachedOrcStruct, structOI, row) - serializer.serialize(cachedOrcStruct, structOI) - } - - private[this] val serializer = { - val table = new Properties() - table.setProperty("columns", dataSchema.fieldNames.mkString(",")) - table.setProperty("columns.types", dataSchema.map(_.dataType.catalogString).mkString(":")) - - val serde = new OrcSerde - serde.initialize(conf, table) - serde - } - - // Object inspector converted from the schema of the relation to be serialized. - private[this] val structOI = { - val typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(dataSchema.catalogString) - OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) - .asInstanceOf[SettableStructObjectInspector] - } - - private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] - - // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format - private[this] val wrappers = dataSchema.zip(structOI.getAllStructFieldRefs().asScala.toSeq).map { - case (f, i) => wrapperFor(i.getFieldObjectInspector, f.dataType) - } - - private[this] def wrapOrcStruct( - struct: OrcStruct, - oi: SettableStructObjectInspector, - row: InternalRow): Unit = { - val fieldRefs = oi.getAllStructFieldRefs - var i = 0 - val size = fieldRefs.size - while (i < size) { - - oi.setStructFieldData( - struct, - fieldRefs.get(i), - wrappers(i)(row.get(i, dataSchema(i).dataType)) - ) - i += 1 - } - } -} - private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { + private lazy val orcStruct: OrcStruct = + OrcUtils.createOrcValue(dataSchema).asInstanceOf[OrcStruct] - private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration) + private[this] val writableWrappers = + dataSchema.fields.map(f => OrcUtils.getWritableWrapper(f.dataType)) // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this // flag to decide whether `OrcRecordWriter.close()` needs to be called. private var recordWriterInstantiated = false - private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + private lazy val recordWriter = { recordWriterInstantiated = true - new OrcOutputFormat().getRecordWriter( - new Path(path).getFileSystem(context.getConfiguration), - context.getConfiguration.asInstanceOf[JobConf], - path, - Reporter.NULL - ).asInstanceOf[RecordWriter[NullWritable, Writable]] + new OrcOutputFormat[OrcStruct]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + }.getRecordWriter(context) } override def write(row: InternalRow): Unit = { - recordWriter.write(NullWritable.get(), serializer.serialize(row)) + recordWriter.write( + NullWritable.get, + OrcUtils.convertInternalRowToOrcStruct( + row, dataSchema, Some(writableWrappers), Some(orcStruct))) } override def close(): Unit = { if (recordWriterInstantiated) { - recordWriter.close(Reporter.NULL) + recordWriter.close(context) } } } - -private[orc] object OrcRelation extends HiveInspectors { - // The references of Hive's classes will be minimized. - val ORC_COMPRESSION = "orc.compress" - - // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. - private[orc] val SARG_PUSHDOWN = "sarg.pushdown" - - // The extensions for ORC compression codecs - val extensionsForCompressionCodecNames = Map( - "NONE" -> "", - "SNAPPY" -> ".snappy", - "ZLIB" -> ".zlib", - "LZO" -> ".lzo") - - def unwrapOrcStructs( - conf: Configuration, - dataSchema: StructType, - maybeStructOI: Option[StructObjectInspector], - iterator: Iterator[Writable]): Iterator[InternalRow] = { - val deserializer = new OrcSerde - val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(dataSchema) - - def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { - val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map { - case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal - }.unzip - - val unwrappers = fieldRefs.map(unwrapperFor) - - iterator.map { value => - val raw = deserializer.deserialize(value) - var i = 0 - val length = fieldRefs.length - while (i < length) { - val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) - } - i += 1 - } - unsafeProjection(mutableRow) - } - } - - maybeStructOI.map(unwrap).getOrElse(Iterator.empty) - } - - def setRequiredColumns( - conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) - val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip - HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index d9efd0cb457cd..df79d3c19cde8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.hive.orc -import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory} -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory} +import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.internal.Logging import org.apache.spark.sql.sources._ @@ -73,6 +74,30 @@ private[orc] object OrcFilters extends Logging { } yield builder.build() } + private def getPredicateLeafType(dataType: DataType) = dataType match { + case BooleanType => PredicateLeaf.Type.BOOLEAN + case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG + case FloatType | DoubleType => PredicateLeaf.Type.FLOAT + case StringType => PredicateLeaf.Type.STRING + case DateType => PredicateLeaf.Type.DATE + case TimestampType => PredicateLeaf.Type.TIMESTAMP + case _: DecimalType => PredicateLeaf.Type.DECIMAL + case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + } + + private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { + case ByteType | ShortType | IntegerType | LongType => + value.asInstanceOf[Number].longValue + case FloatType | DoubleType => + value.asInstanceOf[Number].doubleValue() + case _: DecimalType => + val decimal = value.asInstanceOf[java.math.BigDecimal] + val decimalWritable = new HiveDecimalWritable(decimal.longValue) + decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale) + decimalWritable + case _ => value + } + private def buildSearchArgument( dataTypeMap: Map[String, DataType], expression: Filter, @@ -88,6 +113,9 @@ private[orc] object OrcFilters extends Logging { case _ => false } + def getType(attribute: String): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(attribute)) + expression match { case And(left, right) => // At here, it is not safe to just convert one side if we do not understand the @@ -123,31 +151,39 @@ private[orc] object OrcFilters extends Logging { // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startAnd().equals(attribute, value).end()) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end()) case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startAnd().nullSafeEquals(attribute, value).end()) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end()) case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startAnd().lessThan(attribute, value).end()) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end()) case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startAnd().lessThanEquals(attribute, value).end()) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end()) case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startNot().lessThanEquals(attribute, value).end()) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end()) case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startNot().lessThan(attribute, value).end()) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end()) case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startAnd().isNull(attribute).end()) + Some(builder.startAnd().isNull(attribute, getType(attribute)).end()) case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startNot().isNull(attribute).end()) + Some(builder.startNot().isNull(attribute, getType(attribute)).end()) case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + Some(builder.startAnd().in(attribute, getType(attribute), + castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 7f94c8c579026..da9e2eadf7874 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc import java.util.Locale +import org.apache.orc.OrcConf + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -42,7 +44,7 @@ private[orc] class OrcOptions( val compressionCodec: String = { // `compression`, `orc.compress`, and `spark.sql.orc.compression.codec` are // in order of precedence from highest to lowest. - val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) + val orcCompressionConf = parameters.get(OrcConf.COMPRESS.getAttribute) val codecName = parameters .get("compression") .orElse(orcCompressionConf) @@ -65,4 +67,11 @@ private[orc] object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "UNCOMPRESSED" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcUtils.scala new file mode 100644 index 0000000000000..3964e455a8bcc --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcUtils.scala @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.IOException + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.io._ +import org.apache.orc.{OrcFile, TypeDescription} +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object OrcUtils { + /** + * Read ORC file schema. This method is used in `inferSchema`. + */ + private[orc] def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + try { + val options = OrcFile.readerOptions(conf).filesystem(FileSystem.get(conf)) + val reader = OrcFile.createReader(file, options) + val schema = reader.getSchema + if (schema.getFieldNames.isEmpty) { + None + } else { + Some(schema) + } + } catch { + case _: IOException => None + } + } + + /** + * Return ORC schema with schema field name correction. + */ + private[orc] def getSchema(dataSchema: StructType, filePath: String, conf: Configuration) = { + val hdfsPath = new Path(filePath) + val fs = hdfsPath.getFileSystem(conf) + val reader = OrcFile.createReader(hdfsPath, OrcFile.readerOptions(conf).filesystem(fs)) + val rawSchema = reader.getSchema + val orcSchema = if (!rawSchema.getFieldNames.isEmpty && + rawSchema.getFieldNames.asScala.forall(_.startsWith("_col"))) { + var schemaString = rawSchema.toString + dataSchema.zipWithIndex.foreach { case (field: StructField, index: Int) => + schemaString = schemaString.replace(s"_col$index:", s"${field.name}:") + } + TypeDescription.fromString(schemaString) + } else { + rawSchema + } + orcSchema + } + + /** + * Return a ORC schema string for ORCStruct. + */ + private[orc] def getSchemaString(schema: StructType): String = { + schema.fields.map(f => s"${f.name}:${f.dataType.catalogString}").mkString("struct<", ",", ">") + } + + private[orc] def getTypeDescription(dataType: DataType) = dataType match { + case st: StructType => TypeDescription.fromString(getSchemaString(st)) + case _ => TypeDescription.fromString(dataType.catalogString) + } + + /** + * Return a Orc value object for the given Spark schema. + */ + private[orc] def createOrcValue(dataType: DataType) = + OrcStruct.createValue(getTypeDescription(dataType)) + + /** + * Convert Apache ORC OrcStruct to Apache Spark InternalRow. + * If internalRow is not None, fill into it. Otherwise, create a SpecificInternalRow and use it. + */ + private[orc] def convertOrcStructToInternalRow( + orcStruct: OrcStruct, + schema: StructType, + valueWrappers: Option[Seq[Any => Any]] = None, + internalRow: Option[InternalRow] = None): InternalRow = { + val mutableRow = internalRow.getOrElse(new SpecificInternalRow(schema.map(_.dataType))) + val wrappers = valueWrappers.getOrElse(schema.fields.map(_.dataType).map(getValueWrapper).toSeq) + for (schemaIndex <- 0 until schema.length) { + val writable = orcStruct.getFieldValue(schema(schemaIndex).name) + if (writable == null) { + mutableRow.setNullAt(schemaIndex) + } else { + mutableRow(schemaIndex) = wrappers(schemaIndex)(writable) + } + } + mutableRow + } + + private def withNullSafe(f: Any => Any): Any => Any = { + input => if (input == null) null else f(input) + } + + /** + * Builds a catalyst-value return function ahead of time according to DataType + * to avoid pattern matching and branching costs per row. + */ + private[orc] def getValueWrapper(dataType: DataType): Any => Any = dataType match { + case NullType => _ => null + case BooleanType => withNullSafe(o => o.asInstanceOf[BooleanWritable].get) + case ByteType => withNullSafe(o => o.asInstanceOf[ByteWritable].get) + case ShortType => withNullSafe(o => o.asInstanceOf[ShortWritable].get) + case IntegerType => withNullSafe(o => o.asInstanceOf[IntWritable].get) + case LongType => withNullSafe(o => o.asInstanceOf[LongWritable].get) + case FloatType => withNullSafe(o => o.asInstanceOf[FloatWritable].get) + case DoubleType => withNullSafe(o => o.asInstanceOf[DoubleWritable].get) + case StringType => withNullSafe(o => UTF8String.fromBytes(o.asInstanceOf[Text].copyBytes)) + case BinaryType => + withNullSafe { o => + val binary = o.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + bytes + } + case DateType => + withNullSafe(o => DateTimeUtils.fromJavaDate(o.asInstanceOf[DateWritable].get)) + case TimestampType => + withNullSafe(o => DateTimeUtils.fromJavaTimestamp(o.asInstanceOf[OrcTimestamp])) + case DecimalType.Fixed(precision, scale) => + withNullSafe { o => + val decimal = o.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + v + } + case _: StructType => + withNullSafe { o => + val structValue = convertOrcStructToInternalRow( + o.asInstanceOf[OrcStruct], + dataType.asInstanceOf[StructType]) + structValue + } + case ArrayType(elementType, _) => + withNullSafe { o => + val wrapper = getValueWrapper(elementType) + val data = new scala.collection.mutable.ArrayBuffer[Any] + o.asInstanceOf[OrcList[WritableComparable[_]]].asScala.foreach { x => + data += wrapper(x) + } + new GenericArrayData(data.toArray) + } + case MapType(keyType, valueType, _) => + withNullSafe { o => + val keyWrapper = getValueWrapper(keyType) + val valueWrapper = getValueWrapper(valueType) + val map = new java.util.TreeMap[Any, Any] + o.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + .entrySet().asScala.foreach { entry => + map.put(keyWrapper(entry.getKey), valueWrapper(entry.getValue)) + } + ArrayBasedMapData(map.asScala) + } + case udt: UserDefinedType[_] => withNullSafe { o => getValueWrapper(udt.sqlType)(o) } + case _ => throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + /** + * Convert Apache Spark InternalRow to Apache ORC OrcStruct. + */ + private[orc] def convertInternalRowToOrcStruct( + row: InternalRow, + schema: StructType, + valueWrappers: Option[Seq[Any => Any]] = None, + struct: Option[OrcStruct] = None): OrcStruct = { + val wrappers = + valueWrappers.getOrElse(schema.fields.map(_.dataType).map(getWritableWrapper).toSeq) + val orcStruct = struct.getOrElse(createOrcValue(schema).asInstanceOf[OrcStruct]) + + for (schemaIndex <- 0 until schema.length) { + val fieldType = schema(schemaIndex).dataType + if (row.isNullAt(schemaIndex)) { + orcStruct.setFieldValue(schemaIndex, null) + } else { + val field = row.get(schemaIndex, fieldType) + val fieldValue = wrappers(schemaIndex)(field).asInstanceOf[WritableComparable[_]] + orcStruct.setFieldValue(schemaIndex, fieldValue) + } + } + orcStruct + } + + /** + * Builds a WritableComparable-return function ahead of time according to DataType + * to avoid pattern matching and branching costs per row. + */ + private[orc] def getWritableWrapper(dataType: DataType): Any => Any = dataType match { + case NullType => _ => null + case BooleanType => withNullSafe(o => new BooleanWritable(o.asInstanceOf[Boolean])) + case ByteType => withNullSafe(o => new ByteWritable(o.asInstanceOf[Byte])) + case ShortType => withNullSafe(o => new ShortWritable(o.asInstanceOf[Short])) + case IntegerType => withNullSafe(o => new IntWritable(o.asInstanceOf[Int])) + case LongType => withNullSafe(o => new LongWritable(o.asInstanceOf[Long])) + case FloatType => withNullSafe(o => new FloatWritable(o.asInstanceOf[Float])) + case DoubleType => withNullSafe(o => new DoubleWritable(o.asInstanceOf[Double])) + case StringType => withNullSafe(o => new Text(o.asInstanceOf[UTF8String].getBytes)) + case BinaryType => withNullSafe(o => new BytesWritable(o.asInstanceOf[Array[Byte]])) + case DateType => + withNullSafe(o => new DateWritable(DateTimeUtils.toJavaDate(o.asInstanceOf[Int]))) + case TimestampType => + withNullSafe { o => + val us = o.asInstanceOf[Long] + var seconds = us / DateTimeUtils.MICROS_PER_SECOND + var micros = us % DateTimeUtils.MICROS_PER_SECOND + if (micros < 0) { + micros += DateTimeUtils.MICROS_PER_SECOND + seconds -= 1 + } + val t = new OrcTimestamp(seconds * 1000) + t.setNanos(micros.toInt * 1000) + t + } + case _: DecimalType => + withNullSafe { o => + new HiveDecimalWritable(HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal)) + } + case st: StructType => + withNullSafe(o => convertInternalRowToOrcStruct(o.asInstanceOf[InternalRow], st)) + case ArrayType(et, _) => + withNullSafe { o => + val data = o.asInstanceOf[ArrayData] + val list = createOrcValue(dataType) + for (i <- 0 until data.numElements()) { + val d = data.get(i, et) + val v = getWritableWrapper(et)(d).asInstanceOf[WritableComparable[_]] + list.asInstanceOf[OrcList[WritableComparable[_]]].add(v) + } + list + } + case MapType(keyType, valueType, _) => + withNullSafe { o => + val keyWrapper = getWritableWrapper(keyType) + val valueWrapper = getWritableWrapper(valueType) + val data = o.asInstanceOf[MapData] + val map = createOrcValue(dataType) + .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + data.foreach(keyType, valueType, { case (k, v) => + map.put( + keyWrapper(k).asInstanceOf[WritableComparable[_]], + valueWrapper(v).asInstanceOf[WritableComparable[_]]) + }) + map + } + case udt: UserDefinedType[_] => + withNullSafe { o => + val udtRow = new SpecificInternalRow(Seq(udt.sqlType)) + udtRow(0) = o + convertInternalRowToOrcStruct( + udtRow, + StructType(Seq(StructField("tmp", udt.sqlType)))).getFieldValue(0) + } + case _ => throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index de6f0d67f1734..65c744a5f30d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.spark.sql.{Column, DataFrame, QueryTest} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -291,33 +291,25 @@ class OrcFilterSuite extends QueryTest with OrcTest { // This might have to be changed after Hive version is upgraded. checkFilterPredicate( '_1.isNotNull, - """leaf-0 = (IS_NULL _1) - |expr = (not leaf-0)""".stripMargin.trim + "leaf-0 = (IS_NULL _1), expr = (not leaf-0)" ) checkFilterPredicate( '_1 =!= 1, - """leaf-0 = (IS_NULL _1) - |leaf-1 = (EQUALS _1 1) - |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (EQUALS _1 1), expr = (and (not leaf-0) (not leaf-1))" ) checkFilterPredicate( !('_1 < 4), - """leaf-0 = (IS_NULL _1) - |leaf-1 = (LESS_THAN _1 4) - |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 4), expr = (and (not leaf-0) (not leaf-1))" ) checkFilterPredicate( '_1 < 2 || '_1 > 3, - """leaf-0 = (LESS_THAN _1 2) - |leaf-1 = (LESS_THAN_EQUALS _1 3) - |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + "leaf-0 = (LESS_THAN _1 2), leaf-1 = (LESS_THAN_EQUALS _1 3), " + + "expr = (or leaf-0 (not leaf-1))" ) checkFilterPredicate( '_1 < 2 && '_1 > 3, - """leaf-0 = (IS_NULL _1) - |leaf-1 = (LESS_THAN _1 2) - |leaf-2 = (LESS_THAN_EQUALS _1 3) - |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 2), leaf-2 = (LESS_THAN_EQUALS _1 3), " + + "expr = (and (not leaf-0) leaf-1 (not leaf-2))" ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 60ccd996d6d58..606a74b06213a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -17,11 +17,18 @@ package org.apache.spark.sql.hive.orc +import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc.{OrcConf, OrcFile} +import org.apache.orc.mapred.OrcStruct +import org.apache.orc.mapreduce.OrcInputFormat import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ @@ -58,6 +65,13 @@ case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { + private def getFileReader(path: String, extensions: String) = { + val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(extensions)) + assert(maybeOrcFile.isDefined) + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + OrcFile.createReader(orcFilePath, OrcFile.readerOptions(new Configuration())) + } + test("Read/write All Types") { val data = (0 to 255).map { i => (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) @@ -183,7 +197,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { .option("orc.compress", "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + getFileReader(file.getAbsolutePath, ".zlib.orc").getCompressionKind assert("ZLIB" === expectedCompressionKind.name()) } @@ -194,7 +208,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { .option("orc.compress", "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + getFileReader(file.getAbsolutePath, ".zlib.orc").getCompressionKind assert("ZLIB" === expectedCompressionKind.name()) } } @@ -206,7 +220,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { .option("compression", "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + getFileReader(file.getAbsolutePath, ".zlib.orc").getCompressionKind assert("ZLIB" === expectedCompressionKind.name()) } @@ -215,7 +229,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { .option("compression", "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + getFileReader(file.getAbsolutePath, ".snappy.orc").getCompressionKind assert("SNAPPY" === expectedCompressionKind.name()) } @@ -224,19 +238,18 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { .option("compression", "NONE") .orc(file.getCanonicalPath) val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + getFileReader(file.getAbsolutePath, ".orc").getCompressionKind assert("NONE" === expectedCompressionKind.name()) } } - // Following codec is not supported in Hive 1.2.1, ignore it now - ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") { + test("LZO compression options for writing to an ORC file") { withTempPath { file => spark.range(0, 10).write .option("compression", "LZO") .orc(file.getCanonicalPath) val expectedCompressionKind = - OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + getFileReader(file.getAbsolutePath, ".lzo.orc").getCompressionKind assert("LZO" === expectedCompressionKind.name()) } } @@ -592,18 +605,22 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("Empty schema does not read data from ORC file") { + // Please see ORC-233. We will turn on this later. + ignore("Empty schema does not read data from ORC file") { val data = Seq((1, 1), (2, 2)) withOrcFile(data) { path => - val requestedSchema = StructType(Nil) val conf = new Configuration() - val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) - val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) - assert(maybeOrcReader.isDefined) - val orcRecordReader = new SparkOrcNewRecordReader( - maybeOrcReader.get, conf, 0, maybeOrcReader.get.getContentLength) - + val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".snappy.orc")) + assert(maybeOrcFile.isDefined) + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val reader = OrcFile.createReader(orcFilePath, OrcFile.readerOptions(conf)) + val fileSplit = new FileSplit(orcFilePath, 0, reader.getContentLength, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + + conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, "") + val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val orcRecordReader = + new OrcInputFormat[OrcStruct].createRecordReader(fileSplit, taskAttemptContext) val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) try { assert(recordsIterator.next().toString == "{null, null}") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 781de6631f324..b7de82221f288 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -221,7 +222,7 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } } -class OrcSourceSuite extends OrcSuite { +class OrcSourceSuite extends OrcSuite with SQLTestUtils { override def beforeAll(): Unit = { super.beforeAll() @@ -249,9 +250,7 @@ class OrcSourceSuite extends OrcSuite { StructField("a", IntegerType, nullable = true), StructField("b", StringType, nullable = true))) assertResult( - """leaf-0 = (LESS_THAN a 10) - |expr = leaf-0 - """.stripMargin.trim + "leaf-0 = (LESS_THAN a 10), expr = leaf-0" ) { OrcFilters.createFilter(schema, Array( LessThan("a", 10), @@ -261,9 +260,7 @@ class OrcSourceSuite extends OrcSuite { // The `LessThan` should be converted while the whole inner `And` shouldn't assertResult( - """leaf-0 = (LESS_THAN a 10) - |expr = leaf-0 - """.stripMargin.trim + "leaf-0 = (LESS_THAN a 10), expr = leaf-0" ) { OrcFilters.createFilter(schema, Array( LessThan("a", 10), @@ -274,4 +271,13 @@ class OrcSourceSuite extends OrcSuite { )).get.toString } } + + test("SPARK-21791 ORC should support column names with dot") { + import spark.implicits._ + withTempDir { dir => + val path = new File(dir, "orc").getCanonicalPath + Seq(Some(1), None).toDF("col.dots").write.orc(path) + assert(spark.read.orc(path).collect().length == 2) + } + } }