diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 1522bcae08d17..98c67084642e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ + +/** + * An expression that is evaluated to the first non-null input. + * + * {{{ + * coalesce(1, 2) => 1 + * coalesce(null, 1, 2) => 1 + * coalesce(null, null, 2) => 2 + * coalesce(null, null, null) => null + * }}} + */ case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ @@ -70,6 +81,62 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } + +/** + * Evaluates to `true` if it's NaN or null + */ +case class IsNaN(child: Expression) extends UnaryExpression + with Predicate with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + true + } else { + child.dataType match { + case DoubleType => value.asInstanceOf[Double].isNaN + case FloatType => value.asInstanceOf[Float].isNaN + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + child.dataType match { + case FloatType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Float.isNaN(${eval.primitive}); + } + """ + case DoubleType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Double.isNaN(${eval.primitive}); + } + """ + } + } +} + + +/** + * An expression that is evaluated to true if the input is null. + */ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -83,13 +150,14 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.primitive = eval.isNull eval.code } - - override def toString: String = s"IS NULL $child" } + +/** + * An expression that is evaluated to true if the input is not null. + */ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false - override def toString: String = s"IS NOT NULL $child" override def eval(input: InternalRow): Any = { child.eval(input) != null @@ -103,12 +171,13 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } + /** - * A predicate that is evaluated to be true if there are at least `n` non-null values. + * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false - override def foldable: Boolean = false + override def foldable: Boolean = children.forall(_.foldable) override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2751c8e75f357..bddd2a9eccfc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -120,56 +119,6 @@ case class InSet(child: Expression, hset: Set[Any]) } } -/** - * Evaluates to `true` if it's NaN or null - */ -case class IsNaN(child: Expression) extends UnaryExpression - with Predicate with ImplicitCastInputTypes { - - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - if (value == null) { - true - } else { - child.dataType match { - case DoubleType => value.asInstanceOf[Double].isNaN - case FloatType => value.asInstanceOf[Float].isNaN - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - child.dataType match { - case FloatType => - s""" - ${eval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (${eval.isNull}) { - ${ev.primitive} = true; - } else { - ${ev.primitive} = Float.isNaN(${eval.primitive}); - } - """ - case DoubleType => - s""" - ${eval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (${eval.isNull}) { - ${ev.primitive} = true; - } else { - ${ev.primitive} = Double.isNaN(${eval.primitive}); - } - """ - } - } -} case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ccdada8b56f83..765cc7a969b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -18,48 +18,52 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BooleanType, StringType, ShortType} +import org.apache.spark.sql.types._ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("null checking") { - val row = create_row("^Ba*n", null, true, null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.boolean.at(3) - - checkEvaluation(c1.isNull, false, row) - checkEvaluation(c1.isNotNull, true, row) - - checkEvaluation(c2.isNull, true, row) - checkEvaluation(c2.isNotNull, false, row) - - checkEvaluation(Literal.create(1, ShortType).isNull, false) - checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - - checkEvaluation(Literal.create(null, ShortType).isNull, true) - checkEvaluation(Literal.create(null, ShortType).isNotNull, false) + def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = { + testFunc(false, BooleanType) + testFunc(1.toByte, ByteType) + testFunc(1.toShort, ShortType) + testFunc(1, IntegerType) + testFunc(1L, LongType) + testFunc(1.0F, FloatType) + testFunc(1.0, DoubleType) + testFunc(Decimal(1.5), DecimalType.Unlimited) + testFunc(new java.sql.Date(10), DateType) + testFunc(new java.sql.Timestamp(10), TimestampType) + testFunc("abcd", StringType) + } - checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + test("isnull and isnotnull") { + testAllTypes { (value: Any, tpe: DataType) => + checkEvaluation(IsNull(Literal.create(value, tpe)), false) + checkEvaluation(IsNotNull(Literal.create(value, tpe)), true) + checkEvaluation(IsNull(Literal.create(null, tpe)), true) + checkEvaluation(IsNotNull(Literal.create(null, tpe)), false) + } + } - checkEvaluation( - If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) - checkEvaluation(If(c3, c1, c2), "^Ba*n", row) - checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), - Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) + test("IsNaN") { + checkEvaluation(IsNaN(Literal(Double.NaN)), true) + checkEvaluation(IsNaN(Literal(Float.NaN)), true) + checkEvaluation(IsNaN(Literal(math.log(-3))), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) + checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) + checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) + checkEvaluation(IsNaN(Literal(5.5f)), false) + } - checkEvaluation(c1 in (c1, c2), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) + test("coalesce") { + testAllTypes { (value: Any, tpe: DataType) => + val lit = Literal.create(value, tpe) + val nullLit = Literal.create(null, tpe) + checkEvaluation(Coalesce(Seq(nullLit)), null) + checkEvaluation(Coalesce(Seq(lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 052abc51af5fd..2173a0c25c645 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) - } - test("IsNaN") { - checkEvaluation(IsNaN(Literal(Double.NaN)), true) - checkEvaluation(IsNaN(Literal(Float.NaN)), true) - checkEvaluation(IsNaN(Literal(math.log(-3))), true) - checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) - checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) - checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) - checkEvaluation(IsNaN(Literal(5.5f)), false) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) } test("INSET") {