Skip to content

Commit

Permalink
[SPARK-14387][SPARK-16628][SPARK-18355][SQL] Use Spark schema to read…
Browse files Browse the repository at this point in the history
… ORC table instead of ORC file schema

Before Hive 2.0, ORC File schema has invalid column names like `_col1` and `_col2`. This is a well-known limitation and there are several Apache Spark issues with `spark.sql.hive.convertMetastoreOrc=true`. This PR ignores ORC File schema and use Spark schema.

Pass the newly added test case.

Author: Dongjoon Hyun <[email protected]>

Closes #19470 from dongjoon-hyun/SPARK-18355.

(cherry picked from commit e6e3600)
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
dongjoon-hyun authored and cloud-fan committed Oct 13, 2017
1 parent c9187db commit 30d5c9f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
// SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this
// case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file
// using the given physical schema. Instead, we simply return an empty iterator.
val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf))
if (maybePhysicalSchema.isEmpty) {
val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty
if (isEmptyFile) {
Iterator.empty
} else {
val physicalSchema = maybePhysicalSchema.get
OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema)
OrcRelation.setRequiredColumns(conf, dataSchema, requiredSchema)

val orcRecordReader = {
val job = Job.getInstance(conf)
Expand All @@ -163,6 +162,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
// Unwraps `OrcStruct`s to `UnsafeRow`s
OrcRelation.unwrapOrcStructs(
conf,
dataSchema,
requiredSchema,
Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]),
recordsIterator)
Expand Down Expand Up @@ -272,25 +272,32 @@ private[orc] object OrcRelation extends HiveInspectors {
def unwrapOrcStructs(
conf: Configuration,
dataSchema: StructType,
requiredSchema: StructType,
maybeStructOI: Option[StructObjectInspector],
iterator: Iterator[Writable]): Iterator[InternalRow] = {
val deserializer = new OrcSerde
val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType))
val unsafeProjection = UnsafeProjection.create(dataSchema)
val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType))
val unsafeProjection = UnsafeProjection.create(requiredSchema)

def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = {
val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map {
case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal
val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map {
case (field, ordinal) =>
var ref = oi.getStructFieldRef(field.name)
if (ref == null) {
ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name))
}
ref -> ordinal
}.unzip

val unwrappers = fieldRefs.map(unwrapperFor)
val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r))

iterator.map { value =>
val raw = deserializer.deserialize(value)
var i = 0
val length = fieldRefs.length
while (i < length) {
val fieldValue = oi.getStructFieldData(raw, fieldRefs(i))
val fieldRef = fieldRefs(i)
val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef)
if (fieldValue == null) {
mutableRow.setNullAt(fieldOrdinals(i))
} else {
Expand All @@ -306,8 +313,8 @@ private[orc] object OrcRelation extends HiveInspectors {
}

def setRequiredColumns(
conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = {
val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer)
conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = {
val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer)
val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip
HiveShim.appendReadColumns(conf, sortedIDs, sortedNames)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
Expand Down Expand Up @@ -2034,4 +2034,64 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
assert(setOfPath.size() == pathSizeToDeleteOnExit)
}
}

Seq("orc", "parquet").foreach { format =>
test(s"SPARK-18355 Read data from a hive table with a new column - $format") {
val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client

Seq("true", "false").foreach { value =>
withSQLConf(
HiveUtils.CONVERT_METASTORE_ORC.key -> value,
HiveUtils.CONVERT_METASTORE_PARQUET.key -> value) {
withTempDatabase { db =>
client.runSqlHive(
s"""
|CREATE TABLE $db.t(
| click_id string,
| search_id string,
| uid bigint)
|PARTITIONED BY (
| ts string,
| hour string)
|STORED AS $format
""".stripMargin)

client.runSqlHive(
s"""
|INSERT INTO TABLE $db.t
|PARTITION (ts = '98765', hour = '01')
|VALUES (12, 2, 12345)
""".stripMargin
)

checkAnswer(
sql(s"SELECT click_id, search_id, uid, ts, hour FROM $db.t"),
Row("12", "2", 12345, "98765", "01"))

client.runSqlHive(s"ALTER TABLE $db.t ADD COLUMNS (dummy string)")

checkAnswer(
sql(s"SELECT click_id, search_id FROM $db.t"),
Row("12", "2"))

checkAnswer(
sql(s"SELECT search_id, click_id FROM $db.t"),
Row("2", "12"))

checkAnswer(
sql(s"SELECT search_id FROM $db.t"),
Row("2"))

checkAnswer(
sql(s"SELECT dummy, click_id FROM $db.t"),
Row(null, "12"))

checkAnswer(
sql(s"SELECT click_id, search_id, uid, dummy, ts, hour FROM $db.t"),
Row("12", "2", 12345, null, "98765", "01"))
}
}
}
}
}
}

0 comments on commit 30d5c9f

Please sign in to comment.