Skip to content

Commit

Permalink
[SPARK-19993] Caching logical plans containing subquery expressions d…
Browse files Browse the repository at this point in the history
…oes not work
  • Loading branch information
dilipbiswal committed Apr 10, 2017
1 parent 3d7f201 commit 7346dca
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ abstract class SubqueryExpression(
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
case _ => false
}

def canonicalize(attrs: AttributeSeq): SubqueryExpression = {
// Normalize the outer references in the subquery plan.
val subPlan = plan.transformAllExpressions {
case OuterReference(r) => plan.normalizeExprId(r, attrs)
}
withNewPlan(subPlan).canonicalized.asInstanceOf[SubqueryExpression]
}
}

object SubqueryExpression {
Expand Down Expand Up @@ -236,6 +244,12 @@ case class ScalarSubquery(
override def nullable: Boolean = true
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
override lazy val canonicalized: Expression = {
ScalarSubquery(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
}
}

object ScalarSubquery {
Expand Down Expand Up @@ -268,6 +282,12 @@ case class ListQuery(
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def toString: String = s"list#${exprId.id} $conditionString"
override lazy val canonicalized: Expression = {
ListQuery(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
}
}

/**
Expand All @@ -290,4 +310,10 @@ case class Exists(
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
override def toString: String = s"exists#${exprId.id} $conditionString"
override lazy val canonicalized: Expression = {
Exists(
plan.canonicalized,
children.map(_.canonicalized),
ExprId(0))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* do not use `BindReferences` here as the plan may take the expression as a parameter with type
* `Attribute`, and replace it with `BoundReference` will cause error.
*/
protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = {
def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = {
e.transformUp {
case s: SubqueryExpression => s.canonicalize(input)
case ar: AttributeReference =>
val ordinal = input.indexOf(ar.exprId)
if (ordinal == -1) {
Expand Down
144 changes: 143 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.CleanerListener
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
sum
}

private def getNumInMemoryTableScanExecs(plan: SparkPlan): Int = {
plan.collect {
case InMemoryTableScanExec(_, _, relation) =>
getNumInMemoryTableScanExecs(relation.child) + 1
}.sum
}

test("withColumn doesn't invalidate cached dataframe") {
var evalCount = 0
val myUDF = udf((x: String) => { evalCount += 1; "result" })
Expand Down Expand Up @@ -670,4 +677,139 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
assert(spark.read.parquet(path).filter($"id" > 4).count() == 15)
}
}

test("SPARK-19993 simple subquery caching") {
withTempView("t1", "t2") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
Seq(1).toDF("c1").createOrReplaceTempView("t2")

sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin).cache()

val cachedDs =
sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin)
assert(getNumInMemoryRelations(cachedDs) == 1)

// Additional predicate in the subquery plan should cause a cache miss
val cachedMissDs =
sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1 where c1 = 0)
""".stripMargin)
assert(getNumInMemoryRelations(cachedMissDs) == 0)
}
}

test("SPARK-19993 subquery caching with correlated predicates") {
withTempView("t1", "t2") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
Seq(1).toDF("c1").createOrReplaceTempView("t2")

// Simple correlated predicate in subquery
sql(
"""
|SELECT * FROM t1
|WHERE
|t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
""".stripMargin).cache()

val cachedDs =
sql(
"""
|SELECT * FROM t1
|WHERE
|t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
""".stripMargin)
assert(getNumInMemoryRelations(cachedDs) == 1)
}
}

test("SPARK-19993 subquery with cached underlying relation") {
withTempView("t1", "t2") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
Seq(1).toDF("c1").createOrReplaceTempView("t2")
spark.catalog.cacheTable("t1")

// underlying table t1 is cached as well as the query that refers to it.
val ds =
sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin)
assert(getNumInMemoryRelations(ds) == 2)

val cachedDs =
sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin).cache()
assert(getNumInMemoryTableScanExecs(cachedDs.queryExecution.sparkPlan) == 3)
}
}

test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") {
withTempView("t1", "t2", "t3", "t4") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
Seq(2).toDF("c1").createOrReplaceTempView("t2")
Seq(1).toDF("c1").createOrReplaceTempView("t3")
Seq(1).toDF("c1").createOrReplaceTempView("t4")

// Nested predicate subquery
sql(
"""
|SELECT * FROM t1
|WHERE
|c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
""".stripMargin).cache()

val cachedDs =
sql(
"""
|SELECT * FROM t1
|WHERE
|c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
""".stripMargin)
assert(getNumInMemoryRelations(cachedDs) == 1)

// Scalar subquery and predicate subquery
sql(
"""
|SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
|WHERE
|c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
|OR
|EXISTS (SELECT c1 FROM t3)
|OR
|c1 IN (SELECT c1 FROM t4)
""".stripMargin).cache()

val cachedDs2 =
sql(
"""
|SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
|WHERE
|c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
|OR
|EXISTS (SELECT c1 FROM t3)
|OR
|c1 IN (SELECT c1 FROM t4)
""".stripMargin)
assert(getNumInMemoryRelations(cachedDs2) == 1)
}
}
}

0 comments on commit 7346dca

Please sign in to comment.