diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 3d3d13b0ac9ee..98c67084642e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -86,7 +86,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { * Evaluates to `true` if it's NaN or null */ case class IsNaN(child: Expression) extends UnaryExpression -with Predicate with ImplicitCastInputTypes { + with Predicate with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index cbd624988dfb2..afa143bd5f331 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -68,12 +68,6 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toString, StringType) } - test("in") { - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) - } - test("case when") { val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index 221dcba0d8b5d..765cc7a969b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -45,6 +45,16 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("IsNaN") { + checkEvaluation(IsNaN(Literal(Double.NaN)), true) + checkEvaluation(IsNaN(Literal(Float.NaN)), true) + checkEvaluation(IsNaN(Literal(math.log(-3))), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) + checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) + checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) + checkEvaluation(IsNaN(Literal(5.5f)), false) + } + test("coalesce") { testAllTypes { (value: Any, tpe: DataType) => val lit = Literal.create(value, tpe) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 052abc51af5fd..2173a0c25c645 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) - } - test("IsNaN") { - checkEvaluation(IsNaN(Literal(Double.NaN)), true) - checkEvaluation(IsNaN(Literal(Float.NaN)), true) - checkEvaluation(IsNaN(Literal(math.log(-3))), true) - checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) - checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) - checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) - checkEvaluation(IsNaN(Literal(5.5f)), false) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) } test("INSET") {