From d7344968aacaca0c418653c0ed3bd4daa5f78409 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 May 2015 00:17:51 +0800 Subject: [PATCH] Polishes the ORC data source --- .../scala/org/apache/spark/sql/SQLConf.scala | 7 +- .../spark/sql/parquet/ParquetTest.scala | 61 +---- .../apache/spark/sql/test/SQLTestUtils.scala | 81 ++++++ .../sql/hive/orc/HadoopTypeConverter.scala | 3 +- .../spark/sql/hive/orc/OrcFileOperator.scala | 14 +- .../spark/sql/hive/orc/OrcFilters.scala | 146 ++++++++--- .../spark/sql/hive/orc/OrcRelation.scala | 248 +++++++++++++----- .../sql/hive/orc/OrcTableOperations.scala | 119 --------- .../apache/spark/sql/hive/orc/package.scala | 24 +- .../spark/sql/hive/orc/NewOrcQuerySuite.scala | 177 +++++++++++++ ...e.scala => OrcHadoopFsRelationSuite.scala} | 5 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 1 + .../spark/sql/hive/orc/OrcQuerySuite.scala | 27 +- .../sql/sources/hadoopFsRelationSuites.scala | 6 +- 14 files changed, 591 insertions(+), 328 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala rename sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/{OrcRelationSuite.scala => OrcHadoopFsRelationSuite.scala} (94%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f07bb196c11ec..6da910e332e9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -43,6 +43,8 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" + val ORC_FILTER_PUSHDOWN_ENABLED = "spark.sql.orc.filterPushdown" + val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" @@ -143,6 +145,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def parquetUseDataSourceApi = getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + private[spark] def orcFilterPushDown = + getConf(ORC_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** When true uses verifyPartitionPath to prune the path which is not exists. */ private[spark] def verifyPartitionPath = getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean @@ -254,7 +259,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean - + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 7a73b6f1ac601..516ba373f41d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -21,10 +21,9 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.Try -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,54 +32,9 @@ import org.apache.spark.util.Utils * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest { - val sqlContext: SQLContext - +private[sql] trait ParquetTest extends SQLTestUtils { import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} - import sqlContext.{conf, sparkContext} - - protected def configuration = sparkContext.hadoopConfiguration - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) - (keys, values).zipped.foreach(conf.setConf) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConf(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } - - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } + import sqlContext.sparkContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -105,13 +59,6 @@ private[sql] trait ParquetTest { withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } - /** - * Drops temporary table `tableName` after calling `f`. - */ - protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.dropTempTable(tableName) - } - /** * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala new file mode 100644 index 0000000000000..75d290625ec38 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File + +import scala.util.Try + +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +trait SQLTestUtils { + val sqlContext: SQLContext + + import sqlContext.{conf, sparkContext} + + protected def configuration = sparkContext.hadoopConfiguration + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConf(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempTable(tableName: String)(f: => Unit): Unit = { + try f finally sqlContext.dropTempTable(tableName) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala index 713c076aee457..b5b5e56079cc3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/HadoopTypeConverter.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.hive.orc - import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} +import org.apache.spark.sql.hive.HiveInspectors /** * We can consolidate TableReader.unwrappers and HiveInspectors.wrapperFor to use diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 4dd2d8951b728..1e51173a19882 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -28,28 +28,25 @@ import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType private[orc] object OrcFileOperator extends Logging{ - def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { val conf = config.getOrElse(new Configuration) val fspath = new Path(pathStr) val fs = fspath.getFileSystem(conf) val orcFiles = listOrcFiles(pathStr, conf) - OrcFile.createReader(fs, orcFiles(0)) + + // TODO Need to consider all files when schema evolution is taken into account. + OrcFile.createReader(fs, orcFiles.head) } def readSchema(path: String, conf: Option[Configuration]): StructType = { val reader = getFileReader(path, conf) - val readerInspector: StructObjectInspector = reader.getObjectInspector - .asInstanceOf[StructObjectInspector] + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = { - val reader = getFileReader(path, conf) - val readerInspector: StructObjectInspector = reader.getObjectInspector - .asInstanceOf[StructObjectInspector] - readerInspector + getFileReader(path, conf).getObjectInspector.asInstanceOf[StructObjectInspector] } def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { @@ -66,6 +63,7 @@ private[orc] object OrcFileOperator extends Logging{ throw new IllegalArgumentException( s"orcFileOperator: path $path does not have valid orc files matching the pattern") } + paths } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index eda1cffe49810..9bee4f59b5854 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.hive.orc +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.hadoop.hive.serde2.io.DateWritable + import org.apache.spark.Logging import org.apache.spark.sql.sources._ @@ -29,48 +32,113 @@ import org.apache.spark.sql.sources._ */ private[orc] object OrcFilters extends Logging { def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - if (expr.nonEmpty) { - expr.foldLeft(Some(SearchArgument.FACTORY.newBuilder().startAnd()): Option[Builder]) { - (maybeBuilder, e) => createFilter(e, maybeBuilder) - }.map(_.end().build()) - } else { - None + expr.reduceOption(And).flatMap { conjunction => + val builder = SearchArgument.FACTORY.newBuilder() + buildSearchArgument(conjunction, builder).map(_.build()) } } - private def createFilter(expression: Filter, maybeBuilder: Option[Builder]): Option[Builder] = { - maybeBuilder.flatMap { builder => - expression match { - case p@And(left, right) => - for { - lhs <- createFilter(left, Some(builder.startAnd())) - rhs <- createFilter(right, Some(lhs)) - } yield rhs.end() - case p@Or(left, right) => - for { - lhs <- createFilter(left, Some(builder.startOr())) - rhs <- createFilter(right, Some(lhs)) - } yield rhs.end() - case p@Not(child) => - createFilter(child, Some(builder.startNot())).map(_.end()) - case p@EqualTo(attribute, value) => - Some(builder.equals(attribute, value)) - case p@LessThan(attribute, value) => - Some(builder.lessThan(attribute, value)) - case p@LessThanOrEqual(attribute, value) => - Some(builder.lessThanEquals(attribute, value)) - case p@GreaterThan(attribute, value) => - Some(builder.startNot().lessThanEquals(attribute, value).end()) - case p@GreaterThanOrEqual(attribute, value) => - Some(builder.startNot().lessThan(attribute, value).end()) - case p@IsNull(attribute) => - Some(builder.isNull(attribute)) - case p@IsNotNull(attribute) => - Some(builder.startNot().isNull(attribute).end()) - case p@In(attribute, values) => - Some(builder.in(attribute, values)) - case _ => None - } + private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { + def newBuilder = SearchArgument.FACTORY.newBuilder() + + def isSearchableLiteral(value: Any) = value match { + // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. + case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar | + _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true + case _ => false + } + + // lian: I probably missed something here, and had to end up with a pretty weird double-checking + // pattern when converting `And`/`Or`/`Not` filters. + // + // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, + // and `startNot()` mutate internal state of the builder instance. This forces us to translate + // all convertible filters with a single builder instance. However, before actually converting a + // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible + // filter is found, we may already end up with a builder whose internal state is inconsistent. + // + // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and + // then try to convert its children. Say we convert `left` child successfully, but find that + // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is + // inconsistent now. + // + // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + // children with brand new builders, and only do the actual conversion with the right builder + // instance when the children are proven to be convertible. + // + // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. + // Usage of builder methods mentioned above can only be found in test code, where all tested + // filters are known to be convertible. + + expression match { + case And(left, right) => + val tryLeft = buildSearchArgument(left, newBuilder) + val tryRight = buildSearchArgument(right, newBuilder) + + val conjunction = for { + _ <- tryLeft + _ <- tryRight + lhs <- buildSearchArgument(left, builder.startAnd()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + // For filter `left AND right`, we can still push down `left` even if `right` is not + // convertible, and vice versa. + conjunction + .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) + .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) + + case And(left, right) => + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) + lhs <- buildSearchArgument(left, builder.startOr()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(child, newBuilder) + negate <- buildSearchArgument(child, builder.startNot()) + } yield negate.end() + + case EqualTo(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.equals(attribute, _)) + + case LessThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThan(attribute, _)) + + case LessThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThanEquals(attribute, _)) + + case GreaterThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThanEquals(attribute, _).end()) + + case GreaterThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThan(attribute, _).end()) + + case IsNull(attribute) => + Some(builder.isNull(attribute)) + + case IsNotNull(attribute) => + Some(builder.startNot().isNull(attribute).end()) + + case In(attribute, values) => + Option(values) + .filter(_.forall(isSearchableLiteral)) + .map(builder.in(attribute, _)) + + case _ => None } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 44ac728b09aa3..3e3c8a9e619d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -17,99 +17,114 @@ package org.apache.spark.sql.hive.orc -import java.util.Objects +import java.util.{Objects, Properties} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.io.orc.{OrcSerde, OrcOutputFormat} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector} -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, TypeInfo} -import org.apache.hadoop.io.{Writable, NullWritable} -import org.apache.hadoop.mapred.{RecordWriter, Reporter, JobConf} -import org.apache.hadoop.mapreduce.{TaskID, TaskAttemptContext} - -import org.apache.spark.Logging +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.sources._ +import org.apache.spark.{Logging, SerializableWritable} /* Implicit conversions */ import scala.collection.JavaConversions._ - -private[sql] class DefaultSource extends FSBasedRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider { def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - parameters: Map[String, String]): FSBasedRelation ={ + parameters: Map[String, String]): HadoopFsRelation = { val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty[Partition])) - OrcRelation(paths, parameters, - schema, partitionSpec)(sqlContext) + OrcRelation(paths, parameters, schema, partitionSpec)(sqlContext) } } +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter with SparkHadoopMapRedUtil { -private[sql] class OrcOutputWriter extends OutputWriter with SparkHadoopMapRedUtil { - - var taskAttemptContext: TaskAttemptContext = _ - var serializer: OrcSerde = _ - var wrappers: Array[Any => Any] = _ - var created = false - var path: String = _ - var dataSchema: StructType = _ - var fieldOIs: Array[ObjectInspector] = _ - var structOI: StructObjectInspector = _ - var outputData: Array[Any] = _ - lazy val recordWriter: RecordWriter[NullWritable, Writable] = { - created = true - val conf = taskAttemptContext.getConfiguration - val taskId: TaskID = taskAttemptContext.getTaskAttemptID.getTaskID - val partition: Int = taskId.getId - val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" - val file = new Path(path, filename) - val fs = file.getFileSystem(conf) - val outputFormat = new OrcOutputFormat() - outputFormat.getRecordWriter(fs, - conf.asInstanceOf[JobConf], - file.toUri.getPath, Reporter.NULL) - .asInstanceOf[org.apache.hadoop.mapred.RecordWriter[NullWritable, Writable]] + private val serializer = { + val table = new Properties() + table.setProperty("columns", dataSchema.fieldNames.mkString(",")) + table.setProperty("columns.types", dataSchema.map { f => + HiveMetastoreTypes.toMetastoreType(f.dataType) + }.mkString(":")) + + val serde = new OrcSerde + serde.initialize(context.getConfiguration, table) + serde } - override def init(path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = { - this.path = path - taskAttemptContext = context - val orcSchema = HiveMetastoreTypes.toMetastoreType(dataSchema) - serializer = new OrcSerde - val typeInfo: TypeInfo = - TypeInfoUtils.getTypeInfoFromTypeString(orcSchema) - structOI = TypeInfoUtils + // Object inspector converted from the schema of the relation to be written. + private val structOI = { + val typeInfo = + TypeInfoUtils.getTypeInfoFromTypeString( + HiveMetastoreTypes.toMetastoreType(dataSchema)) + + TypeInfoUtils .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) .asInstanceOf[StructObjectInspector] - fieldOIs = structOI - .getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - outputData = new Array[Any](fieldOIs.length) - wrappers = fieldOIs.map(HadoopTypeConverter.wrappers) + } + + // Used to hold temporary `Writable` fields of the next row to be written. + private val reusableOutputBuffer = new Array[Any](dataSchema.length) + + // Used to convert Catalyst values into Hadoop `Writable`s. + private val wrappers = structOI.getAllStructFieldRefs.map { ref => + HadoopTypeConverter.wrappers(ref.getFieldObjectInspector) + }.toArray + + // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this + // flag to decide whether `OrcRecordWriter.close()` needs to be called. + private var recordWriterInstantiated = false + + private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + recordWriterInstantiated = true + + val conf = context.getConfiguration + val partition = context.getTaskAttemptID.getTaskID.getId + val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + + new OrcOutputFormat().getRecordWriter( + new Path(path, filename).getFileSystem(conf), + conf.asInstanceOf[JobConf], + new Path(path, filename).toUri.getPath, + Reporter.NULL + ).asInstanceOf[RecordWriter[NullWritable, Writable]] } override def write(row: Row): Unit = { var i = 0 while (i < row.length) { - outputData(i) = wrappers(i)(row(i)) + reusableOutputBuffer(i) = wrappers(i)(row(i)) i += 1 } - val writable = serializer.serialize(outputData, structOI) - recordWriter.write(NullWritable.get(), writable) + + recordWriter.write( + NullWritable.get(), + serializer.serialize(reusableOutputBuffer, structOI)) } override def close(): Unit = { - if (created) { + if (recordWriterInstantiated) { recordWriter.close(Reporter.NULL) } } @@ -122,13 +137,16 @@ private[sql] case class OrcRelation( maybeSchema: Option[StructType] = None, maybePartitionSpec: Option[PartitionSpec] = None)( @transient val sqlContext: SQLContext) - extends FSBasedRelation(paths, maybePartitionSpec) + extends HadoopFsRelation(maybePartitionSpec) with Logging { - override val dataSchema: StructType = - maybeSchema.getOrElse(OrcFileOperator.readSchema(paths(0), - Some(sqlContext.sparkContext.hadoopConfiguration))) - override def outputWriterClass: Class[_ <: OutputWriter] = classOf[OrcOutputWriter] + override val dataSchema: StructType = maybeSchema.getOrElse { + OrcFileOperator.readSchema( + paths.head, Some(sqlContext.sparkContext.hadoopConfiguration)) + } + + override def userDefinedPartitionColumns: Option[StructType] = + maybePartitionSpec.map(_.partitionColumns) override def needConversion: Boolean = false @@ -155,4 +173,106 @@ private[sql] case class OrcRelation( val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes OrcTableScan(output, this, filters, inputPaths).execute() } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + } + } +} + +private[orc] case class OrcTableScan( + attributes: Seq[Attribute], + @transient relation: OrcRelation, + filters: Array[Filter], + inputPaths: Array[String]) extends Logging { + @transient private val sqlContext = relation.sqlContext + + private def addColumnIds( + output: Seq[Attribute], + relation: OrcRelation, + conf: Configuration): Unit = { + val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) + val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIds, sortedNames) + } + + // Transform all given raw `Writable`s into `Row`s. + private def fillObject( + path: String, + conf: Configuration, + iterator: Iterator[Writable], + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val deserializer = new OrcSerde + val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + mutableRow: Row + } + } + + def execute(): RDD[Row] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration + + // Tries to push down filters if ORC filter push-down is enabled + if (sqlContext.conf.orcFilterPushDown) { + OrcFilters.createFilter(filters).foreach { f => + conf.set(SARG_PUSHDOWN, f.toKryo) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + // Sets requested columns + addColumnIds(attributes, relation, conf) + + if (inputPaths.nonEmpty) { + FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) + } + + val inputFormatClass = + classOf[OrcInputFormat] + .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]] + + val rdd = sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], + inputFormatClass, + classOf[NullWritable], + classOf[Writable] + ).asInstanceOf[HadoopRDD[NullWritable, Writable]] + + val wrappedConf = new SerializableWritable(conf) + + rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + fillObject( + split.getPath.toString, + wrappedConf.value, + iterator.map(_._2), + attributes.zipWithIndex, + mutableRow) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala deleted file mode 100644 index 2163b0ce70e99..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcTableOperations.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.orc - -import java.util._ -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc._ -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat - -import org.apache.spark.rdd.{HadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.sources.Filter -import org.apache.spark.{Logging, SerializableWritable} - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - -private[orc] case class OrcTableScan(attributes: Seq[Attribute], - @transient relation: OrcRelation, - filters: Array[Filter], - inputPaths: Array[String]) extends Logging { - @transient private val sqlContext = relation.sqlContext - - private def addColumnIds( - output: Seq[Attribute], - relation: OrcRelation, - conf: Configuration): Unit = { - val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) - val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip - HiveShim.appendReadColumns(conf, sortedIds, sortedNames) - } - - private def buildFilter(job: Job, filters: Array[Filter]): Unit = { - if (ORC_FILTER_PUSHDOWN_ENABLED) { - val conf: Configuration = job.getConfiguration - OrcFilters.createFilter(filters).foreach { f => - conf.set(SARG_PUSHDOWN, toKryo(f)) - conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) - } - } - } - - // Transform all given raw `Writable`s into `Row`s. - private def fillObject( - path: String, - conf: Configuration, - iterator: Iterator[org.apache.hadoop.io.Writable], - nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { - val deserializer = new OrcSerde - val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) - val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { - case (attr, ordinal) => - soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal - }.unzip - val unwrappers = HadoopTypeConverter.unwrappers(fieldRefs) - // Map each tuple to a row object - iterator.map { value => - val raw = deserializer.deserialize(value) - var i = 0 - while (i < fieldRefs.length) { - val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) - } - i += 1 - } - mutableRow: Row - } - } - - def execute(): RDD[Row] = { - val sc = sqlContext.sparkContext - val job = new Job(sc.hadoopConfiguration) - val conf: Configuration = job.getConfiguration - - buildFilter(job, filters) - addColumnIds(attributes, relation, conf) - FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) - - val inputClass = classOf[OrcInputFormat].asInstanceOf[ - Class[_ <: org.apache.hadoop.mapred.InputFormat[NullWritable, Writable]]] - - val rdd = sc.hadoopRDD(conf.asInstanceOf[JobConf], - inputClass, classOf[NullWritable], classOf[Writable]) - .asInstanceOf[HadoopRDD[NullWritable, Writable]] - val wrappedConf = new SerializableWritable(conf) - val rowRdd: RDD[Row] = rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iter) => - val pathStr = split.getPath.toString - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) - fillObject(pathStr, wrappedConf.value, iter.map(_._2), attributes.zipWithIndex, mutableRow) - } - rowRdd - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala index b219fbb44ca0d..869c8a5b8f1db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/package.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.hive -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.Kryo -import org.apache.commons.codec.binary.Base64 -import org.apache.spark.sql.{SaveMode, DataFrame} +import org.apache.spark.sql.{DataFrame, SaveMode} package object orc { implicit class OrcContext(sqlContext: HiveContext) { - import sqlContext._ @scala.annotation.varargs def orcFile(path: String, paths: String*): DataFrame = { val pathArray: Array[String] = { @@ -34,29 +30,21 @@ package object orc { paths.toArray ++ Array(path) } } + val orcRelation = OrcRelation(pathArray, Map.empty)(sqlContext) sqlContext.baseRelationToDataFrame(orcRelation) } } - implicit class OrcSchemaRDD(dataFrame: DataFrame) { + implicit class OrcDataFrame(dataFrame: DataFrame) { def saveAsOrcFile(path: String, mode: SaveMode = SaveMode.Overwrite): Unit = { - dataFrame.save( - path, - source = classOf[DefaultSource].getCanonicalName, - mode) + dataFrame.save(path, source = classOf[DefaultSource].getCanonicalName, mode) } } // Flags for orc copression, predicates pushdown, etc. val orcDefaultCompressVar = "hive.exec.orc.default.compress" - var ORC_FILTER_PUSHDOWN_ENABLED = true - val SARG_PUSHDOWN = "sarg.pushdown" - def toKryo(input: Any): String = { - val out = new Output(4 * 1024, 10 * 1024 * 1024); - new Kryo().writeObject(out, input); - out.close(); - Base64.encodeBase64String(out.toBytes()); - } + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + val SARG_PUSHDOWN = "sarg.pushdown" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala new file mode 100644 index 0000000000000..7e326de1335e0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/NewOrcQuerySuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql._ + +private[sql] trait OrcTest extends SQLTestUtils { + protected def hiveContext = sqlContext.asInstanceOf[HiveContext] + + import sqlContext.sparkContext + import sqlContext.implicits._ + + /** + * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withOrcFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + sparkContext.parallelize(data).toDF().saveAsOrcFile(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Orc file will be deleted after `f` returns. + */ + protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + import org.apache.spark.sql.hive.orc.OrcContext + withOrcFile(data)(path => f(hiveContext.orcFile(path))) + } + + /** + * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Orc file will be dropped/deleted after `f` returns. + */ + protected def withOrcTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withOrcDataFrame(data) { df => + hiveContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.save(path.getCanonicalPath, "org.apache.spark.sql.orc", SaveMode.Overwrite) + } +} + +class NewOrcQuerySuite extends QueryTest with OrcTest { + override val sqlContext: SQLContext = TestHive + + import sqlContext._ + + test("simple select queries") { + withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer( + sql("SELECT `_1` FROM t where t.`_1` > 5"), + (6 until 10).map(Row.apply(_))) + + checkAnswer( + sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withOrcTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("columns only referenced by pushed down filters should remain") { + withOrcTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in orc") { + withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + List(Row("same", "run_5", 100))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala similarity index 94% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 1d8c421b90678..90812b03fd2e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.sources.{FSBasedRelationTest} +import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ - -class FSBasedOrcRelationSuite extends FSBasedRelationTest { +class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[DefaultSource].getCanonicalName import sqlContext._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 31a829a81124d..55d8b8c71d9ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -37,6 +37,7 @@ case class OrcParData(intField: Int, stringField: String) // The data that also includes the partitioning key case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) +// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 475af3d4c94e4..3d52c31eca9f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -112,15 +112,16 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("orcTest").getCanonicalPath - val range = (0 to 255) - val data = sparkContext.parallelize(range) - .map(x => AllDataTypesWithNonPrimitiveType( - s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 until x), - (0 until x).map(Option(_).filter(_ % 3 == 0)), - (0 until x).map(i => i -> i.toLong).toMap, - (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), - Data((0 until x), Nested(x, s"$x")))) + val range = 0 to 255 + val data = sparkContext.parallelize(range).map { x => + AllDataTypesWithNonPrimitiveType( + s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + 0 until x, + (0 until x).map(Option(_).filter(_ % 3 == 0)), + (0 until x).map(i => i -> i.toLong).toMap, + (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), + Data(0 until x, Nested(x, s"$x"))) + } data.toDF().saveAsOrcFile(tempDir) checkAnswer( @@ -204,11 +205,11 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { // We only support zlib in hive0.12.0 now test("Default Compression options for writing to an Orcfile") { // TODO: support other compress codec - var tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize((1 to 100)) + val tempDir = getTempFilePath("orcTest").getCanonicalPath + val rdd = sparkContext.parallelize(1 to 100) .map(i => TestRDDEntry(i, s"val_$i")) rdd.toDF().saveAsOrcFile(tempDir) - var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression + val actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression assert(actualCodec == CompressionKind.ZLIB) Utils.deleteRecursively(new File(tempDir)) } @@ -217,7 +218,7 @@ class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { ignore("Other Compression options for writing to an Orcfile - 0.13.1 and above") { TestHive.sparkContext.hadoopConfiguration.set(orcDefaultCompressVar, "SNAPPY") var tempDir = getTempFilePath("orcTest").getCanonicalPath - val rdd = sparkContext.parallelize((1 to 100)) + val rdd = sparkContext.parallelize(1 to 100) .map(i => TestRDDEntry(i, s"val_$i")) rdd.toDF().saveAsOrcFile(tempDir) var actualCodec = OrcFileOperator.getFileReader(tempDir).getCompression diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index f44b3c521e647..082933e0390f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -23,12 +23,10 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -// TODO Don't extend ParquetTest -// This test suite extends ParquetTest for some convenient utility methods. These methods should be -// moved to some more general places, maybe QueryTest. -class HadoopFsRelationTest extends QueryTest with ParquetTest { +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { override val sqlContext: SQLContext = TestHive import sqlContext._