Skip to content

Commit

Permalink
[SPARK-19459] Support for nested char/varchar fields in ORC
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This PR is a small follow-up on apache#16804. This PR also adds support for nested char/varchar fields in orc.

## How was this patch tested?
I have added a regression test to the OrcSourceSuite.

Author: Herman van Hovell <[email protected]>

Closes apache#17030 from hvanhovell/SPARK-19459-follow-up.
  • Loading branch information
hvanhovell authored and Yun Ni committed Feb 27, 2017
1 parent a0ce01e commit ae4a697
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}

override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
visit(ctx.dataType).asInstanceOf[DataType]
visitSparkDataType(ctx.dataType)
}

/* ********************************************************************************************
Expand Down Expand Up @@ -1006,7 +1006,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Create a [[Cast]] expression.
*/
override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
Cast(expression(ctx.expression), typedVisit(ctx.dataType))
Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType))
}

/**
Expand Down Expand Up @@ -1424,6 +1424,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
/* ********************************************************************************************
* DataType parsing
* ******************************************************************************************** */
/**
* Create a Spark DataType.
*/
private def visitSparkDataType(ctx: DataTypeContext): DataType = {
HiveStringType.replaceCharType(typedVisit(ctx))
}

/**
* Resolve/create a primitive type.
*/
Expand All @@ -1438,8 +1445,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case ("double", Nil) => DoubleType
case ("date", Nil) => DateType
case ("timestamp", Nil) => TimestampType
case ("char" | "varchar" | "string", Nil) => StringType
case ("char" | "varchar", _ :: Nil) => StringType
case ("string", Nil) => StringType
case ("char", length :: Nil) => CharType(length.getText.toInt)
case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
case ("binary", Nil) => BinaryType
case ("decimal", Nil) => DecimalType.USER_DEFAULT
case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
Expand All @@ -1461,7 +1469,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case SqlBaseParser.MAP =>
MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
case SqlBaseParser.STRUCT =>
createStructType(ctx.complexColTypeList())
StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList))
}
}

Expand All @@ -1480,7 +1488,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}

/**
* Create a [[StructField]] from a column definition.
* Create a top level [[StructField]] from a column definition.
*/
override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
import ctx._
Expand All @@ -1491,19 +1499,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
builder.putString("comment", string(STRING))
}
// Add Hive type string to metadata.
dataType match {
case p: PrimitiveDataTypeContext =>
p.identifier.getText.toLowerCase match {
case "varchar" | "char" =>
builder.putString(HIVE_TYPE_STRING, dataType.getText.toLowerCase)
case _ =>
}
case _ =>
val rawDataType = typedVisit[DataType](ctx.dataType)
val cleanedDataType = HiveStringType.replaceCharType(rawDataType)
if (rawDataType != cleanedDataType) {
builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString)
}

StructField(
identifier.getText,
typedVisit(dataType),
cleanedDataType,
nullable = true,
builder.build())
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.types

import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag

import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.unsafe.types.UTF8String

/**
* A hive string type for compatibility. These datatypes should only used for parsing,
* and should NOT be used anywhere else. Any instance of these data types should be
* replaced by a [[StringType]] before analysis.
*/
sealed abstract class HiveStringType extends AtomicType {
private[sql] type InternalType = UTF8String

private[sql] val ordering = implicitly[Ordering[InternalType]]

@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized {
typeTag[InternalType]
}

override def defaultSize: Int = length

private[spark] override def asNullable: HiveStringType = this

def length: Int
}

object HiveStringType {
def replaceCharType(dt: DataType): DataType = dt match {
case ArrayType(et, nullable) =>
ArrayType(replaceCharType(et), nullable)
case MapType(kt, vt, nullable) =>
MapType(replaceCharType(kt), replaceCharType(vt), nullable)
case StructType(fields) =>
StructType(fields.map { field =>
field.copy(dataType = replaceCharType(field.dataType))
})
case _: HiveStringType => StringType
case _ => dt
}
}

/**
* Hive char type.
*/
case class CharType(length: Int) extends HiveStringType {
override def simpleString: String = s"char($length)"
}

/**
* Hive varchar type.
*/
case class VarcharType(length: Int) extends HiveStringType {
override def simpleString: String = s"varchar($length)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,16 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
|CREATE EXTERNAL TABLE hive_orc(
| a STRING,
| b CHAR(10),
| c VARCHAR(10))
| c VARCHAR(10),
| d ARRAY<CHAR(3)>)
|STORED AS orc""".stripMargin)
// Hive throws an exception if I assign the location in the create table statement.
hiveClient.runSqlHive(
s"ALTER TABLE hive_orc SET LOCATION '$uri'")
hiveClient.runSqlHive(
"INSERT INTO TABLE hive_orc SELECT 'a', 'b', 'c' FROM (SELECT 1) t")
"""INSERT INTO TABLE hive_orc
|SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3)))
|FROM (SELECT 1) t""".stripMargin)

// We create a different table in Spark using the same schema which points to
// the same location.
Expand All @@ -177,10 +180,11 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
|CREATE EXTERNAL TABLE spark_orc(
| a STRING,
| b CHAR(10),
| c VARCHAR(10))
| c VARCHAR(10),
| d ARRAY<CHAR(3)>)
|STORED AS orc
|LOCATION '$uri'""".stripMargin)
val result = Row("a", "b ", "c")
val result = Row("a", "b ", "c", Seq("d "))
checkAnswer(spark.table("hive_orc"), result)
checkAnswer(spark.table("spark_orc"), result)
} finally {
Expand Down

0 comments on commit ae4a697

Please sign in to comment.