From 95d73fcf911bfb25b20fea798d1f7b3f4b319e26 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 13 Jan 2017 11:50:11 -0800 Subject: [PATCH 1/5] Fix Python UDF accessing attributes from both side of join --- python/pyspark/sql/tests.py | 9 +++++++++ .../spark/sql/catalyst/expressions/predicates.scala | 13 +++++++++++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../apache/spark/sql/catalyst/optimizer/joins.scala | 8 +++----- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a8250281dab35..5b9953a4a5ba7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -342,6 +342,15 @@ def test_udf_in_filter_on_top_of_outer_join(self): df = df.withColumn('b', udf(lambda x: 'x')(df.a)) self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + def test_udf_in_filter_on_top_of_inner(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import BooleanType + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.crossJoin(right).filter(f("a","b")) + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3fcbb05372d87..42dc3a8682b8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -86,6 +86,19 @@ trait PredicateHelper { */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = expr.references.subsetOf(plan.outputSet) + + /** + * Returns true iff `expr` could be evaluated as a condition within join. + */ + protected def canEvaluateWithinJoin(expr: Expression): Boolean = { + expr.find { + case e: SubqueryExpression => + // non-correlated subquery will be replaced as literal + e.children.nonEmpty + case e: Unevaluable => true + case _ => false + }.isEmpty + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 009c517ae4651..20b3898f8a6fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -893,7 +893,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val (newJoinConditions, others) = - commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e)) + commonFilterCondition.partition(canEvaluateWithinJoin) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) val join = Join(newLeft, newRight, joinType, newJoinCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 180ad2e0ad1fa..65d0d7064a468 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { : LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { - val (joinConditions, others) = conditions.partition( - e => !SubqueryExpression.hasCorrelatedSubquery(e)) + val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1)) val innerJoinType = (leftJoinType, rightJoinType) match { case (Inner, Inner) => Inner @@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( - e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) + e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) // should not have reference to same logical plan @@ -108,11 +107,10 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false + if (!e.deterministic || e.find(_.isInstanceOf[Unevaluable]).isDefined) return false val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val boundE = BindReferences.bindReference(e, attributes) - if (boundE.find(_.isInstanceOf[Unevaluable]).isDefined) return false val v = boundE.eval(emptyRow) v == null || v == false } From c45126bae138c68ef74298ff444421e3305640cd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 13 Jan 2017 13:13:42 -0800 Subject: [PATCH 2/5] fix style --- python/pyspark/sql/tests.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5b9953a4a5ba7..9b8a5d6b605c4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -342,13 +342,12 @@ def test_udf_in_filter_on_top_of_outer_join(self): df = df.withColumn('b', udf(lambda x: 'x')(df.a)) self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) - def test_udf_in_filter_on_top_of_inner(self): + def test_udf_in_filter_on_top_of_join(self): from pyspark.sql.functions import udf - from pyspark.sql.types import BooleanType left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) - df = left.crossJoin(right).filter(f("a","b")) + df = left.crossJoin(right).filter(f("a", "b")) self.assertEqual(df.collect(), [Row(a=1, b=1)]) def test_udf_without_arguments(self): From e4db8209843379fdd385dbf299baca7dea410075 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 17 Jan 2017 09:51:12 -0800 Subject: [PATCH 3/5] address comments --- python/pyspark/sql/tests.py | 1 + .../sql/catalyst/expressions/predicates.scala | 15 ++++++--------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9b8a5d6b605c4..61d82e768ad84 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -343,6 +343,7 @@ def test_udf_in_filter_on_top_of_outer_join(self): self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) def test_udf_in_filter_on_top_of_join(self): + # regression test for SPARK-18589 from pyspark.sql.functions import udf left = self.spark.createDataFrame([Row(a=1)]) right = self.spark.createDataFrame([Row(b=1)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 42dc3a8682b8d..1c18f4457f52c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils object InterpretedPredicate { @@ -90,14 +89,12 @@ trait PredicateHelper { /** * Returns true iff `expr` could be evaluated as a condition within join. */ - protected def canEvaluateWithinJoin(expr: Expression): Boolean = { - expr.find { - case e: SubqueryExpression => - // non-correlated subquery will be replaced as literal - e.children.nonEmpty - case e: Unevaluable => true - case _ => false - }.isEmpty + protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + case e: SubqueryExpression => + // non-correlated subquery will be replaced as literal + e.children.isEmpty + case e: Unevaluable => false + case e => e.children.forall(canEvaluateWithinJoin) } } From d6bba37147fef5088513542098d14a8650d7339b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 19 Jan 2017 10:23:50 -0800 Subject: [PATCH 4/5] bug fix --- .../org/apache/spark/sql/catalyst/expressions/predicates.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1c18f4457f52c..ac56ff13fa5bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -93,6 +93,7 @@ trait PredicateHelper { case e: SubqueryExpression => // non-correlated subquery will be replaced as literal e.children.isEmpty + case a: AttributeReference => true case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } From f720c85713252e7d33ca1bdb1667149b8d1a8cd2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 19 Jan 2017 16:22:52 -0800 Subject: [PATCH 5/5] rollback change, fix test --- .../spark/sql/catalyst/optimizer/joins.scala | 3 ++- .../python/BatchEvalPythonExecSuite.scala | 14 ++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 65d0d7064a468..bfe529e21e9ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -107,10 +107,11 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || e.find(_.isInstanceOf[Unevaluable]).isDefined) return false + if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val boundE = BindReferences.bindReference(e, attributes) + if (boundE.find(_.isInstanceOf[Unevaluable]).isDefined) return false val v = boundE.eval(emptyRow) v == null || v == false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 81bea2fef8bd4..2a3d1cf0b298a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType @@ -86,13 +86,11 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF refers to the attributes from more than one child") { val df = Seq(("Hello", 4)).toDF("a", "b") val df2 = Seq(("Hello", 4)).toDF("c", "d") - val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") - - val e = intercept[RuntimeException] { - joinDF.queryExecution.executedPlan - }.getMessage - assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child") - .forall(e.contains)) + val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") + val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect { + case b: BatchEvalPythonExec => b + } + assert(qualifiedPlanNodes.size == 1) } }