From df60d9f3469022866de2f41939a38e7e5d02dc1b Mon Sep 17 00:00:00 2001 From: maryannxue Date: Fri, 19 Oct 2018 21:03:59 +0800 Subject: [PATCH] [SPARK-25044][FOLLOW-UP] Change ScalaUDF constructor signature ## What changes were proposed in this pull request? This is a follow-up PR for #22259. The extra field added in `ScalaUDF` with the original PR was declared optional, but should be indeed required, otherwise callers of `ScalaUDF`'s constructor could ignore this new field and cause the result to be incorrect. This PR makes the new field required and changes its name to `handleNullForInputs`. #22259 breaks the previous behavior for null-handling of primitive-type input parameters. For example, for `val f = udf({(x: Int, y: Any) => x})`, `f(null, "str")` should return `null` but would return `0` after #22259. In this PR, all UDF methods except `def udf(f: AnyRef, dataType: DataType): UserDefinedFunction` have been restored with the original behavior. The only exception is documented in the Spark SQL migration guide. In addition, now that we have this extra field indicating if a null-test should be applied on the corresponding input value, we can also make use of this flag to avoid the rule `HandleNullInputsForUDF` being applied infinitely. ## How was this patch tested? Added UT in UDFSuite Passed affected existing UTs: AnalysisSuite UDFSuite Closes #22732 from maryannxue/spark-25044-followup. Lead-authored-by: maryannxue Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan (cherry picked from commit e8167768cfebfdb11acd8e0a06fe34ca43c14648) Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/ScalaReflection.scala | 22 +- .../sql/catalyst/analysis/Analyzer.scala | 51 ++-- .../sql/catalyst/expressions/ScalaUDF.scala | 14 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 +- .../catalyst/expressions/ScalaUDFSuite.scala | 9 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 218 ++++++++++-------- .../datasources/FileFormatDataWriter.scala | 3 +- .../sql/expressions/UserDefinedFunction.scala | 24 +- .../org/apache/spark/sql/functions.scala | 54 ++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 24 ++ 11 files changed, 257 insertions(+), 182 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0238d57de2446..c27180e2a6b9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst import java.lang.reflect.Constructor +import scala.util.Properties + import org.apache.commons.lang3.reflect.ConstructorUtils +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -879,7 +882,7 @@ object ScalaReflection extends ScalaReflection { * Support for generating catalyst schemas for scala objects. Note that unlike its companion * object, this trait able to work in both the runtime and the compile time (macro) universe. */ -trait ScalaReflection { +trait ScalaReflection extends Logging { /** The universe we work in (runtime or macro) */ val universe: scala.reflect.api.Universe @@ -932,6 +935,23 @@ trait ScalaReflection { tpe.dealias.erasure.typeSymbol.asClass.fullName } + /** + * Returns the nullability of the input parameter types of the scala function object. + * + * Note that this only works with Scala 2.11, and the information returned may be inaccurate if + * used with a different Scala version. + */ + def getParameterTypeNullability(func: AnyRef): Seq[Boolean] = { + if (!Properties.versionString.contains("2.11")) { + logWarning(s"Scala ${Properties.versionString} cannot get type nullability correctly via " + + "reflection, thus Spark cannot add proper input null check for UDF. To avoid this " + + "problem, use the typed UDF interfaces instead.") + } + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes.map(!_.isPrimitive) + } + /** * Returns the parameter names and types for the primary constructor of this type. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9c0975eecd443..4a83067bd8963 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2151,36 +2151,27 @@ class Analyzer( case p => p transformExpressionsUp { - case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) => - if (nullableTypes.isEmpty) { - // If no nullability info is available, do nothing. No fields will be specially - // checked for null in the plan. If nullability info is incorrect, the results - // of the UDF could be wrong. - udf - } else { - // Otherwise, add special handling of null for fields that can't accept null. - // The result of operations like this, when passed null, is generally to return null. - assert(nullableTypes.length == inputs.length) - - // TODO: skip null handling for not-nullable primitive inputs after we can completely - // trust the `nullable` information. - val needsNullCheck = (nullable: Boolean, expr: Expression) => - nullable && !expr.isInstanceOf[KnownNotNull] - val inputsNullCheck = nullableTypes.zip(inputs) - .filter { case (nullableType, expr) => needsNullCheck(!nullableType, expr) } - .map { case (_, expr) => IsNull(expr) } - .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) - // Once we add an `If` check above the udf, it is safe to mark those checked inputs - // as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning - // branch of `If` will be called if any of these checked inputs is null. Thus we can - // prevent this rule from being applied repeatedly. - val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) => - if (nullable) expr else KnownNotNull(expr) - } - inputsNullCheck - .map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs))) - .getOrElse(udf) - } + case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _) + if inputsNullSafe.contains(false) => + // Otherwise, add special handling of null for fields that can't accept null. + // The result of operations like this, when passed null, is generally to return null. + assert(inputsNullSafe.length == inputs.length) + + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + val inputsNullCheck = inputsNullSafe.zip(inputs) + .filter { case (nullSafe, _) => !nullSafe } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + // Once we add an `If` check above the udf, it is safe to mark those checked inputs + // as null-safe (i.e., set `inputsNullSafe` all `true`), because the null-returning + // branch of `If` will be called if any of these checked inputs is null. Thus we can + // prevent this rule from being applied repeatedly. + val newInputsNullSafe = inputsNullSafe.map(_ => true) + inputsNullCheck + .map(If(_, Literal.create(null, udf.dataType), + udf.copy(inputsNullSafe = newInputsNullSafe))) + .getOrElse(udf) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 8954fe8a58e6e..fae90caebf96c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType @@ -31,6 +31,9 @@ import org.apache.spark.sql.types.DataType * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. * @param children The input expressions of this UDF. + * @param inputsNullSafe Whether the inputs are of non-primitive types or not nullable. Null values + * of Scala primitive types will be converted to the type's default value and + * lead to wrong results, thus need special handling before calling the UDF. * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do * not want to perform coercion, simply use "Nil". Note that it would've been * better to use Option of Seq[DataType] so we can use "None" as the case for no @@ -39,17 +42,16 @@ import org.apache.spark.sql.types.DataType * @param nullable True if the UDF can return null value. * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result * each time it is invoked with a particular input. - * @param nullableTypes which of the inputTypes are nullable (i.e. not primitive) */ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], + inputsNullSafe: Seq[Boolean], inputTypes: Seq[DataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, - udfDeterministic: Boolean = true, - nullableTypes: Seq[Boolean] = Nil) + udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { // The constructor for SPARK 2.1 and 2.2 @@ -60,8 +62,8 @@ case class ScalaUDF( inputTypes: Seq[DataType], udfName: Option[String]) = { this( - function, dataType, children, inputTypes, udfName, nullable = true, - udfDeterministic = true, nullableTypes = Nil) + function, dataType, children, ScalaReflection.getParameterTypeNullability(function), + inputTypes, udfName, nullable = true, udfDeterministic = true) } override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cf76c92b093b7..d8cb6f7caa99e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -314,24 +314,24 @@ class AnalysisSuite extends AnalysisTest with Matchers { } // non-primitive parameters do not need special null handling - val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, true :: Nil) val expected1 = udf1 checkUDF(udf1, expected1) // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil, - nullableTypes = true :: false :: Nil) + true :: false :: Nil) val expected2 = - If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) + If(IsNull(double), nullResult, udf2.copy(inputsNullSafe = true :: true :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil, - nullableTypes = false :: false :: Nil) + false :: false :: Nil) val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil)) + udf3.copy(inputsNullSafe = true :: true :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable @@ -340,19 +340,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { (s: Short, d: Double) => "x", StringType, short :: double.withNullability(false) :: Nil, - nullableTypes = false :: false :: Nil) + false :: false :: Nil) val expected4 = If( IsNull(short), nullResult, - udf4.copy(children = KnownNotNull(short) :: double.withNullability(false) :: Nil)) + udf4.copy(inputsNullSafe = true :: true :: Nil)) // checkUDF(udf4, expected4) } test("SPARK-24891 Fix HandleNullInputsForUDF rule") { val a = testRelation.output(0) val func = (x: Int, y: Int) => x + y - val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, nullableTypes = false :: false :: Nil) - val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, nullableTypes = false :: false :: Nil) + val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, false :: false :: Nil) + val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, false :: false :: Nil) val plan = Project(Alias(udf2, "")() :: Nil, testRelation) comparePlans(plan.analyze, plan.analyze.analyze) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index e083ae0089244..467cfd5598ff1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.types.{IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("basic") { - val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil) + val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil) checkEvaluation(intUdf, 2) - val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil) + val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil) checkEvaluation(stringUdf, "ax") } @@ -37,7 +37,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { val udf = ScalaUDF( (s: String) => s.toLowerCase(Locale.ROOT), StringType, - Literal.create(null, StringType) :: Nil) + Literal.create(null, StringType) :: Nil, + true :: Nil) val e1 = intercept[SparkException](udf.eval()) assert(e1.getMessage.contains("Failed to execute user defined function")) @@ -50,7 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext - ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) + ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index b7092f4c42d4c..64aa1ee39046d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -564,7 +564,7 @@ class TreeNodeSuite extends SparkFunSuite { } test("toJSON should not throws java.lang.StackOverflowError") { - val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr)) + val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), true :: Nil) // Should not throw java.lang.StackOverflowError udf.toJSON } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index c37ba0c60c3d4..aa3a6c3bf122f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i] :: $s"}) + val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"Try(ScalaReflection.schemaFor[A$i]).toOption :: $s"}) println(s""" |/** | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). @@ -122,10 +122,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try($inputSchemas).toOption + | val inputSchemas: Seq[Option[ScalaReflection.Schema]] = $inputSchemas | def builder(e: Seq[Expression]) = if (e.length == $x) { - | ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - | udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + | ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + | if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + | Some(name), nullable, udfDeterministic = true) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $x; Found: " + e.length) @@ -151,7 +152,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { | val func = f$anyCast.call($anyParams) | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF($funcCall, returnType, e, udfName = Some(name)) + | ScalaUDF($funcCall, returnType, e, e.map(_ => true), udfName = Some(name)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $i; Found: " + e.length) @@ -168,10 +169,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -188,10 +190,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -208,10 +211,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -228,10 +232,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -248,10 +253,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -268,10 +274,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -288,10 +295,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -308,10 +316,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -328,10 +337,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -348,10 +358,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -368,10 +379,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -388,10 +400,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -408,10 +421,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -428,10 +442,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -448,10 +463,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -468,10 +484,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -488,10 +505,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -508,10 +526,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -528,10 +547,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -548,10 +568,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -568,10 +589,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -588,10 +610,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -608,10 +631,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas: Option[Seq[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1] :: ScalaReflection.schemaFor[A2] :: ScalaReflection.schemaFor[A3] :: ScalaReflection.schemaFor[A4] :: ScalaReflection.schemaFor[A5] :: ScalaReflection.schemaFor[A6] :: ScalaReflection.schemaFor[A7] :: ScalaReflection.schemaFor[A8] :: ScalaReflection.schemaFor[A9] :: ScalaReflection.schemaFor[A10] :: ScalaReflection.schemaFor[A11] :: ScalaReflection.schemaFor[A12] :: ScalaReflection.schemaFor[A13] :: ScalaReflection.schemaFor[A14] :: ScalaReflection.schemaFor[A15] :: ScalaReflection.schemaFor[A16] :: ScalaReflection.schemaFor[A17] :: ScalaReflection.schemaFor[A18] :: ScalaReflection.schemaFor[A19] :: ScalaReflection.schemaFor[A20] :: ScalaReflection.schemaFor[A21] :: ScalaReflection.schemaFor[A22] :: Nil).toOption + val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Try(ScalaReflection.schemaFor[A22]).toOption :: Nil def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.dataType)).getOrElse(Nil), Some(name), nullable, - udfDeterministic = true, nullableTypes = inputSchemas.map(_.map(_.nullable)).getOrElse(Nil)) + ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)), + if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType), + Some(name), nullable, udfDeterministic = true) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) @@ -719,7 +743,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF0[_], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF0[Any]].call() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(() => func, returnType, e, udfName = Some(name)) + ScalaUDF(() => func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -734,7 +758,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -749,7 +773,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -764,7 +788,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -779,7 +803,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -794,7 +818,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -809,7 +833,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -824,7 +848,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -839,7 +863,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -854,7 +878,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -869,7 +893,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -884,7 +908,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -899,7 +923,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -914,7 +938,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -929,7 +953,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -944,7 +968,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -959,7 +983,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -974,7 +998,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -989,7 +1013,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -1004,7 +1028,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -1019,7 +1043,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -1034,7 +1058,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -1049,7 +1073,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, returnType, e, udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 6499328e89ce7..10733810b6416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -179,7 +179,8 @@ class DynamicPartitionDataWriter( val partitionName = ScalaUDF( ExternalCatalogUtils.getPartitionPathString _, StringType, - Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))), + Seq(true, true)) if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 697757f8a73ce..eb956c4b3e888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -73,19 +73,24 @@ case class UserDefinedFunction protected[sql] ( */ @scala.annotation.varargs def apply(exprs: Column*): Column = { - if (inputTypes.isDefined && nullableTypes.isDefined) { - require(inputTypes.get.length == nullableTypes.get.length) + // TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()` + // and `nullableTypes` is always set. + if (nullableTypes.isEmpty) { + nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f)) + } + if (inputTypes.isDefined) { + assert(inputTypes.get.length == nullableTypes.get.length) } Column(ScalaUDF( f, dataType, exprs.map(_.expr), + nullableTypes.get, inputTypes.getOrElse(Nil), udfName = _nameOption, nullable = _nullable, - udfDeterministic = _deterministic, - nullableTypes = nullableTypes.getOrElse(Nil))) + udfDeterministic = _deterministic)) } private def copyAll(): UserDefinedFunction = { @@ -146,9 +151,14 @@ private[sql] object SparkUserDefinedFunction { def create( f: AnyRef, dataType: DataType, - inputSchemas: Option[Seq[ScalaReflection.Schema]]): UserDefinedFunction = { - val udf = new UserDefinedFunction(f, dataType, inputSchemas.map(_.map(_.dataType))) - udf.nullableTypes = inputSchemas.map(_.map(_.nullable)) + inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = { + val inputTypes = if (inputSchemas.contains(None)) { + None + } else { + Some(inputSchemas.map(_.get.dataType)) + } + val udf = new UserDefinedFunction(f, dataType, inputTypes) + udf.nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true))) udf } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b67d7a1ca54..6a43ce160efec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3819,7 +3819,7 @@ object functions { (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"}) + val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"Try(ScalaReflection.schemaFor(typeTag[A$i])).toOption :: $s"}) println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -3832,7 +3832,7 @@ object functions { | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputSchemas = Try($inputTypes).toOption + | val inputSchemas = $inputSchemas | val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) @@ -3856,7 +3856,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = f$anyCast.call($anyParams) - | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = None) + | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = Seq.fill($i)(None)) |}""".stripMargin) } @@ -3877,7 +3877,7 @@ object functions { */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(Nil).toOption + val inputSchemas = Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3893,7 +3893,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3909,7 +3909,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3925,7 +3925,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3941,7 +3941,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3957,7 +3957,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3973,7 +3973,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3989,7 +3989,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4005,7 +4005,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4021,7 +4021,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4037,7 +4037,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4057,7 +4057,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = Seq.fill(0)(None)) } /** @@ -4071,7 +4071,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(1)(None)) } /** @@ -4085,7 +4085,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(2)(None)) } /** @@ -4099,7 +4099,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(3)(None)) } /** @@ -4113,7 +4113,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(4)(None)) } /** @@ -4127,7 +4127,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(5)(None)) } /** @@ -4141,7 +4141,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(6)(None)) } /** @@ -4155,7 +4155,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(7)(None)) } /** @@ -4169,7 +4169,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(8)(None)) } /** @@ -4183,7 +4183,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(9)(None)) } /** @@ -4197,7 +4197,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(10)(None)) } // scalastyle:on parameter.number @@ -4216,7 +4216,9 @@ object functions { * @since 2.0.0 */ def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { - SparkUserDefinedFunction.create(f, dataType, inputSchemas = None) + // TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently + // unavailable. We may need to create type-safe overloaded versions of udf() methods. + new UserDefinedFunction(f, dataType, inputTypes = None) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 30dca9497ddde..f8ed21bbf7c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -393,4 +393,28 @@ class UDFSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null))) } } + + test("SPARK-25044 Verify null input handling for primitive types - with udf()") { + val udf1 = udf((x: Long, y: Any) => x * 2 + (if (y == null) 1 else 0)) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", lit(null))) + .withColumn("c", udf1(lit(null), $"a")) + + checkAnswer( + df, + Seq( + Row(0, 1, null), + Row(1, 3, null), + Row(2, 5, null))) + } + + test("SPARK-25044 Verify null input handling for primitive types - with udf.register") { + withTable("t") { + Seq((null, new Integer(1), "x"), ("M", null, "y"), ("N", new Integer(3), null)) + .toDF("a", "b", "c").write.format("json").saveAsTable("t") + spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c) + val df = spark.sql("SELECT f(a, b, c) FROM t") + checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) + } + } }