Skip to content

Commit

Permalink
[SPARK-25044][FOLLOW-UP] Change ScalaUDF constructor signature
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is a follow-up PR for #22259. The extra field added in `ScalaUDF` with the original PR was declared optional, but should be indeed required, otherwise callers of `ScalaUDF`'s constructor could ignore this new field and cause the result to be incorrect. This PR makes the new field required and changes its name to `handleNullForInputs`.

#22259 breaks the previous behavior for null-handling of primitive-type input parameters. For example, for `val f = udf({(x: Int, y: Any) => x})`, `f(null, "str")` should return `null` but would return `0` after #22259. In this PR, all UDF methods except `def udf(f: AnyRef, dataType: DataType): UserDefinedFunction` have been restored with the original behavior. The only exception is documented in the Spark SQL migration guide.

In addition, now that we have this extra field indicating if a null-test should be applied on the corresponding input value, we can also make use of this flag to avoid the rule `HandleNullInputsForUDF` being applied infinitely.

## How was this patch tested?
Added UT in UDFSuite

Passed affected existing UTs:
AnalysisSuite
UDFSuite

Closes #22732 from maryannxue/spark-25044-followup.

Lead-authored-by: maryannxue <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit e816776)
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
maryannxue and cloud-fan committed Oct 19, 2018
1 parent 9ed2e42 commit df60d9f
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst

import java.lang.reflect.Constructor

import scala.util.Properties

import org.apache.commons.lang3.reflect.ConstructorUtils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
Expand Down Expand Up @@ -879,7 +882,7 @@ object ScalaReflection extends ScalaReflection {
* Support for generating catalyst schemas for scala objects. Note that unlike its companion
* object, this trait able to work in both the runtime and the compile time (macro) universe.
*/
trait ScalaReflection {
trait ScalaReflection extends Logging {
/** The universe we work in (runtime or macro) */
val universe: scala.reflect.api.Universe

Expand Down Expand Up @@ -932,6 +935,23 @@ trait ScalaReflection {
tpe.dealias.erasure.typeSymbol.asClass.fullName
}

/**
* Returns the nullability of the input parameter types of the scala function object.
*
* Note that this only works with Scala 2.11, and the information returned may be inaccurate if
* used with a different Scala version.
*/
def getParameterTypeNullability(func: AnyRef): Seq[Boolean] = {
if (!Properties.versionString.contains("2.11")) {
logWarning(s"Scala ${Properties.versionString} cannot get type nullability correctly via " +
"reflection, thus Spark cannot add proper input null check for UDF. To avoid this " +
"problem, use the typed UDF interfaces instead.")
}
val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
assert(methods.length == 1)
methods.head.getParameterTypes.map(!_.isPrimitive)
}

/**
* Returns the parameter names and types for the primary constructor of this type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2151,36 +2151,27 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) =>
if (nullableTypes.isEmpty) {
// If no nullability info is available, do nothing. No fields will be specially
// checked for null in the plan. If nullability info is incorrect, the results
// of the UDF could be wrong.
udf
} else {
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(nullableTypes.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
val needsNullCheck = (nullable: Boolean, expr: Expression) =>
nullable && !expr.isInstanceOf[KnownNotNull]
val inputsNullCheck = nullableTypes.zip(inputs)
.filter { case (nullableType, expr) => needsNullCheck(!nullableType, expr) }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) =>
if (nullable) expr else KnownNotNull(expr)
}
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
.getOrElse(udf)
}
case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _)
if inputsNullSafe.contains(false) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(inputsNullSafe.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
val inputsNullCheck = inputsNullSafe.zip(inputs)
.filter { case (nullSafe, _) => !nullSafe }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as null-safe (i.e., set `inputsNullSafe` all `true`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputsNullSafe = inputsNullSafe.map(_ => true)
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType),
udf.copy(inputsNullSafe = newInputsNullSafe)))
.getOrElse(udf)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
Expand All @@ -31,6 +31,9 @@ import org.apache.spark.sql.types.DataType
* null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
* @param dataType Return type of function.
* @param children The input expressions of this UDF.
* @param inputsNullSafe Whether the inputs are of non-primitive types or not nullable. Null values
* of Scala primitive types will be converted to the type's default value and
* lead to wrong results, thus need special handling before calling the UDF.
* @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
Expand All @@ -39,17 +42,16 @@ import org.apache.spark.sql.types.DataType
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
* each time it is invoked with a particular input.
* @param nullableTypes which of the inputTypes are nullable (i.e. not primitive)
*/
case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputsNullSafe: Seq[Boolean],
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true,
nullableTypes: Seq[Boolean] = Nil)
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {

// The constructor for SPARK 2.1 and 2.2
Expand All @@ -60,8 +62,8 @@ case class ScalaUDF(
inputTypes: Seq[DataType],
udfName: Option[String]) = {
this(
function, dataType, children, inputTypes, udfName, nullable = true,
udfDeterministic = true, nullableTypes = Nil)
function, dataType, children, ScalaReflection.getParameterTypeNullability(function),
inputTypes, udfName, nullable = true, udfDeterministic = true)
}

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,24 +314,24 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

// non-primitive parameters do not need special null handling
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, true :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
nullableTypes = true :: false :: Nil)
true :: false :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
If(IsNull(double), nullResult, udf2.copy(inputsNullSafe = true :: true :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
nullableTypes = false :: false :: Nil)
false :: false :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil))
udf3.copy(inputsNullSafe = true :: true :: Nil))
checkUDF(udf3, expected3)

// we can skip special null handling for primitive parameters that are not nullable
Expand All @@ -340,19 +340,19 @@ class AnalysisSuite extends AnalysisTest with Matchers {
(s: Short, d: Double) => "x",
StringType,
short :: double.withNullability(false) :: Nil,
nullableTypes = false :: false :: Nil)
false :: false :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
udf4.copy(children = KnownNotNull(short) :: double.withNullability(false) :: Nil))
udf4.copy(inputsNullSafe = true :: true :: Nil))
// checkUDF(udf4, expected4)
}

test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
val a = testRelation.output(0)
val func = (x: Int, y: Int) => x + y
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, nullableTypes = false :: false :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, nullableTypes = false :: false :: Nil)
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, false :: false :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, false :: false :: Nil)
val plan = Project(Alias(udf2, "")() :: Nil, testRelation)
comparePlans(plan.analyze, plan.analyze.analyze)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ import org.apache.spark.sql.types.{IntegerType, StringType}
class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
checkEvaluation(stringUdf, "ax")
}

test("better error message for NPE") {
val udf = ScalaUDF(
(s: String) => s.toLowerCase(Locale.ROOT),
StringType,
Literal.create(null, StringType) :: Nil)
Literal.create(null, StringType) :: Nil,
true :: Nil)

val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))
Expand All @@ -50,7 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ class TreeNodeSuite extends SparkFunSuite {
}

test("toJSON should not throws java.lang.StackOverflowError") {
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr))
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), true :: Nil)
// Should not throw java.lang.StackOverflowError
udf.toJSON
}
Expand Down
Loading

0 comments on commit df60d9f

Please sign in to comment.