diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index ab013885b7086..49a9382037a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -26,7 +26,7 @@ private[sql] class DefaultSource extends SchemaRelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType] = None): BaseRelation = { + schema: Option[StructType]): BaseRelation = { val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 6a4a41686eb28..506be8ccde6b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -48,7 +48,7 @@ class DefaultSource extends SchemaRelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType] = None): BaseRelation = { + schema: Option[StructType]): BaseRelation = { val path = parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index f5b72f3c4ca52..457cbbb39abd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -64,14 +64,14 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi // Data types. protected val STRING = Keyword("STRING") - protected val FLOAT = Keyword("FLOAT") - protected val INT = Keyword("INT") + protected val BINARY = Keyword("BINARY") + protected val BOOLEAN = Keyword("BOOLEAN") protected val TINYINT = Keyword("TINYINT") protected val SMALLINT = Keyword("SMALLINT") - protected val DOUBLE = Keyword("DOUBLE") + protected val INT = Keyword("INT") protected val BIGINT = Keyword("BIGINT") - protected val BINARY = Keyword("BINARY") - protected val BOOLEAN = Keyword("BOOLEAN") + protected val FLOAT = Keyword("FLOAT") + protected val DOUBLE = Keyword("DOUBLE") protected val DECIMAL = Keyword("DECIMAL") protected val DATE = Keyword("DATE") protected val TIMESTAMP = Keyword("TIMESTAMP") @@ -105,8 +105,8 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { case tableName ~ columns ~ provider ~ opts => - val tblColumns = if(columns.isEmpty) Seq.empty else columns.get - CreateTableUsing(tableName, tblColumns, provider, opts) + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing(tableName, userSpecifiedSchema, provider, opts) } ) @@ -184,7 +184,7 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi private[sql] case class CreateTableUsing( tableName: String, - tableCols: Seq[StructField], + userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -203,16 +203,9 @@ private[sql] case class CreateTableUsing( .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] .createRelation(sqlContext, new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - if(tableCols.isEmpty) { - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) - } else { - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation( - sqlContext, new CaseInsensitiveMap(options), Some(StructType(tableCols))) - } + dataSource + .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] + .createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema) } sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 1ad82ecbb6ee6..97157c868cc90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -68,7 +68,7 @@ trait SchemaRelationProvider { def createRelation( sqlContext: SQLContext, parameters: Map[String, String], - schema: Option[StructType] = None): BaseRelation + schema: Option[StructType]): BaseRelation } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/NewTableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/NewTableScanSuite.scala deleted file mode 100644 index 8272c57c29131..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/NewTableScanSuite.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.sources - -import org.apache.spark.sql._ -import java.sql.{Timestamp, Date} -import org.apache.spark.sql.execution.RDDConversions - -case class AllDataTypesData( - stringField: String, - intField: Int, - longField: Long, - floatField: Float, - doubleField: Double, - shortField: Short, - byteField: Byte, - booleanField: Boolean, - decimalField: BigDecimal, - date: Date, - timestampField: Timestamp, - arrayFiled: Seq[Int], - mapField: Map[Int, String], - structField: Row) - -class AllDataTypesScanSource extends SchemaRelationProvider { - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: Option[StructType] = None): BaseRelation = { - AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) - } -} - -case class AllDataTypesScan( - from: Int, - to: Int, - userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext) - extends TableScan { - - override def schema = userSpecifiedSchema.get - - override def buildScan() = { - val rdd = sqlContext.sparkContext.parallelize(from to to).map { i => - AllDataTypesData( - i.toString, - i, - i.toLong, - i.toFloat, - i.toDouble, - i.toShort, - i.toByte, - true, - BigDecimal(i), - new Date(12345), - new Timestamp(12345), - Seq(i, i+1), - Map(i -> i.toString), - Row(i, i.toString)) - } - - RDDConversions.productToRowRdd(rdd, schema) - } - -} - -class NewTableScanSuite extends DataSourceTest { - import caseInsensisitiveContext._ - - var records = (1 to 10).map { i => - Row( - i.toString, - i, - i.toLong, - i.toFloat, - i.toDouble, - i.toShort, - i.toByte, - true, - BigDecimal(i), - new Date(12345), - new Timestamp(12345), - Seq(i, i+1), - Map(i -> i.toString), - Row(i, i.toString)) - }.toSeq - - before { - sql( - """ - |CREATE TEMPORARY TABLE oneToTen(stringField stRIng, intField iNt, longField Bigint, - |floatField flOat, doubleField doubLE, shortField smaLlint, byteField tinyint, - |booleanField boolean, decimalField decimal(10,2), dateField dAte, - |timestampField tiMestamp, arrayField Array, mapField MAP, - |structField StRuct) - |USING org.apache.spark.sql.sources.AllDataTypesScanSource - |OPTIONS ( - | From '1', - | To '10' - |) - """.stripMargin) - } - - sqlTest( - "SELECT * FROM oneToTen", - records) - - sqlTest( - "SELECT count(*) FROM oneToTen", - 10) - - sqlTest( - "SELECT stringField FROM oneToTen", - (1 to 10).map(i =>Row(i.toString)).toSeq) - - sqlTest( - "SELECT intField FROM oneToTen WHERE intField < 5", - (1 to 4).map(Row(_)).toSeq) - - sqlTest( - "SELECT longField * 2 FROM oneToTen", - (1 to 10).map(i => Row(i * 2.toLong)).toSeq) - - sqlTest( - """SELECT a.floatField, b.floatField FROM oneToTen a JOIN oneToTen b - |ON a.floatField = b.floatField + 1""".stripMargin, - (2 to 10).map(i => Row(i.toFloat, i - 1.toFloat)).toSeq) - - sqlTest( - "SELECT distinct(a.dateField) FROM oneToTen a", - Some(new Date(12345)).map(Row(_)).toSeq) - - sqlTest( - "SELECT distinct(a.timestampField) FROM oneToTen a", - Some(new Timestamp(12345)).map(Row(_)).toSeq) - - sqlTest( - "SELECT distinct(arrayField) FROM oneToTen a where intField=1", - Some(Seq(1, 2)).map(Row(_)).toSeq) - - sqlTest( - "SELECT distinct(mapField) FROM oneToTen a where intField=1", - Some(Map(1 -> 1.toString)).map(Row(_)).toSeq) - - sqlTest( - "SELECT distinct(structField) FROM oneToTen a where intField=1", - Some(Row(1, "1")).map(Row(_)).toSeq) - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 3cd7b0115d567..26191a8a5c769 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.sources +import java.sql.{Timestamp, Date} + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.types.DecimalType class DefaultSource extends SimpleScanSource @@ -38,9 +41,77 @@ case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } +class AllDataTypesScanSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { + AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) + } +} + +case class AllDataTypesScan( + from: Int, + to: Int, + userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = userSpecifiedSchema.get + + override def buildScan() = { + sqlContext.sparkContext.parallelize(from to to).map { i => + Row( + s"str_$i", + s"str_$i".getBytes(), + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + i.toFloat, + i.toDouble, + BigDecimal(i), + BigDecimal(i), + new Date(10000 + i), + new Timestamp(20000 + i), + s"varchar_$i", + Seq(i, i + 1), + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(30000 + i))))) + } + } +} + class TableScanSuite extends DataSourceTest { import caseInsensisitiveContext._ + var tableWithSchemaExpected = (1 to 10).map { i => + Row( + s"str_$i", + s"str_$i", + i % 2 == 0, + i.toByte, + i.toShort, + i, + i.toLong, + i.toFloat, + i.toDouble, + BigDecimal(i), + BigDecimal(i), + new Date(10000 + i), + new Timestamp(20000 + i), + s"varchar_$i", + Seq(i, i + 1), + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(30000 + i))))) + }.toSeq + before { sql( """ @@ -51,6 +122,37 @@ class TableScanSuite extends DataSourceTest { | To '10' |) """.stripMargin) + + sql( + """ + |CREATE TEMPORARY TABLE tableWithSchema ( + |stringField stRIng, + |binaryField binary, + |booleanField boolean, + |byteField tinyint, + |shortField smaLlint, + |intField iNt, + |longField Bigint, + |floatField flOat, + |doubleField doubLE, + |decimalField1 decimal, + |decimalField2 decimal(9,2), + |dateField dAte, + |timestampField tiMestamp, + |varcharField varchaR(12), + |arrayFieldSimple Array, + |arrayFieldComplex Array>>, + |mapFieldSimple MAP, + |mapFieldComplex Map, Struct>, + |structFieldSimple StRuct, + |structFieldComplex StRuct, Value:struct>> + |) + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) } sqlTest( @@ -73,6 +175,91 @@ class TableScanSuite extends DataSourceTest { "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", (2 to 10).map(i => Row(i, i - 1)).toSeq) + test("Schema and all fields") { + val expectedSchema = StructType( + StructField("stringField", StringType, true) :: + StructField("binaryField", BinaryType, true) :: + StructField("booleanField", BooleanType, true) :: + StructField("byteField", ByteType, true) :: + StructField("shortField", ShortType, true) :: + StructField("intField", IntegerType, true) :: + StructField("longField", LongType, true) :: + StructField("floatField", FloatType, true) :: + StructField("doubleField", DoubleType, true) :: + StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField2", DecimalType(9, 2), true) :: + StructField("dateField", DateType, true) :: + StructField("timestampField", TimestampType, true) :: + StructField("varcharField", StringType, true) :: + StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: + StructField("arrayFieldComplex", + ArrayType( + MapType(StringType, StructType(StructField("key", LongType, true) :: Nil))), true) :: + StructField("mapFieldSimple", MapType(IntegerType, StringType), true) :: + StructField("mapFieldComplex", + MapType( + MapType(StringType, FloatType), + StructType(StructField("key", LongType, true) :: Nil)), true) :: + StructField("structFieldSimple", + StructType( + StructField("key", IntegerType, true) :: + StructField("Value", StringType, true) :: Nil), true) :: + StructField("structFieldComplex", + StructType( + StructField("key", ArrayType(StringType), true) :: + StructField("Value", + StructType( + StructField("value", ArrayType(DateType), true) :: Nil), true) :: Nil), true) :: Nil + ) + + assert(expectedSchema == table("tableWithSchema").schema) + + checkAnswer( + sql( + """SELECT + | stringField, + | cast(binaryField as string), + | booleanField, + | byteField, + | shortField, + | intField, + | longField, + | floatField, + | doubleField, + | decimalField1, + | decimalField2, + | dateField, + | timestampField, + | varcharField, + | arrayFieldSimple, + | arrayFieldComplex, + | mapFieldSimple, + | mapFieldComplex, + | structFieldSimple, + | structFieldComplex FROM tableWithSchema""".stripMargin), + tableWithSchemaExpected + ) + } + + sqlTest( + "SELECT count(*) FROM tableWithSchema", + 10) + + sqlTest( + "SELECT stringField FROM tableWithSchema", + (1 to 10).map(i => Row(s"str_$i")).toSeq) + + sqlTest( + "SELECT intField FROM tableWithSchema WHERE intField < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT longField * 2 FROM tableWithSchema", + (1 to 10).map(i => Row(i * 2.toLong)).toSeq) + + sqlTest( + "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where intField=1", + Seq(Seq(1, 2))) test("Caching") { // Cached Query Execution