diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 59db28d58afce..d7b493d521ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -47,7 +47,6 @@ abstract class SubqueryExpression( plan: LogicalPlan, children: Seq[Expression], exprId: ExprId) extends PlanExpression[LogicalPlan] { - override lazy val resolved: Boolean = childrenResolved && plan.resolved override lazy val references: AttributeSet = if (plan.resolved) super.references -- plan.outputSet else super.references @@ -59,6 +58,13 @@ abstract class SubqueryExpression( children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) case _ => false } + def canonicalize(attrs: AttributeSeq): SubqueryExpression = { + // Normalize the outer references in the subquery plan. + val normalizedPlan = plan.transformAllExpressions { + case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs)) + } + withNewPlan(normalizedPlan).canonicalized.asInstanceOf[SubqueryExpression] + } } object SubqueryExpression { @@ -236,6 +242,12 @@ case class ScalarSubquery( override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ScalarSubquery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } object ScalarSubquery { @@ -268,6 +280,12 @@ case class ListQuery( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) override def toString: String = s"list#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ListQuery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } /** @@ -290,4 +308,10 @@ case class Exists( override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) override def toString: String = s"exists#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + Exists( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3008e8cb84659..2fb65bd435507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -377,7 +377,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // As the root of the expression, Alias will always take an arbitrary exprId, we need to // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. - Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => // Top level `AttributeReference` may also be used for output like `Alias`, we should @@ -385,7 +386,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT id += 1 ar.withExprId(ExprId(id)) - case other => normalizeExprId(other) + case other => QueryPlan.normalizeExprId(other, allAttributes) }.withNewChildren(canonicalizedChildren) } @@ -395,23 +396,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def preCanonicalized: PlanType = this - /** - * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` - * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we - * do not use `BindReferences` here as the plan may take the expression as a parameter with type - * `Attribute`, and replace it with `BoundReference` will cause error. - */ - protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { - e.transformUp { - case ar: AttributeReference => - val ordinal = input.indexOf(ar.exprId) - if (ordinal == -1) { - ar - } else { - ar.withExprId(ExprId(ordinal)) - } - }.canonicalized.asInstanceOf[T] - } /** * Returns true when the given query plan will return the same results as this query plan. @@ -438,3 +422,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) } + +object QueryPlan { + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: SubqueryExpression => s.canonicalize(input) + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 3a9132d74ac11..866fa98533218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} @@ -516,10 +517,10 @@ case class FileSourceScanExec( override lazy val canonicalized: FileSourceScanExec = { FileSourceScanExec( relation, - output.map(normalizeExprId(_, output)), + output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, - partitionFilters.map(normalizeExprId(_, output)), - dataFilters.map(normalizeExprId(_, output)), + partitionFilters.map(QueryPlan.normalizeExprId(_, output)), + dataFilters.map(QueryPlan.normalizeExprId(_, output)), None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 7a7d52b21427a..e66fe97afad45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ @@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sum } + private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { + plan.collect { + case InMemoryTableScanExec(_, _, relation) => + getNumInMemoryTablesRecursively(relation.child) + 1 + }.sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -670,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) } } + + test("SPARK-19993 simple subquery caching") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Additional predicate in the subquery plan should cause a cache miss + val cachedMissDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2 where c1 = 0) + """.stripMargin) + assert(getNumInMemoryRelations(cachedMissDs) == 0) + } + } + + test("SPARK-19993 subquery caching with correlated predicates") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(1).toDF("c1").createOrReplaceTempView("t2") + + // Simple correlated predicate in subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + } + } + + test("SPARK-19993 subquery with cached underlying relation") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + + // underlying table t1 is cached as well as the query that refers to it. + val ds = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin).cache() + assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) + } + } + + test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + + // Nested predicate subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Scalar subquery and predicate subquery + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin).cache() + + val cachedDs2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs2) == 1) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index fab0d7fa84827..666548d1a490b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ @@ -203,9 +204,9 @@ case class HiveTableScanExec( override lazy val canonicalized: HiveTableScanExec = { val input: AttributeSeq = relation.output HiveTableScanExec( - requestedAttributes.map(normalizeExprId(_, input)), + requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) + partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)