Skip to content

Commit

Permalink
draft version
Browse files Browse the repository at this point in the history
  • Loading branch information
scwf committed Dec 30, 2014
1 parent 040d6f2 commit 0ba70df
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,38 @@
package org.apache.spark.sql.json

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.sources._

private[sql] class DefaultSource extends RelationProvider {
/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

JSONRelation(fileName, samplingRatio)(sqlContext)
JSONRelation(fileName, samplingRatio, schema)(sqlContext)
}
}

private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)(
private[sql] case class JSONRelation(
fileName: String,
samplingRatio: Double,
userSpecifiedSchema: Option[StructType])(
@transient val sqlContext: SQLContext)
extends TableScan {

private def baseRDD = sqlContext.sparkContext.textFile(fileName)

override val schema =
userSpecifiedSchema.getOrElse(
JsonRDD.inferSchema(
baseRDD,
samplingRatio,
sqlContext.columnNameOfCorruptRecord)
)

override def buildScan() =
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,21 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate

import parquet.hadoop.ParquetInputFormat
import parquet.hadoop.util.ContextUtil

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Partition => SparkPartition, Logging}
import org.apache.spark.rdd.{NewHadoopPartition, RDD}

import org.apache.spark.sql.{SQLConf, Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType}
import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{SQLConf, SQLContext}

import scala.collection.JavaConversions._


/**
* Allows creation of parquet based tables using the syntax
* `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option
Expand All @@ -48,11 +47,12 @@ class DefaultSource extends RelationProvider {
/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
val path =
parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables."))

ParquetRelation2(path)(sqlContext)
ParquetRelation2(path, schema)(sqlContext)
}
}

Expand Down Expand Up @@ -82,7 +82,9 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files:
* discovery.
*/
@DeveloperApi
case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
case class ParquetRelation2(
path: String,
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
extends CatalystScan with Logging {

def sparkContext = sqlContext.sparkContext
Expand Down Expand Up @@ -133,12 +135,13 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)

override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum

val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
ParquetTypesConverter.readSchemaFromFile(
partitions.head.files.head.getPath,
Some(sparkContext.hadoopConfiguration),
sqlContext.isParquetBinaryAsString))

val dataSchema = userSpecifiedSchema.getOrElse(
StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
ParquetTypesConverter.readSchemaFromFile(
partitions.head.files.head.getPath,
Some(sparkContext.hadoopConfiguration),
sqlContext.isParquetBinaryAsString))
)
val dataIncludesKey =
partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true)

Expand Down
141 changes: 130 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

package org.apache.spark.sql.sources

import scala.language.implicitConversions
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
import scala.util.parsing.combinator.{RegexParsers, PackratParsers}

import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.util.Utils

import scala.language.implicitConversions
import scala.util.parsing.combinator.lexical.StdLexical
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
import scala.util.parsing.combinator.PackratParsers

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SqlLexical

Expand All @@ -49,6 +49,21 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)

// data types
protected val STRING = Keyword("STRING")
protected val DOUBLE = Keyword("DOUBLE")
protected val BOOLEAN = Keyword("BOOLEAN")
protected val FLOAT = Keyword("FLOAT")
protected val INT = Keyword("INT")
protected val TINYINT = Keyword("TINYINT")
protected val SMALLINT = Keyword("SMALLINT")
protected val BIGINT = Keyword("BIGINT")
protected val BINARY = Keyword("BINARY")
protected val DECIMAL = Keyword("DECIMAL")
protected val DATE = Keyword("DATE")
protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val VARCHAR = Keyword("VARCHAR")

protected val CREATE = Keyword("CREATE")
protected val TEMPORARY = Keyword("TEMPORARY")
protected val TABLE = Keyword("TABLE")
Expand All @@ -67,26 +82,129 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected lazy val ddl: Parser[LogicalPlan] = createTable

/**
* CREATE TEMPORARY TABLE avroTable
* `CREATE TEMPORARY TABLE avroTable
* USING org.apache.spark.sql.avro
* OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
* OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
* or
* `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...)
* USING org.apache.spark.sql.avro
* OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
*/
protected lazy val createTable: Parser[LogicalPlan] =
CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
( CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
case tableName ~ provider ~ opts =>
CreateTableUsing(tableName, provider, opts)
CreateTableUsing(tableName, Seq.empty, provider, opts)
}
|
CREATE ~ TEMPORARY ~ TABLE ~> ident
~ tableCols ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
case tableName ~ tableColumns ~ provider ~ opts =>
CreateTableUsing(tableName, tableColumns, provider, opts)
}
)

protected lazy val metastoreTypes = new MetastoreTypes

protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"

protected lazy val options: Parser[Map[String, String]] =
"(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }

protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}

protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) }

protected lazy val column: Parser[StructField] =
ident ~ ident ^^ { case name ~ typ =>
StructField(name, metastoreTypes.toDataType(typ))
}
}

/**
* :: DeveloperApi ::
* Provides a parser for data types.
*/
@DeveloperApi
private[sql] class MetastoreTypes extends RegexParsers {
protected lazy val primitiveType: Parser[DataType] =
"string" ^^^ StringType |
"float" ^^^ FloatType |
"int" ^^^ IntegerType |
"tinyint" ^^^ ByteType |
"smallint" ^^^ ShortType |
"double" ^^^ DoubleType |
"bigint" ^^^ LongType |
"binary" ^^^ BinaryType |
"boolean" ^^^ BooleanType |
fixedDecimalType | // Hive 0.13+ decimal with precision/scale
"decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale
"date" ^^^ DateType |
"timestamp" ^^^ TimestampType |
"varchar\\((\\d+)\\)".r ^^^ StringType

protected lazy val fixedDecimalType: Parser[DataType] =
("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ {
case precision ~ scale =>
DecimalType(precision.toInt, scale.toInt)
}

protected lazy val arrayType: Parser[DataType] =
"array" ~> "<" ~> dataType <~ ">" ^^ {
case tpe => ArrayType(tpe)
}

protected lazy val mapType: Parser[DataType] =
"map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
case t1 ~ _ ~ t2 => MapType(t1, t2)
}

protected lazy val structField: Parser[StructField] =
"[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ {
case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
}

protected lazy val structType: Parser[DataType] =
"struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ {
case fields => new StructType(fields)
}

private[sql] lazy val dataType: Parser[DataType] =
arrayType |
mapType |
structType |
primitiveType

def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match {
case Success(result, _) => result
case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType")
}

def toMetastoreType(dt: DataType): String = dt match {
case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
case StructType(fields) =>
s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
case MapType(keyType, valueType, _) =>
s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
case StringType => "string"
case FloatType => "float"
case IntegerType => "int"
case ByteType => "tinyint"
case ShortType => "smallint"
case DoubleType => "double"
case LongType => "bigint"
case BinaryType => "binary"
case BooleanType => "boolean"
case DateType => "date"
case d: DecimalType => "decimal"
case TimestampType => "timestamp"
case NullType => "void"
case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType)
}
}

private[sql] case class CreateTableUsing(
tableName: String,
tableCols: Seq[StructField],
provider: String,
options: Map[String, String]) extends RunnableCommand {

Expand All @@ -100,7 +218,8 @@ private[sql] case class CreateTableUsing(
}
}
val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
val relation = dataSource.createRelation(
sqlContext, new CaseInsensitiveMap(options), Some(StructType(tableCols)))

sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ trait RelationProvider {
* Note: the parameters' keywords are case insensitive and this insensitivity is enforced
* by the Map that is passed to the function.
*/
def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation
}

/**
Expand Down
Loading

0 comments on commit 0ba70df

Please sign in to comment.