Skip to content

Commit

Permalink
[SPARK-22945][SQL] add java UDF APIs in the functions object
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently Scala users can use UDF like
```
val foo = udf((i: Int) => Math.random() + i).asNondeterministic
df.select(foo('a))
```
Python users can also do it with similar APIs. However Java users can't do it, we should add Java UDF APIs in the functions object.

## How was this patch tested?

new tests

Author: Wenchen Fan <[email protected]>

Closes #20141 from cloud-fan/udf.
  • Loading branch information
cloud-fan committed Jan 4, 2018
1 parent 9fa703e commit d5861ab
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 112 deletions.
90 changes: 45 additions & 45 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.lang.reflect.{ParameterizedType, Type}
import java.lang.reflect.ParameterizedType

import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
Expand Down Expand Up @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends

/* register 0-22 were generated by this script
(0 to 22).map { x =>
(0 to 22).foreach { x =>
val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"})
println(s"""
/**
* Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF).
* @tparam RT return type of UDF.
* @since 1.3.0
*/
def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try($inputTypes).toOption
def builder(e: Seq[Expression]) = if (e.length == $x) {
ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: $x; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name)
if (nullable) udf else udf.asNonNullable()
}""")
|/**
| * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF).
| * @tparam RT return type of UDF.
| * @since 1.3.0
| */
|def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
| val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
| val inputTypes = Try($inputTypes).toOption
| def builder(e: Seq[Expression]) = if (e.length == $x) {
| ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true)
| } else {
| throw new AnalysisException("Invalid number of arguments for function " + name +
| ". Expected: $x; Found: " + e.length)
| }
| functionRegistry.createOrReplaceTempFunction(name, builder)
| val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name)
| if (nullable) udf else udf.asNonNullable()
|}""".stripMargin)
}
(0 to 22).foreach { i =>
Expand All @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
val funcCall = if (i == 0) "() => func" else "func"
println(s"""
|/**
| * Register a user-defined function with ${i} arguments.
| * Register a deterministic Java UDF$i instance as user-defined function (UDF).
| * @since $version
| */
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
Expand Down Expand Up @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 0 arguments.
* Register a deterministic Java UDF0 instance as user-defined function (UDF).
* @since 2.3.0
*/
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
Expand All @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 1 arguments.
* Register a deterministic Java UDF1 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
Expand All @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 2 arguments.
* Register a deterministic Java UDF2 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
Expand All @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 3 arguments.
* Register a deterministic Java UDF3 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
Expand All @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 4 arguments.
* Register a deterministic Java UDF4 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 5 arguments.
* Register a deterministic Java UDF5 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 6 arguments.
* Register a deterministic Java UDF6 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 7 arguments.
* Register a deterministic Java UDF7 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 8 arguments.
* Register a deterministic Java UDF8 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 9 arguments.
* Register a deterministic Java UDF9 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 10 arguments.
* Register a deterministic Java UDF10 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 11 arguments.
* Register a deterministic Java UDF11 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 12 arguments.
* Register a deterministic Java UDF12 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 13 arguments.
* Register a deterministic Java UDF13 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 14 arguments.
* Register a deterministic Java UDF14 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 15 arguments.
* Register a deterministic Java UDF15 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 16 arguments.
* Register a deterministic Java UDF16 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 17 arguments.
* Register a deterministic Java UDF17 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 18 arguments.
* Register a deterministic Java UDF18 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 19 arguments.
* Register a deterministic Java UDF19 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 20 arguments.
* Register a deterministic Java UDF20 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 21 arguments.
* Register a deterministic Java UDF21 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand All @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
}

/**
* Register a user-defined function with 22 arguments.
* Register a deterministic Java UDF22 instance as user-defined function (UDF).
* @since 1.3.0
*/
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] (
*
* @since 1.3.0
*/
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
Column(ScalaUDF(
f,
Expand Down
Loading

0 comments on commit d5861ab

Please sign in to comment.