diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index fc70c183437f6..52b79c1ca8c29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -18,31 +18,38 @@ package org.apache.spark.sql.json import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources._ private[sql] class DefaultSource extends RelationProvider { /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio)(sqlContext) + JSONRelation(fileName, samplingRatio, schema)(sqlContext) } } -private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( +private[sql] case class JSONRelation( + fileName: String, + samplingRatio: Double, + userSpecifiedSchema: Option[StructType])( @transient val sqlContext: SQLContext) extends TableScan { private def baseRDD = sqlContext.sparkContext.textFile(fileName) override val schema = + userSpecifiedSchema.getOrElse( JsonRDD.inferSchema( baseRDD, samplingRatio, sqlContext.columnNameOfCorruptRecord) + ) override def buildScan() = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 2e0c6c51c00e5..a237d27794a9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -22,22 +22,21 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate - import parquet.hadoop.ParquetInputFormat import parquet.hadoop.util.ContextUtil import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{Partition => SparkPartition, Logging} import org.apache.spark.rdd.{NewHadoopPartition, RDD} - -import org.apache.spark.sql.{SQLConf, Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType} +import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.{SQLConf, SQLContext} import scala.collection.JavaConversions._ + /** * Allows creation of parquet based tables using the syntax * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option @@ -48,11 +47,12 @@ class DefaultSource extends RelationProvider { /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { val path = parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) - ParquetRelation2(path)(sqlContext) + ParquetRelation2(path, schema)(sqlContext) } } @@ -82,7 +82,9 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files: * discovery. */ @DeveloperApi -case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) +case class ParquetRelation2( + path: String, + userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext) extends CatalystScan with Logging { def sparkContext = sqlContext.sparkContext @@ -133,12 +135,13 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum - val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.isParquetBinaryAsString)) - + val dataSchema = userSpecifiedSchema.getOrElse( + StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. + ParquetTypesConverter.readSchemaFromFile( + partitions.head.files.head.getPath, + Some(sparkContext.hadoopConfiguration), + sqlContext.isParquetBinaryAsString)) + ) val dataIncludesKey = partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 8a66ac31f2dfb..69fa64affd961 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.sources +import scala.language.implicitConversions +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.{RegexParsers, PackratParsers} + import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.util.Utils - -import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers - import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical @@ -49,6 +49,21 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + // data types + protected val STRING = Keyword("STRING") + protected val DOUBLE = Keyword("DOUBLE") + protected val BOOLEAN = Keyword("BOOLEAN") + protected val FLOAT = Keyword("FLOAT") + protected val INT = Keyword("INT") + protected val TINYINT = Keyword("TINYINT") + protected val SMALLINT = Keyword("SMALLINT") + protected val BIGINT = Keyword("BIGINT") + protected val BINARY = Keyword("BINARY") + protected val DECIMAL = Keyword("DECIMAL") + protected val DATE = Keyword("DATE") + protected val TIMESTAMP = Keyword("TIMESTAMP") + protected val VARCHAR = Keyword("VARCHAR") + protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") @@ -67,15 +82,30 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val ddl: Parser[LogicalPlan] = createTable /** - * CREATE TEMPORARY TABLE avroTable + * `CREATE TEMPORARY TABLE avroTable * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...) + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` */ protected lazy val createTable: Parser[LogicalPlan] = - CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + ( CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { case tableName ~ provider ~ opts => - CreateTableUsing(tableName, provider, opts) + CreateTableUsing(tableName, Seq.empty, provider, opts) + } + | + CREATE ~ TEMPORARY ~ TABLE ~> ident + ~ tableCols ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ tableColumns ~ provider ~ opts => + CreateTableUsing(tableName, tableColumns, provider, opts) } + ) + + protected lazy val metastoreTypes = new MetastoreTypes + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" protected lazy val options: Parser[Map[String, String]] = "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } @@ -83,10 +113,98 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + + protected lazy val column: Parser[StructField] = + ident ~ ident ^^ { case name ~ typ => + StructField(name, metastoreTypes.toDataType(typ)) + } +} + +/** + * :: DeveloperApi :: + * Provides a parser for data types. + */ +@DeveloperApi +private[sql] class MetastoreTypes extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + "string" ^^^ StringType | + "float" ^^^ FloatType | + "int" ^^^ IntegerType | + "tinyint" ^^^ ByteType | + "smallint" ^^^ ShortType | + "double" ^^^ DoubleType | + "bigint" ^^^ LongType | + "binary" ^^^ BinaryType | + "boolean" ^^^ BooleanType | + fixedDecimalType | // Hive 0.13+ decimal with precision/scale + "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale + "date" ^^^ DateType | + "timestamp" ^^^ TimestampType | + "varchar\\((\\d+)\\)".r ^^^ StringType + + protected lazy val fixedDecimalType: Parser[DataType] = + ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "array" ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } + + protected lazy val mapType: Parser[DataType] = + "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + "[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ { + case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { + case fields => new StructType(fields) + } + + private[sql] lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match { + case Success(result, _) => result + case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType") + } + + def toMetastoreType(dt: DataType): String = dt match { + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" + case StructType(fields) => + s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" + case MapType(keyType, valueType, _) => + s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" + case StringType => "string" + case FloatType => "float" + case IntegerType => "int" + case ByteType => "tinyint" + case ShortType => "smallint" + case DoubleType => "double" + case LongType => "bigint" + case BinaryType => "binary" + case BooleanType => "boolean" + case DateType => "date" + case d: DecimalType => "decimal" + case TimestampType => "timestamp" + case NullType => "void" + case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) + } } private[sql] case class CreateTableUsing( tableName: String, + tableCols: Seq[StructField], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -100,7 +218,8 @@ private[sql] case class CreateTableUsing( } } val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + val relation = dataSource.createRelation( + sqlContext, new CaseInsensitiveMap(options), Some(StructType(tableCols))) sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 02eff80456dbe..5f9e8a35ef84e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -41,7 +41,10 @@ trait RelationProvider { * Note: the parameters' keywords are case insensitive and this insensitivity is enforced * by the Map that is passed to the function. */ - def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation + def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index b31a3ec25096b..accdaf591b5ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,12 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} -import org.apache.spark.sql.execution.SparkPlan - -import scala.util.parsing.combinator.RegexParsers - import org.apache.hadoop.util.ReflectionUtils - import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} @@ -36,7 +31,6 @@ import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -44,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.sources.MetastoreTypes import org.apache.spark.util.Utils /* Implicit conversions */ @@ -386,88 +381,6 @@ private[hive] case class InsertIntoHiveTable( } } -/** - * :: DeveloperApi :: - * Provides conversions between Spark SQL data types and Hive Metastore types. - */ -@DeveloperApi -object HiveMetastoreTypes extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "string" ^^^ StringType | - "float" ^^^ FloatType | - "int" ^^^ IntegerType | - "tinyint" ^^^ ByteType | - "smallint" ^^^ ShortType | - "double" ^^^ DoubleType | - "bigint" ^^^ LongType | - "binary" ^^^ BinaryType | - "boolean" ^^^ BooleanType | - fixedDecimalType | // Hive 0.13+ decimal with precision/scale - "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale - "date" ^^^ DateType | - "timestamp" ^^^ TimestampType | - "varchar\\((\\d+)\\)".r ^^^ StringType - - protected lazy val fixedDecimalType: Parser[DataType] = - ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { - case precision ~ scale => - DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } - - protected lazy val mapType: Parser[DataType] = - "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } - - protected lazy val structField: Parser[StructField] = - "[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ { - case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) - } - - protected lazy val structType: Parser[DataType] = - "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ { - case fields => new StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType - - def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType") - } - - def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" - case StructType(fields) => - s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType, _) => - s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" - case StringType => "string" - case FloatType => "float" - case IntegerType => "int" - case ByteType => "tinyint" - case ShortType => "smallint" - case DoubleType => "double" - case LongType => "bigint" - case BinaryType => "binary" - case BooleanType => "boolean" - case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) - case TimestampType => "timestamp" - case NullType => "void" - case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) - } -} - private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: TTable, val partitions: Seq[TPartition]) @@ -545,3 +458,28 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) } + + +object HiveMetastoreTypes extends MetastoreTypes { + override def toMetastoreType(dt: DataType): String = dt match { + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" + case StructType(fields) => + s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" + case MapType(keyType, valueType, _) => + s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" + case StringType => "string" + case FloatType => "float" + case IntegerType => "int" + case ByteType => "tinyint" + case ShortType => "smallint" + case DoubleType => "double" + case LongType => "bigint" + case BinaryType => "binary" + case BooleanType => "boolean" + case DateType => "date" + case d: DecimalType => HiveShim.decimalMetastoreString(d) + case TimestampType => "timestamp" + case NullType => "void" + case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) + } +}