Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SQL] SPARK-1366 Consistent sql function across different types of SQLContexts #319

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ object HiveFromSpark {
val hiveContext = new LocalHiveContext(sc)
import hiveContext._

sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")
hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")

// Queries are expressed in HiveQL
println("Result of 'SELECT *': ")
sql("SELECT * FROM src").collect.foreach(println)
hql("SELECT * FROM src").collect.foreach(println)

// Aggregation queries are also supported.
val count = sql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
println(s"COUNT(*): $count")

// The results of SQL queries are themselves RDDs and support all normal RDD functions. The
// items in the RDD are of type Row, which allows you to access each column by ordinal.
val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")
val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")

println("Result of RDD.map:")
val rddAsStrings = rddFromSql.map {
Expand All @@ -59,6 +59,6 @@ object HiveFromSpark {

// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
class HiveContext(sc: SparkContext) extends SQLContext(sc) {
self =>

override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql)
override def executePlan(plan: LogicalPlan): this.QueryExecution =
override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }

/**
* Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD.
*/
def hql(hqlQuery: String): SchemaRDD = {
def hiveql(hqlQuery: String): SchemaRDD = {
val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery))
// We force query optimization to happen right away instead of letting it happen lazily like
// when using the query DSL. This is so DDL commands behave as expected. This is only
Expand All @@ -83,6 +82,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
result
}

/** An alias for `hiveql`. */
def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery)

// Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
@transient
protected val outputBuffer = new java.io.OutputStream {
Expand Down Expand Up @@ -120,7 +122,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override def lookupRelation(
databaseName: Option[String],
tableName: String,
Expand All @@ -132,7 +134,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/* An analyzer that uses the Hive metastore. */
@transient
override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
override protected[sql] lazy val analyzer =
new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)

/**
* Runs the specified SQL query using Hive.
Expand Down Expand Up @@ -214,14 +217,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}

@transient
override val planner = hivePlanner
override protected[sql] val planner = hivePlanner

@transient
protected lazy val emptyResult =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)

/** Extends QueryExecution with hive specific features. */
abstract class QueryExecution extends super.QueryExecution {
protected[sql] abstract class QueryExecution extends super.QueryExecution {
// TODO: Create mixin for the analyzer instead of overriding things here.
override lazy val optimizedPlan =
optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
Expand Down
12 changes: 6 additions & 6 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {

val describedTable = "DESCRIBE (\\w+)".r

class SqlQueryExecution(sql: String) extends this.QueryExecution {
lazy val logical = HiveQl.parseSql(sql)
def hiveExec() = runSqlHive(sql)
override def toString = sql + "\n" + super.toString
protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
lazy val logical = HiveQl.parseSql(hql)
def hiveExec() = runSqlHive(hql)
override def toString = hql + "\n" + super.toString
}

/**
Expand All @@ -140,8 +140,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {

case class TestTable(name: String, commands: (()=>Unit)*)

implicit class SqlCmd(sql: String) {
def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit
protected[hive] implicit class SqlCmd(sql: String) {
def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ abstract class HiveComparisonTest
}

protected def prepareAnswer(
hiveQuery: TestHive.type#SqlQueryExecution,
hiveQuery: TestHive.type#HiveQLQueryExecution,
answer: Seq[String]): Seq[String] = {
val orderedAnswer = hiveQuery.logical match {
// Clean out non-deterministic time schema info.
Expand Down Expand Up @@ -227,7 +227,7 @@ abstract class HiveComparisonTest

try {
// MINOR HACK: You must run a query before calling reset the first time.
TestHive.sql("SHOW TABLES")
TestHive.hql("SHOW TABLES")
if (reset) { TestHive.reset() }

val hiveCacheFiles = queryList.zipWithIndex.map {
Expand Down Expand Up @@ -256,7 +256,7 @@ abstract class HiveComparisonTest
hiveCachedResults
} else {

val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_))
val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_))
// Make sure we can at least parse everything before attempting hive execution.
hiveQueries.foreach(_.logical)
val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map {
Expand Down Expand Up @@ -302,7 +302,7 @@ abstract class HiveComparisonTest

// Run w/ catalyst
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
val query = new TestHive.SqlQueryExecution(queryString)
val query = new TestHive.HiveQLQueryExecution(queryString)
try { (query, prepareAnswer(query, query.stringResult())) } catch {
case e: Exception =>
val errorMessage =
Expand Down Expand Up @@ -359,7 +359,7 @@ abstract class HiveComparisonTest
// When we encounter an error we check to see if the environment is still okay by running a simple query.
// If this fails then we halt testing since something must have gone seriously wrong.
try {
new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult()
new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult()
TestHive.runSqlHive("SELECT key FROM src")
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ import org.apache.spark.sql.hive.TestHive._
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
*/
class HiveQuerySuite extends HiveComparisonTest {

test("Query expressed in SQL") {
assert(sql("SELECT 1").collect() === Array(Seq(1)))
}

test("Query expressed in HiveQL") {
hql("FROM src SELECT key").collect()
hiveql("FROM src SELECT key").collect()
}

createQueryTest("Simple Average",
"SELECT AVG(key) FROM src")

Expand Down Expand Up @@ -133,7 +143,7 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v")

test("sampling") {
sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil)
.registerAsTable("caseSensitivityTest")

sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class PruningSuite extends HiveComparisonTest {
expectedScannedColumns: Seq[String],
expectedPartValues: Seq[Seq[String]]) = {
test(s"$testCaseName - pruning test") {
val plan = new TestHive.SqlQueryExecution(sql).executedPlan
val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan
val actualOutputColumns = plan.output.map(_.name)
val (actualScannedColumns, actualPartValues) = plan.collect {
case p @ HiveTableScan(columns, relation, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,34 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
}

test("SELECT on Parquet table") {
val rdd = sql("SELECT * FROM testsource").collect()
val rdd = hql("SELECT * FROM testsource").collect()
assert(rdd != null)
assert(rdd.forall(_.size == 6))
}

test("Simple column projection + filter on Parquet table") {
val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
assert(rdd.size === 5, "Filter returned incorrect number of rows")
assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value")
}

test("Converting Hive to Parquet Table via saveAsParquetFile") {
sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0))
val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0))
val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0))
val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0))
compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String"))
}

test("INSERT OVERWRITE TABLE Parquet table") {
sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
// let's do three overwrites for good measure
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
val rddCopy = sql("SELECT * FROM ptable").collect()
val rddOrig = sql("SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
val rddCopy = hql("SELECT * FROM ptable").collect()
val rddOrig = hql("SELECT * FROM testsource").collect()
assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??")
compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames)
}
Expand All @@ -93,13 +93,13 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmp")
val rddCopy =
sql("INSERT INTO TABLE tmp SELECT * FROM src")
hql("INSERT INTO TABLE tmp SELECT * FROM src")
.collect()
.sortBy[Int](_.apply(0) match {
case x: Int => x
case _ => 0
})
val rddOrig = sql("SELECT * FROM src")
val rddOrig = hql("SELECT * FROM src")
.collect()
.sortBy(_.getInt(0))
compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String"))
Expand All @@ -108,22 +108,22 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
test("Appending to Parquet table") {
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmpnew")
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
val rddCopies = sql("SELECT * FROM tmpnew").collect()
val rddOrig = sql("SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
val rddCopies = hql("SELECT * FROM tmpnew").collect()
val rddOrig = hql("SELECT * FROM src").collect()
assert(rddCopies.size === 3 * rddOrig.size, "number of copied rows via INSERT INTO did not match correct number")
}

test("Appending to and then overwriting Parquet table") {
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmp")
sql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
sql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect()
val rddCopies = sql("SELECT * FROM tmp").collect()
val rddOrig = sql("SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
hql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect()
val rddCopies = hql("SELECT * FROM tmp").collect()
val rddOrig = hql("SELECT * FROM src").collect()
assert(rddCopies.size === rddOrig.size, "INSERT OVERWRITE did not actually overwrite")
}

Expand Down