Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24371] [SQL] Added isInCollection in DataFrame API for Scala and Java. #21416

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -219,7 +218,11 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral
case In(v, list) if list.isEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this improvement looks reasonable, but can we move them to a separated PR? it's not related to adding isInCollection.

// When v is not nullable, the following expression will be optimized
// to FalseLiteral which is tested in OptimizeInSuite.scala
If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType))
case In(v, list) if list.length == 1 => EqualTo(v, list.head)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ur, @dbtsai . This will cause side-effects on typecasting. For example, please see the following example. Could you add these kind of test cases?

scala> sql("select '1.1' in (1), '1.1' = 1").collect()
res0: Array[org.apache.spark.sql.Row] = Array([false,true])

Copy link
Member Author

@dbtsai dbtsai May 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought once the optimizer reaches case In(v, list) if list.length == 1 => EqualTo(v, list.head), the types in v and list are resolved, and autocasting is done so that v and the element in list will have the same type. Thus, we are safe to directly use EqualTo.

With this PR, I get the correct behavior.

sql("select '1.1' in (1), '1.1' = 1").explain(true)

== Analyzed Logical Plan ==
(CAST(1.1 AS STRING) IN (CAST(1 AS STRING))): boolean, (CAST(1.1 AS INT) = 1): boolean
Project [cast(1.1 as string) IN (cast(1 as string)) AS (CAST(1.1 AS STRING) IN (CAST(1 AS STRING)))#484, (cast(1.1 as int) = 1) AS (CAST(1.1 AS INT) = 1)#485]
+- OneRowRelation

== Optimized Logical Plan ==
Project [false AS (CAST(1.1 AS STRING) IN (CAST(1 AS STRING)))#484, true AS (CAST(1.1 AS INT) = 1)#485]
+- OneRowRelation

== Physical Plan ==
*(1) Project [false AS (CAST(1.1 AS STRING) IN (CAST(1 AS STRING)))#484, true AS (CAST(1.1 AS INT) = 1)#485]
+- Scan OneRowRelation[]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it. Thanks.

Copy link
Member

@dongjoon-hyun dongjoon-hyun May 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check the following case, too?

scala> spark.range(1).toDF("a").createOrReplaceTempView("t")
scala> sql("select * from t group by a having count(*) = (select count(*) from t)").explain
== Physical Plan ==
*(2) Project [a#2L]
+- *(2) Filter (count(1)#75L = Subquery subquery62)
   :  +- Subquery subquery62
   :     +- *(2) HashAggregate(keys=[], functions=[count(1)])
   :        +- Exchange SinglePartition
   :           +- *(1) HashAggregate(keys=[], functions=[partial_count(1)])
   :              +- *(1) Project
   :                 +- *(1) Range (0, 1, step=1, splits=8)
   +- *(2) HashAggregate(keys=[a#2L], functions=[count(1)])
      +- Exchange hashpartitioning(a#2L, 200)
         +- *(1) HashAggregate(keys=[a#2L], functions=[partial_count(1)])
            +- *(1) Project [id#0L AS a#2L]
               +- *(1) Range (0, 1, step=1, splits=8)

scala> sql("select * from t group by a having count(*) in (select count(*) from t)").explain
java.lang.StackOverflowError
  at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)

Copy link
Member Author

@dbtsai dbtsai May 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I'm debugging this StackOverflowError issue shown in Hive test. Thanks for this which helps me to reproduce locally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. This is that one.

case expr @ In(v, list) if expr.inSetConvertible =>
val newList = ExpressionSet(list).toSeq
if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
newPlan = dedupJoin(
Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
exists
case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
case EqualTo(value, ListQuery(sub, conditions, _, _)) =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dongjoon-hyun this should fix the test. I'll add one test for this.

val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ abstract class LogicalPlan
schema.map { field =>
resolve(field.name :: Nil, resolver).map {
case a: AttributeReference => a
case other => sys.error(s"can not handle nested schema yet... plan $this")
case _ => sys.error(s"can not handle nested schema yet... plan $this")
}.getOrElse {
throw new AnalysisException(
s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest {
}
}

test("OptimizedIn test: one element in list gets transformed to EqualTo.") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b"))))
.analyze

val optimized = Optimize.execute(originalQuery)
val correctAnswer =
testRelation
.where(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("b")))
.analyze

comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: In empty list gets transformed to FalseLiteral " +
"when value is not nullable") {
val originalQuery =
Expand All @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: In empty list gets transformed to " +
"If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) when value is nullable") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Nil))
.analyze

val optimized = Optimize.execute(originalQuery)
val correctAnswer =
testRelation
.where(If(IsNotNull(UnresolvedAttribute("a")),
Literal(false), Literal.create(null, BooleanType)))
.analyze

comparePlans(optimized, correctAnswer)
}
}
19 changes: 19 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import scala.collection.JavaConverters._
import scala.language.implicitConversions

import org.apache.spark.annotation.InterfaceStability
Expand Down Expand Up @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging {
@scala.annotation.varargs
def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is not your change, but I think (both here and bellow) something about the automagical type casting thats going on should be in the docstring/scaladoc/javadoc because to me its a little surprising how this will compare integers to strings and silently convert the types including if there are no strings which can be converted to integers. And I'd also include that in the isInCollection docstring/scaladoc/javadoc bellow.

I'd also point out that the result of the conversion needs to be of the same type and not of a sequence of the type (although the error message we get is pretty clear so your call).

Just a suggestion for improvement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Let's do it in the followup PR. Thanks.


/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided Set.
*
* @group expr_ops
* @since 2.4.0
*/
def isinSet(values: scala.collection.Set[_]): Column = isin(values.toSeq: _*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we be more generic here and accept Iterable? Then Set, Seq, Array can all be accepted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name can be isInCollection

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound good.

How do we want to do the naming? def isin has i as lower case. If we do isInCollection, it will be slightly inconsistent.


/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided Set.
*
* @group java_expr_ops
* @since 2.4.0
*/
def isinSet(values: java.util.Set[_]): Column = isinSet(values.asScala)

/**
* SQL like expression. Returns a boolean column based on a SQL LIKE match.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql

import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
import org.scalatest.Matchers._
Expand Down Expand Up @@ -392,9 +396,83 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

intercept[AnalysisException] {
val e = intercept[AnalysisException] {
df2.filter($"a".isin($"b"))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("isinSet: Scala Set") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
checkAnswer(df.filter($"a".isinSet(Set(1, 2))),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 2))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 1))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

// Auto casting should work with mixture of different types in Set
checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2"))),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, "1"))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

checkAnswer(df.filter($"b".isinSet(Set("y", "x"))),
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "x"))),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "y"))),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isinSet(Set($"b")))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check the error message to prevent the future regression like raising different AnalysisException.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("isinSet: Java Set") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here. just run a single test case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

checkAnswer(df.filter($"a".isinSet(Set(1, 2).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 2).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 1).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

// Auto casting should work with mixture of different types in Set
checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2").asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, "1").asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

checkAnswer(df.filter($"b".isinSet(Set("y", "x").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "x").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "y").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isinSet(Set($"b").asJava))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("&&") {
Expand Down