-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-24371] [SQL] Added isInCollection in DataFrame API for Scala and Java. #21416
Changes from 1 commit
da10307
730b19b
6ff2806
286a468
1332406
fed2846
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet | |
import scala.collection.mutable.{ArrayBuffer, Stack} | ||
|
||
import org.apache.spark.sql.catalyst.analysis._ | ||
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} | ||
import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
|
@@ -219,7 +218,11 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { | |
object OptimizeIn extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case q: LogicalPlan => q transformExpressionsDown { | ||
case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral | ||
case In(v, list) if list.isEmpty => | ||
// When v is not nullable, the following expression will be optimized | ||
// to FalseLiteral which is tested in OptimizeInSuite.scala | ||
If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) | ||
case In(v, list) if list.length == 1 => EqualTo(v, list.head) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ur, @dbtsai . This will cause side-effects on typecasting. For example, please see the following example. Could you add these kind of test cases? scala> sql("select '1.1' in (1), '1.1' = 1").collect()
res0: Array[org.apache.spark.sql.Row] = Array([false,true]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought once the optimizer reaches With this PR, I get the correct behavior. sql("select '1.1' in (1), '1.1' = 1").explain(true)
== Analyzed Logical Plan ==
(CAST(1.1 AS STRING) IN (CAST(1 AS STRING))): boolean, (CAST(1.1 AS INT) = 1): boolean
Project [cast(1.1 as string) IN (cast(1 as string)) AS (CAST(1.1 AS STRING) IN (CAST(1 AS STRING)))#484, (cast(1.1 as int) = 1) AS (CAST(1.1 AS INT) = 1)#485]
+- OneRowRelation
== Optimized Logical Plan ==
Project [false AS (CAST(1.1 AS STRING) IN (CAST(1 AS STRING)))#484, true AS (CAST(1.1 AS INT) = 1)#485]
+- OneRowRelation
== Physical Plan ==
*(1) Project [false AS (CAST(1.1 AS STRING) IN (CAST(1 AS STRING)))#484, true AS (CAST(1.1 AS INT) = 1)#485]
+- Scan OneRowRelation[] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got it. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you check the following case, too? scala> spark.range(1).toDF("a").createOrReplaceTempView("t")
scala> sql("select * from t group by a having count(*) = (select count(*) from t)").explain
== Physical Plan ==
*(2) Project [a#2L]
+- *(2) Filter (count(1)#75L = Subquery subquery62)
: +- Subquery subquery62
: +- *(2) HashAggregate(keys=[], functions=[count(1)])
: +- Exchange SinglePartition
: +- *(1) HashAggregate(keys=[], functions=[partial_count(1)])
: +- *(1) Project
: +- *(1) Range (0, 1, step=1, splits=8)
+- *(2) HashAggregate(keys=[a#2L], functions=[count(1)])
+- Exchange hashpartitioning(a#2L, 200)
+- *(1) HashAggregate(keys=[a#2L], functions=[partial_count(1)])
+- *(1) Project [id#0L AS a#2L]
+- *(1) Range (0, 1, step=1, splits=8)
scala> sql("select * from t group by a having count(*) in (select count(*) from t)").explain
java.lang.StackOverflowError
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, I'm debugging this StackOverflowError issue shown in Hive test. Thanks for this which helps me to reproduce locally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep. This is that one. |
||
case expr @ In(v, list) if expr.inSetConvertible => | ||
val newList = ExpressionSet(list).toSeq | ||
if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,7 +149,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { | |
newPlan = dedupJoin( | ||
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) | ||
exists | ||
case In(value, Seq(ListQuery(sub, conditions, _, _))) => | ||
case EqualTo(value, ListQuery(sub, conditions, _, _)) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dongjoon-hyun this should fix the test. I'll add one test for this. |
||
val exists = AttributeReference("exists", BooleanType, nullable = false)() | ||
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) | ||
val newConditions = (inConditions ++ conditions).reduceLeftOption(And) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
package org.apache.spark.sql | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.language.implicitConversions | ||
|
||
import org.apache.spark.annotation.InterfaceStability | ||
|
@@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging { | |
@scala.annotation.varargs | ||
def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is not your change, but I think (both here and bellow) something about the automagical type casting thats going on should be in the docstring/scaladoc/javadoc because to me its a little surprising how this will compare integers to strings and silently convert the types including if there are no strings which can be converted to integers. And I'd also include that in the isInCollection docstring/scaladoc/javadoc bellow. I'd also point out that the result of the conversion needs to be of the same type and not of a sequence of the type (although the error message we get is pretty clear so your call). Just a suggestion for improvement. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 Let's do it in the followup PR. Thanks. |
||
|
||
/** | ||
* A boolean expression that is evaluated to true if the value of this expression is contained | ||
* by the provided Set. | ||
* | ||
* @group expr_ops | ||
* @since 2.4.0 | ||
*/ | ||
def isinSet(values: scala.collection.Set[_]): Column = isin(values.toSeq: _*) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we be more generic here and accept There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the name can be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sound good. How do we want to do the naming? |
||
|
||
/** | ||
* A boolean expression that is evaluated to true if the value of this expression is contained | ||
* by the provided Set. | ||
* | ||
* @group java_expr_ops | ||
* @since 2.4.0 | ||
*/ | ||
def isinSet(values: java.util.Set[_]): Column = isinSet(values.asScala) | ||
|
||
/** | ||
* SQL like expression. Returns a boolean column based on a SQL LIKE match. | ||
* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,10 @@ | |
|
||
package org.apache.spark.sql | ||
|
||
import java.util.Locale | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.hadoop.io.{LongWritable, Text} | ||
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} | ||
import org.scalatest.Matchers._ | ||
|
@@ -392,9 +396,83 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { | |
|
||
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") | ||
|
||
intercept[AnalysisException] { | ||
val e = intercept[AnalysisException] { | ||
df2.filter($"a".isin($"b")) | ||
} | ||
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") | ||
.foreach { s => | ||
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) | ||
} | ||
} | ||
|
||
test("isinSet: Scala Set") { | ||
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") | ||
checkAnswer(df.filter($"a".isinSet(Set(1, 2))), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set(3, 2))), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set(3, 1))), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) | ||
|
||
// Auto casting should work with mixture of different types in Set | ||
checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2"))), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong))), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set(3, "1"))), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) | ||
|
||
checkAnswer(df.filter($"b".isinSet(Set("y", "x"))), | ||
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) | ||
checkAnswer(df.filter($"b".isinSet(Set("z", "x"))), | ||
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) | ||
checkAnswer(df.filter($"b".isinSet(Set("z", "y"))), | ||
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) | ||
|
||
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") | ||
|
||
val e = intercept[AnalysisException] { | ||
df2.filter($"a".isinSet(Set($"b"))) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's check the error message to prevent the future regression like raising different AnalysisException. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed |
||
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") | ||
.foreach { s => | ||
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) | ||
} | ||
} | ||
|
||
test("isinSet: Java Set") { | ||
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same thing here. just run a single test case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
checkAnswer(df.filter($"a".isinSet(Set(1, 2).asJava)), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set(3, 2).asJava)), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set(3, 1).asJava)), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) | ||
|
||
// Auto casting should work with mixture of different types in Set | ||
checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2").asJava)), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong).asJava)), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) | ||
checkAnswer(df.filter($"a".isinSet(Set(3, "1").asJava)), | ||
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) | ||
|
||
checkAnswer(df.filter($"b".isinSet(Set("y", "x").asJava)), | ||
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) | ||
checkAnswer(df.filter($"b".isinSet(Set("z", "x").asJava)), | ||
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) | ||
checkAnswer(df.filter($"b".isinSet(Set("z", "y").asJava)), | ||
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) | ||
|
||
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") | ||
|
||
val e = intercept[AnalysisException] { | ||
df2.filter($"a".isinSet(Set($"b").asJava)) | ||
} | ||
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") | ||
.foreach { s => | ||
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) | ||
} | ||
} | ||
|
||
test("&&") { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this improvement looks reasonable, but can we move them to a separated PR? it's not related to adding
isInCollection
.