From 62fa9cf99610d8fa67d123450f2721cac0b5899f Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 23 Jul 2018 11:56:05 -0700 Subject: [PATCH 1/2] [SPARK-24891][SQL] Fix HandleNullInputsForUDF rule --- .../sql/catalyst/analysis/Analyzer.scala | 22 +++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 17 ++++++++-- .../scala/org/apache/spark/sql/UDFSuite.scala | 31 ++++++++++++++++++- 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 866396c42f9d8..f33e97366c48e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -2145,14 +2145,24 @@ class Analyzer( val parameterTypes = ScalaReflection.getParameterTypes(func) assert(parameterTypes.length == inputs.length) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // (cls, expr) => cls.isPrimitive && expr.nullable + val needsNullCheck = (cls: Class[_], expr: Expression) => + cls.isPrimitive && !expr.isInstanceOf[AssertNotNull] val inputsNullCheck = parameterTypes.zip(inputs) - // TODO: skip null handling for not-nullable primitive inputs after we can completely - // trust the `nullable` information. - // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } - .filter { case (cls, _) => cls.isPrimitive } + .filter { case (cls, expr) => needsNullCheck(cls, expr) } .map { case (_, expr) => IsNull(expr) } .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) - inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + // Once we add an `If` check above the udf, it is safe to mark those checked inputs + // as not nullable (i.e., wrap them with `AssertNotNull`), because the null-returning + // branch of `If` will be called if any of these checked inputs is null. Thus we can + // prevent this rule from being applied repeatedly. + val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) => + if (needsNullCheck(cls, expr)) AssertNotNull(expr) else expr } + inputsNullCheck + .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) + .getOrElse(udf) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 9e0db8dbf8f3a..51889da678d93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, @@ -316,7 +317,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) - val expected2 = If(IsNull(double), nullResult, udf2) + val expected2 = + If(IsNull(double), nullResult, udf2.copy(children = string :: AssertNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters @@ -324,7 +326,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3) + udf3.copy(children = AssertNotNull(short) :: AssertNotNull(double) :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable @@ -336,10 +338,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected4 = If( IsNull(short), nullResult, - udf4) + udf4.copy(children = AssertNotNull(short) :: double.withNullability(false) :: Nil)) // checkUDF(udf4, expected4) } + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val a = testRelation.output(0) + val func = (x: Int, y: Int) => x + y + val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil) + val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil) + val plan = Project(Alias(udf2, "")() :: Nil, testRelation) + comparePlans(plan.analyze, plan.analyze.analyze) + } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { val a = testRelation2.output(0) val c = testRelation2.output(2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 21afdc7e2a33f..d8074571ffc65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.functions.{lit, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types.{DataTypes, DoubleType} @@ -324,4 +324,33 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) } } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule") { + val udf1 = udf({(x: Int, y: Int) => x + y}) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", udf1($"a", lit(10)))) + .withColumn("c", udf1($"a", lit(null))) + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + + comparePlans(df.logicalPlan, plan) + checkAnswer( + df, + Seq( + Row(0, 10, null), + Row(1, 12, null), + Row(2, 14, null))) + } + + test("SPARK-24891 Fix HandleNullInputsForUDF rule - with table") { + withTable("x") { + Seq((1, "2"), (2, "4")).toDF("a", "b").write.format("json").saveAsTable("x") + sql("insert into table x values(3, null)") + sql("insert into table x values(null, '4')") + spark.udf.register("f", (a: Int, b: String) => a + b) + val df = spark.sql("SELECT f(a, b) FROM x") + val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed + comparePlans(df.logicalPlan, plan) + checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null))) + } + } } From b499b9727a4cb9cc42149d05a4d54dba2de8bd9e Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 23 Jul 2018 18:38:37 -0700 Subject: [PATCH 2/2] Change AssertNotNull to KnownNotNull --- .../sql/catalyst/analysis/Analyzer.scala | 6 ++-- .../expressions/constraintExpressions.scala | 35 +++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 7 ++-- 3 files changed, 41 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f33e97366c48e..4f474f4987dcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2149,17 +2149,17 @@ class Analyzer( // trust the `nullable` information. // (cls, expr) => cls.isPrimitive && expr.nullable val needsNullCheck = (cls: Class[_], expr: Expression) => - cls.isPrimitive && !expr.isInstanceOf[AssertNotNull] + cls.isPrimitive && !expr.isInstanceOf[KnowNotNull] val inputsNullCheck = parameterTypes.zip(inputs) .filter { case (cls, expr) => needsNullCheck(cls, expr) } .map { case (_, expr) => IsNull(expr) } .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) // Once we add an `If` check above the udf, it is safe to mark those checked inputs - // as not nullable (i.e., wrap them with `AssertNotNull`), because the null-returning + // as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning // branch of `If` will be called if any of these checked inputs is null. Thus we can // prevent this rule from being applied repeatedly. val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) => - if (needsNullCheck(cls, expr)) AssertNotNull(expr) else expr } + if (needsNullCheck(cls, expr)) KnowNotNull(expr) else expr } inputsNullCheck .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) .getOrElse(udf) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala new file mode 100644 index 0000000000000..53936aa914c8f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} +import org.apache.spark.sql.types.DataType + +case class KnowNotNull(child: Expression) extends UnaryExpression { + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + child.genCode(ctx).copy(isNull = FalseLiteral) + } + + override def eval(input: InternalRow): Any = { + child.eval(input) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 51889da678d93..31f703d018aed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, @@ -318,7 +317,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) val expected2 = - If(IsNull(double), nullResult, udf2.copy(children = string :: AssertNotNull(double) :: Nil)) + If(IsNull(double), nullResult, udf2.copy(children = string :: KnowNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters @@ -326,7 +325,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3.copy(children = AssertNotNull(short) :: AssertNotNull(double) :: Nil)) + udf3.copy(children = KnowNotNull(short) :: KnowNotNull(double) :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable @@ -338,7 +337,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { val expected4 = If( IsNull(short), nullResult, - udf4.copy(children = AssertNotNull(short) :: double.withNullability(false) :: Nil)) + udf4.copy(children = KnowNotNull(short) :: double.withNullability(false) :: Nil)) // checkUDF(udf4, expected4) }