forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-9435][SQL] Reuse function in Java UDF to correctly support exp…
…ressions that require equality comparison between ScalaUDF ## What changes were proposed in this pull request? Currently, running the codes in Java ```java spark.udf().register("inc", new UDF1<Long, Long>() { Override public Long call(Long i) { return i + 1; } }, DataTypes.LongType); spark.range(10).toDF("x").createOrReplaceTempView("tmp"); Row result = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").head(); Assert.assertEquals(7, result.getLong(0)); ``` fails as below: ``` org.apache.spark.sql.AnalysisException: expression 'tmp.`x`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;; Aggregate [UDF(x#19L)], [UDF(x#19L) AS UDF(x)#23L] +- SubqueryAlias tmp, `tmp` +- Project [id#16L AS x#19L] +- Range (0, 10, step=1, splits=Some(8)) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:40) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:57) ``` The root cause is because we were creating the function every time when it needs to build as below: ```scala scala> def inc(i: Int) = i + 1 inc: (i: Int)Int scala> (inc(_: Int)).hashCode res15: Int = 1231799381 scala> (inc(_: Int)).hashCode res16: Int = 2109839984 scala> (inc(_: Int)) == (inc(_: Int)) res17: Boolean = false ``` This seems leading to the comparison failure between `ScalaUDF`s created from Java UDF API, for example, in `Expression.semanticEquals`. In case of Scala one, it seems already fine. Both can be tested easily as below if any reviewer is more comfortable with Scala: ```scala val df = Seq((1, 10), (2, 11), (3, 12)).toDF("x", "y") val javaUDF = new UDF1[Int, Int] { override def call(i: Int): Int = i + 1 } // spark.udf.register("inc", javaUDF, IntegerType) // Uncomment this for Java API // spark.udf.register("inc", (i: Int) => i + 1) // Uncomment this for Scala API df.createOrReplaceTempView("tmp") spark.sql("SELECT inc(y) FROM tmp GROUP BY inc(y)").show() ``` ## How was this patch tested? Unit test in `JavaUDFSuite.java` and `./dev/lint-java`. Author: hyukjinkwon <[email protected]> Closes apache#16553 from HyukjinKwon/SPARK-9435.
- Loading branch information
1 parent
8281d46
commit a1a7eae
Showing
2 changed files
with
68 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters