From ae4a6971e55837ef1a5a7ef7cdc2a0086695b46b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 23 Feb 2017 10:25:18 -0800 Subject: [PATCH] [SPARK-19459] Support for nested char/varchar fields in ORC ## What changes were proposed in this pull request? This PR is a small follow-up on https://github.com/apache/spark/pull/16804. This PR also adds support for nested char/varchar fields in orc. ## How was this patch tested? I have added a regression test to the OrcSourceSuite. Author: Herman van Hovell Closes #17030 from hvanhovell/SPARK-19459-follow-up. --- .../sql/catalyst/parser/AstBuilder.scala | 34 +++++---- .../spark/sql/types/HiveStringType.scala | 73 +++++++++++++++++++ .../spark/sql/hive/orc/OrcSourceSuite.scala | 12 ++- 3 files changed, 100 insertions(+), 19 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 926a37b363f1b..d2e091f4dda69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -76,7 +76,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { - visit(ctx.dataType).asInstanceOf[DataType] + visitSparkDataType(ctx.dataType) } /* ******************************************************************************************** @@ -1006,7 +1006,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[Cast]] expression. */ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { - Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } /** @@ -1424,6 +1424,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { /* ******************************************************************************************** * DataType parsing * ******************************************************************************************** */ + /** + * Create a Spark DataType. + */ + private def visitSparkDataType(ctx: DataTypeContext): DataType = { + HiveStringType.replaceCharType(typedVisit(ctx)) + } + /** * Resolve/create a primitive type. */ @@ -1438,8 +1445,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("double", Nil) => DoubleType case ("date", Nil) => DateType case ("timestamp", Nil) => TimestampType - case ("char" | "varchar" | "string", Nil) => StringType - case ("char" | "varchar", _ :: Nil) => StringType + case ("string", Nil) => StringType + case ("char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) case ("binary", Nil) => BinaryType case ("decimal", Nil) => DecimalType.USER_DEFAULT case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) @@ -1461,7 +1469,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case SqlBaseParser.MAP => MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) case SqlBaseParser.STRUCT => - createStructType(ctx.complexColTypeList()) + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) } } @@ -1480,7 +1488,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a [[StructField]] from a column definition. + * Create a top level [[StructField]] from a column definition. */ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { import ctx._ @@ -1491,19 +1499,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { builder.putString("comment", string(STRING)) } // Add Hive type string to metadata. - dataType match { - case p: PrimitiveDataTypeContext => - p.identifier.getText.toLowerCase match { - case "varchar" | "char" => - builder.putString(HIVE_TYPE_STRING, dataType.getText.toLowerCase) - case _ => - } - case _ => + val rawDataType = typedVisit[DataType](ctx.dataType) + val cleanedDataType = HiveStringType.replaceCharType(rawDataType) + if (rawDataType != cleanedDataType) { + builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString) } StructField( identifier.getText, - typedVisit(dataType), + cleanedDataType, nullable = true, builder.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala new file mode 100644 index 0000000000000..b319eb70bc13c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.unsafe.types.UTF8String + +/** + * A hive string type for compatibility. These datatypes should only used for parsing, + * and should NOT be used anywhere else. Any instance of these data types should be + * replaced by a [[StringType]] before analysis. + */ +sealed abstract class HiveStringType extends AtomicType { + private[sql] type InternalType = UTF8String + + private[sql] val ordering = implicitly[Ordering[InternalType]] + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { + typeTag[InternalType] + } + + override def defaultSize: Int = length + + private[spark] override def asNullable: HiveStringType = this + + def length: Int +} + +object HiveStringType { + def replaceCharType(dt: DataType): DataType = dt match { + case ArrayType(et, nullable) => + ArrayType(replaceCharType(et), nullable) + case MapType(kt, vt, nullable) => + MapType(replaceCharType(kt), replaceCharType(vt), nullable) + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = replaceCharType(field.dataType)) + }) + case _: HiveStringType => StringType + case _ => dt + } +} + +/** + * Hive char type. + */ +case class CharType(length: Int) extends HiveStringType { + override def simpleString: String = s"char($length)" +} + +/** + * Hive varchar type. + */ +case class VarcharType(length: Int) extends HiveStringType { + override def simpleString: String = s"varchar($length)" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 59ea8916efae9..11dda5425cf94 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -162,13 +162,16 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA |CREATE EXTERNAL TABLE hive_orc( | a STRING, | b CHAR(10), - | c VARCHAR(10)) + | c VARCHAR(10), + | d ARRAY) |STORED AS orc""".stripMargin) // Hive throws an exception if I assign the location in the create table statement. hiveClient.runSqlHive( s"ALTER TABLE hive_orc SET LOCATION '$uri'") hiveClient.runSqlHive( - "INSERT INTO TABLE hive_orc SELECT 'a', 'b', 'c' FROM (SELECT 1) t") + """INSERT INTO TABLE hive_orc + |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) + |FROM (SELECT 1) t""".stripMargin) // We create a different table in Spark using the same schema which points to // the same location. @@ -177,10 +180,11 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA |CREATE EXTERNAL TABLE spark_orc( | a STRING, | b CHAR(10), - | c VARCHAR(10)) + | c VARCHAR(10), + | d ARRAY) |STORED AS orc |LOCATION '$uri'""".stripMargin) - val result = Row("a", "b ", "c") + val result = Row("a", "b ", "c", Seq("d ")) checkAnswer(spark.table("hive_orc"), result) checkAnswer(spark.table("spark_orc"), result) } finally {