Skip to content

Commit

Permalink
[SPARK-9435][SQL] Reuse function in Java UDF to correctly support exp…
Browse files Browse the repository at this point in the history
…ressions that require equality comparison between ScalaUDF

## What changes were proposed in this pull request?

Currently, running the codes in Java

```java
spark.udf().register("inc", new UDF1<Long, Long>() {
  Override
  public Long call(Long i) {
    return i + 1;
  }
}, DataTypes.LongType);

spark.range(10).toDF("x").createOrReplaceTempView("tmp");
Row result = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").head();
Assert.assertEquals(7, result.getLong(0));
```

fails as below:

```
org.apache.spark.sql.AnalysisException: expression 'tmp.`x`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;;
Aggregate [UDF(x#19L)], [UDF(x#19L) AS UDF(x)#23L]
+- SubqueryAlias tmp, `tmp`
   +- Project [id#16L AS x#19L]
      +- Range (0, 10, step=1, splits=Some(8))

	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:40)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:57)
```

The root cause is because we were creating the function every time when it needs to build as below:

```scala
scala> def inc(i: Int) = i + 1
inc: (i: Int)Int

scala> (inc(_: Int)).hashCode
res15: Int = 1231799381

scala> (inc(_: Int)).hashCode
res16: Int = 2109839984

scala> (inc(_: Int)) == (inc(_: Int))
res17: Boolean = false
```

This seems leading to the comparison failure between `ScalaUDF`s created from Java UDF API, for example, in `Expression.semanticEquals`.

In case of Scala one, it seems already fine.

Both can be tested easily as below if any reviewer is more comfortable with Scala:

```scala
val df = Seq((1, 10), (2, 11), (3, 12)).toDF("x", "y")
val javaUDF = new UDF1[Int, Int]  {
  override def call(i: Int): Int = i + 1
}
// spark.udf.register("inc", javaUDF, IntegerType) // Uncomment this for Java API
// spark.udf.register("inc", (i: Int) => i + 1)    // Uncomment this for Scala API
df.createOrReplaceTempView("tmp")
spark.sql("SELECT inc(y) FROM tmp GROUP BY inc(y)").show()
```

## How was this patch tested?

Unit test in `JavaUDFSuite.java` and `./dev/lint-java`.

Author: hyukjinkwon <[email protected]>

Closes #16553 from HyukjinKwon/SPARK-9435.
  • Loading branch information
HyukjinKwon authored and gatorsmile committed Jan 24, 2017
1 parent 3bdf3ee commit e576c1e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 23 deletions.
69 changes: 46 additions & 23 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
| * @since 1.3.0
| */
|def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = {
| val func = f$anyCast.call($anyParams)
| functionRegistry.registerFunction(
| name,
| (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e))
| (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
|}""".stripMargin)
}
*/
Expand Down Expand Up @@ -488,219 +489,241 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 1.3.0
*/
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 2 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 3 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 4 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 5 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 6 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 7 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 8 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 9 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 10 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 11 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 12 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 13 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 14 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 15 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 16 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 17 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 18 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 19 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 20 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 21 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

/**
* Register a user-defined function with 22 arguments.
* @since 1.3.0
*/
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
functionRegistry.registerFunction(
name,
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
}

// scalastyle:on line.size.limit
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package test.org.apache.spark.sql;

import java.io.Serializable;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -108,4 +109,25 @@ public void udf3Test() {
result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}

@SuppressWarnings("unchecked")
@Test
public void udf4Test() {
spark.udf().register("inc", new UDF1<Long, Long>() {
@Override
public Long call(Long i) {
return i + 1;
}
}, DataTypes.LongType);

spark.range(10).toDF("x").createOrReplaceTempView("tmp");
// This tests when Java UDFs are required to be the semantically same (See SPARK-9435).
List<Row> results = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").collectAsList();
Assert.assertEquals(10, results.size());
long sum = 0;
for (Row result : results) {
sum += result.getLong(0);
}
Assert.assertEquals(55, sum);
}
}

0 comments on commit e576c1e

Please sign in to comment.