Skip to content

Commit

Permalink
Merge pull request alteryx#22 from yhuai/pr3431yin
Browse files Browse the repository at this point in the history
Remove Option from createRelation.
  • Loading branch information
scwf committed Jan 10, 2015
2 parents a852b10 + 38f634e commit 7e79ce5
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,27 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.sources._

private[sql] class DefaultSource extends SchemaRelationProvider {
/** Returns a new base relation with the given parameters. */
private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {

/** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

JSONRelation(fileName, samplingRatio, None)(sqlContext)
}

/** Returns a new base relation with the given schema and parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
schema: StructType): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)

JSONRelation(fileName, samplingRatio, schema)(sqlContext)
JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,37 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate

import parquet.hadoop.ParquetInputFormat
import parquet.hadoop.util.ContextUtil

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Partition => SparkPartition, Logging}
import org.apache.spark.rdd.{NewHadoopPartition, RDD}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate

import org.apache.spark.sql.{SQLConf, Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{SQLConf, SQLContext}

import scala.collection.JavaConversions._


/**
* Allows creation of parquet based tables using the syntax
* `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option
* required is `path`, which should be the location of a collection of, optionally partitioned,
* parquet files.
*/
class DefaultSource extends SchemaRelationProvider {
class DefaultSource extends RelationProvider {
/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
parameters: Map[String, String]): BaseRelation = {
val path =
parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables."))

ParquetRelation2(path, schema)(sqlContext)
ParquetRelation2(path)(sqlContext)
}
}

Expand Down Expand Up @@ -82,9 +82,7 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files:
* discovery.
*/
@DeveloperApi
case class ParquetRelation2(
path: String,
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
extends CatalystScan with Logging {

def sparkContext = sqlContext.sparkContext
Expand Down Expand Up @@ -135,13 +133,12 @@ case class ParquetRelation2(

override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum

val dataSchema = userSpecifiedSchema.getOrElse(
StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
ParquetTypesConverter.readSchemaFromFile(
partitions.head.files.head.getPath,
Some(sparkContext.hadoopConfiguration),
sqlContext.isParquetBinaryAsString))
)
val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
ParquetTypesConverter.readSchemaFromFile(
partitions.head.files.head.getPath,
Some(sparkContext.hadoopConfiguration),
sqlContext.isParquetBinaryAsString))

val dataIncludesKey =
partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true)

Expand Down
31 changes: 22 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,28 @@ private[sql] case class CreateTableUsing(
sys.error(s"Failed to load class for data source: $provider")
}
}
val relation = clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options))
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options), userSpecifiedSchema)

val relation = userSpecifiedSchema match {
case Some(schema: StructType) => {
clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
case _ =>
sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
}
}
case None => {
clazz.newInstance match {
case dataSource: org.apache.spark.sql.sources.RelationProvider =>
dataSource
.asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
.createRelation(sqlContext, new CaseInsensitiveMap(options))
case _ =>
sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
}
}
}

sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ trait SchemaRelationProvider {
def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation
schema: StructType): BaseRelation
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ class AllDataTypesScanSource extends SchemaRelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: Option[StructType]): BaseRelation = {
schema: StructType): BaseRelation = {
AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
}
}

case class AllDataTypesScan(
from: Int,
to: Int,
userSpecifiedSchema: Option[StructType])(@transient val sqlContext: SQLContext)
userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
extends TableScan {

override def schema = userSpecifiedSchema.get
override def schema = userSpecifiedSchema

override def buildScan() = {
sqlContext.sparkContext.parallelize(from to to).map { i =>
Expand Down

0 comments on commit 7e79ce5

Please sign in to comment.