Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25044][SQL] (take 2) Address translation of LMF closure primitive args to Object in Scala 2.12 #22259

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ object MimaExcludes {

// Exclude rules for 2.4.x
lazy val v24excludes = v23excludes ++ Seq(
// [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"),

// [SPARK-24296][CORE] Replicate large blocks as a stream.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"),
// [SPARK-23528] Add numIter to ClusteringSummary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -932,15 +932,6 @@ trait ScalaReflection {
tpe.dealias.erasure.typeSymbol.asClass.fullName
}

/**
* Returns classes of input parameters of scala function object.
*/
def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
assert(methods.length == 1)
methods.head.getParameterTypes
}

/**
* Returns the parameter names and types for the primary constructor of this type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2149,28 +2149,34 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf @ ScalaUDF(func, _, inputs, _, _, _, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
assert(parameterTypes.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
// (cls, expr) => cls.isPrimitive && expr.nullable
val needsNullCheck = (cls: Class[_], expr: Expression) =>
cls.isPrimitive && !expr.isInstanceOf[KnownNotNull]
val inputsNullCheck = parameterTypes.zip(inputs)
.filter { case (cls, expr) => needsNullCheck(cls, expr) }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) =>
if (needsNullCheck(cls, expr)) KnownNotNull(expr) else expr }
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
.getOrElse(udf)
case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we restore the spaces as in the original?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah missed it. We can clean it up later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing it - what is the space issue? There's an additional level of indent because of the if statement

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is in udf@ScalaUDF which should have been udf @ ScalaUDF

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right. Yeah I didn't mean to change that. It's minor enough to leave I think. (or else standardize across the code)

if (nullableTypes.isEmpty) {
// If no nullability info is available, do nothing. No fields will be specially
// checked for null in the plan. If nullability info is incorrect, the results
// of the UDF could be wrong.
udf
} else {
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(nullableTypes.length == inputs.length)

// TODO: skip null handling for not-nullable primitive inputs after we can completely
// trust the `nullable` information.
val inputsNullCheck = nullableTypes.zip(inputs)
.filter { case (nullable, _) => !nullable }
.map { case (_, expr) => IsNull(expr) }
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) =>
if (nullable) expr else KnownNotNull(expr)
}
inputsNullCheck
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
.getOrElse(udf)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.types.DataType
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
* each time it is invoked with a particular input.
* @param nullableTypes which of the inputTypes are nullable (i.e. not primitive)
*/
case class ScalaUDF(
function: AnyRef,
Expand All @@ -47,7 +48,8 @@ case class ScalaUDF(
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
udfDeterministic: Boolean = true,
nullableTypes: Seq[Boolean] = Nil)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using Nil as the default value is dangerous. We even do not have an assert to ensure it is set. We could easily miss the setting without the right values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We put the assert in the rule HandleNullInputsForUDF.

can we merge inputTypes and nullableTypes here so that we don't need to worry about it any more? cc @srowen

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here was again that we wanted to avoid changing the binary signature. I know catalyst is effectively private to Spark, but this wasn't marked specifically private; I wondered if it would actually affect callers? If not we can go back and merge it.

Nil is just an empty list; I don't think it's dangerous and it is used above in inputTypes. It is not always set, because it's not always possible to infer the schema, let alone nullability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is more about the way we handle nullableTypes if not specified as in https://github.com/apache/spark/pull/22259/files#diff-57b3d87be744b7d79a9beacf8e5e5eb2R2157. The test failure of https://github.com/apache/spark/pull/21851/files#diff-e8dddba2915a147970654aa93bee30a7R344 would have been exposed if the nullableTypes had been updated in this PR. So I would say logically this parameter is required, but right now it is declared optional. In this case, things went wrong when nullableTypes was left unspecified, and this could happen not only with tests but in "source" too. I suggest we move this parameter up right after inputTypes so it can get the attention it needs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, you are saying that some UDF needed to declare nullable types but didn't? I made the param optional to try to make 'migration' easier and avoid changing the signature much. But, the test you point to, doesn't it pass? are you saying it should not pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I could see an argument that it need not block release. The functionality works as intended, at least.

Would you change it again in 2.4.1? If not then we decide to just keep this behavior. Let's say at least get this in if there is a new RC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I could see an argument that it need not block release. The functionality works as intended, at least.

Would you change it again in 2.4.1? If not then we decide to just keep this behavior. Let's say at least get this in if there is a new RC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's mostly about maintainability. We should definitely update the code as @maryannxue said, so that ScalaUDF is easier to use and not that error-prone. I feel we don't need to backport it, as it's basically code refactor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maryannxue Please submit a PR to make the parameter nullableTypes required.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks like we should just make these changes after all, and for 2.4, as we need another RC.

extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {

// The constructor for SPARK 2.1 and 2.2
Expand All @@ -58,7 +60,8 @@ case class ScalaUDF(
inputTypes: Seq[DataType],
udfName: Option[String]) = {
this(
function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true)
function, dataType, children, inputTypes, udfName, nullable = true,
udfDeterministic = true, nullableTypes = Nil)
}

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,23 +261,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
}
}

test("get parameter type from a function object") {
val primitiveFunc = (i: Int, j: Long) => "x"
val primitiveTypes = getParameterTypes(primitiveFunc)
assert(primitiveTypes.forall(_.isPrimitive))
assert(primitiveTypes === Seq(classOf[Int], classOf[Long]))

val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x"
val boxedTypes = getParameterTypes(boxedFunc)
assert(boxedTypes.forall(!_.isPrimitive))
assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long]))

val anyFunc = (i: Any, j: AnyRef) => "x"
val anyTypes = getParameterTypes(anyFunc)
assert(anyTypes.forall(!_.isPrimitive))
assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
}

test("SPARK-15062: Get correct serializer for List[_]") {
val list = List(1, 2, 3)
val serializer = serializerFor[List[Int]](BoundReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,15 @@ class AnalysisSuite extends AnalysisTest with Matchers {
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
nullableTypes = true :: false :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
nullableTypes = false :: false :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
Expand All @@ -335,7 +337,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val udf4 = ScalaUDF(
(s: Short, d: Double) => "x",
StringType,
short :: double.withNullability(false) :: Nil)
short :: double.withNullability(false) :: Nil,
nullableTypes = false :: false :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
Expand Down
Loading