Skip to content

Commit

Permalink
Hive tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jan 20, 2015
1 parent 15681c2 commit 9cdeb7d
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ class JoinedRow extends Row {
if ((row1 eq null) && (row2 eq null)) {
"[ empty row ]"
} else if (row1 eq null) {
"[" + row2.mkString(",") + "]"
row2.mkString("[", ",", "]")
} else if (row2 eq null) {
"[" + row1.mkString(",") + "]"
row1.mkString("[", ",", "]")
} else {
"[" + mkString(",") + "]"
mkString("[", ",", "]")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ package object debug {
case (null, _) =>

case (row: Row, StructType(fields)) =>
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) }
row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
case (s: Seq[_], ArrayType(elemType, _)) =>
s.foreach(typeCheck(_, elemType))
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ private[hive] trait HiveInspectors {
}

def wrap(
row: Seq[Any],
row: Row,
inspectors: Seq[ObjectInspector],
cache: Array[AnyRef]): Array[AnyRef] = {
var i = 0
Expand All @@ -486,6 +486,18 @@ private[hive] trait HiveInspectors {
cache
}

def wrap(
row: Seq[Any],
inspectors: Seq[ObjectInspector],
cache: Array[AnyRef]): Array[AnyRef] = {
var i = 0
while (i < inspectors.length) {
cache(i) = wrap(row(i), inspectors(i))
i += 1
}
cache
}

/**
* @param dataType Catalyst data type
* @return Hive java object inspector (recursively), not the Writable ObjectInspector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ private[hive] case class HiveGenericUdtf(
val inputProjection = new InterpretedProjection(children)
val collector = new UDTFCollector
function.setCollector(collector)
function.process(wrap(inputProjection(input).toSeq, inputInspectors, udtInput))
function.process(wrap(inputProjection(input), inputInspectors, udtInput))
collector.collectRows()
}

Expand Down Expand Up @@ -360,7 +360,7 @@ private[hive] case class HiveUdafFunction(
protected lazy val cached = new Array[AnyRef](exprs.length)

def update(input: Row): Unit = {
val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, inspectors, cached))
}
}
Expand Down
48 changes: 32 additions & 16 deletions sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util._
* So, we duplicate this code here.
*/
class QueryTest extends PlanTest {

/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
Expand All @@ -56,29 +57,33 @@ class QueryTest extends PlanTest {
* @param rdd the [[SchemaRDD]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = {
val convertedAnswer = expectedAnswer match {
case s: Seq[_] if s.isEmpty => s
case s: Seq[_] if s.head.isInstanceOf[Product] &&
!s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq)
case s: Seq[_] => s
case singleItem => Seq(Seq(singleItem))
protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case o => o
})
}
if (!isSorted) converted.sortBy(_.toString) else converted
}

val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty
def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
fail(
s"""
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
|${stackTraceToString(e)}
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}

if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
Expand All @@ -88,11 +93,22 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
s"== Correct Answer - ${convertedAnswer.size} ==" +:
prepareAnswer(convertedAnswer).map(_.toString),
s"== Spark Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
s"== Correct Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString),
s"== Spark Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
""".stripMargin)
}
}

protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
checkAnswer(rdd, Seq(expectedAnswer))
}

def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
testData.collect().toSeq
testData.collect().toSeq.map(Row.fromTuple)
)

// Add more data.
Expand All @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
testData.collect().toSeq ++ testData.collect().toSeq
testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq
)

// Now overwrite.
Expand All @@ -61,7 +61,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the registered table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
testData.collect().toSeq
testData.collect().toSeq.map(Row.fromTuple)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {

checkAnswer(
sql("SELECT * FROM jsonTable"),
("a", "b") :: Nil)
Row("a", "b"))

FileUtils.deleteDirectory(tempDir)
sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
Expand All @@ -164,14 +164,14 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
// will show.
checkAnswer(
sql("SELECT * FROM jsonTable"),
("a1", "b1") :: Nil)
Row("a1", "b1"))

refreshTable("jsonTable")

// Check that the refresh worked
checkAnswer(
sql("SELECT * FROM jsonTable"),
("a1", "b1", "c1") :: Nil)
Row("a1", "b1", "c1"))
FileUtils.deleteDirectory(tempDir)
}

Expand All @@ -191,7 +191,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {

checkAnswer(
sql("SELECT * FROM jsonTable"),
("a", "b") :: Nil)
Row("a", "b"))

FileUtils.deleteDirectory(tempDir)
sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
Expand All @@ -210,7 +210,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
// New table should reflect new schema.
checkAnswer(
sql("SELECT * FROM jsonTable"),
("a", "b", "c") :: Nil)
Row("a", "b", "c"))
FileUtils.deleteDirectory(tempDir)
}

Expand Down Expand Up @@ -253,6 +253,6 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|)
""".stripMargin)

sql("DROP TABLE jsonTable").collect.foreach(println)
sql("DROP TABLE jsonTable").collect().foreach(println)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll

import scala.reflect.ClassTag

import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.{Row, SQLConf, QueryTest}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
Expand Down Expand Up @@ -141,7 +141,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
before: () => Unit,
after: () => Unit,
query: String,
expectedAnswer: Seq[Any],
expectedAnswer: Seq[Row],
ct: ClassTag[_]) = {
before()

Expand Down Expand Up @@ -183,7 +183,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {

/** Tests for MetastoreRelation */
val metastoreQuery = """SELECT * FROM src a JOIN src b ON a.key = 238 AND a.key = b.key"""
val metastoreAnswer = Seq.fill(4)((238, "val_238", 238, "val_238"))
val metastoreAnswer = Seq.fill(4)(Row(238, "val_238", 238, "val_238"))
mkTest(
() => (),
() => (),
Expand All @@ -197,7 +197,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
val leftSemiJoinQuery =
"""SELECT * FROM src a
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
val answer = (86, "val_86") :: Nil
val answer = Row(86, "val_86")

var rdd = sql(leftSemiJoinQuery)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class HiveUdfSuite extends QueryTest {
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
checkAnswer(
sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
8
Row(8)
)
}

Expand Down Expand Up @@ -115,7 +115,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
checkAnswer(
sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(),
Seq(Seq("1"), Seq("2")))
Seq(Row("1"), Row("2")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString")

TestHive.reset()
Expand All @@ -131,7 +131,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
checkAnswer(
sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(),
Seq(Seq(0), Seq(2), Seq(13)))
Seq(Row(0), Row(2), Row(13)))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt")

TestHive.reset()
Expand All @@ -146,7 +146,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
checkAnswer(
sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(),
Seq(Seq("a,b,c"), Seq("d,e")))
Seq(Row("a,b,c"), Row("d,e")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString")

TestHive.reset()
Expand All @@ -160,7 +160,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
checkAnswer(
sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(),
Seq(Seq("hello world"), Seq("hello goodbye")))
Seq(Row("hello world"), Row("hello goodbye")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf")

TestHive.reset()
Expand All @@ -177,7 +177,7 @@ class HiveUdfSuite extends QueryTest {
sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
checkAnswer(
sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(),
Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13")))
Seq(Row("0, 0"), Row("2, 2"), Row("13, 13")))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")

TestHive.reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SQLQuerySuite extends QueryTest {
}

test("CTAS with serde") {
sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect
sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect()
sql(
"""CREATE TABLE ctas2
| ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
Expand All @@ -51,23 +51,23 @@ class SQLQuerySuite extends QueryTest {
| AS
| SELECT key, value
| FROM src
| ORDER BY key, value""".stripMargin).collect
| ORDER BY key, value""".stripMargin).collect()
sql(
"""CREATE TABLE ctas3
| ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012'
| STORED AS textfile AS
| SELECT key, value
| FROM src
| ORDER BY key, value""".stripMargin).collect
| ORDER BY key, value""".stripMargin).collect()

// the table schema may like (key: integer, value: string)
sql(
"""CREATE TABLE IF NOT EXISTS ctas4 AS
| SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect
| SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect()
// do nothing cause the table ctas4 already existed.
sql(
"""CREATE TABLE IF NOT EXISTS ctas4 AS
| SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect
| SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect()

checkAnswer(
sql("SELECT k, value FROM ctas1 ORDER BY k, value"),
Expand All @@ -89,7 +89,7 @@ class SQLQuerySuite extends QueryTest {
intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] {
sql(
"""CREATE TABLE ctas4 AS
| SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect
| SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect()
}
checkAnswer(
sql("SELECT key, value FROM ctas4 ORDER BY key, value"),
Expand Down Expand Up @@ -126,7 +126,7 @@ class SQLQuerySuite extends QueryTest {
sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested")
checkAnswer(
sql("SELECT f1.f2.f3 FROM nested"),
1)
Row(1))
checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"),
Seq.empty[Row])
checkAnswer(
Expand Down Expand Up @@ -233,7 +233,7 @@ class SQLQuerySuite extends QueryTest {
| (s struct<innerStruct: struct<s1:string>,
| innerArray:array<int>,
| innerMap: map<string, int>>)
""".stripMargin).collect
""".stripMargin).collect()

sql(
"""
Expand All @@ -243,7 +243,7 @@ class SQLQuerySuite extends QueryTest {

checkAnswer(
sql("SELECT * FROM nullValuesInInnerComplexTypes"),
Seq(Seq(Seq(null, null, null)))
Row(Row(null, null, null))
)

sql("DROP TABLE nullValuesInInnerComplexTypes")
Expand Down
Loading

0 comments on commit 9cdeb7d

Please sign in to comment.