Skip to content

Commit

Permalink
[SPARK-9169][SQL] Improve unit test coverage for null expressions.
Browse files Browse the repository at this point in the history
Author: Reynold Xin <[email protected]>

Closes #7490 from rxin/unit-test-null-funcs and squashes the following commits:

7b276f0 [Reynold Xin] Move isNaN.
8307287 [Reynold Xin] [SPARK-9169][SQL] Improve unit test coverage for null expressions.
  • Loading branch information
rxin committed Jul 18, 2015
1 parent b9ef7ac commit fba3f5b
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types._


/**
* An expression that is evaluated to the first non-null input.
*
* {{{
* coalesce(1, 2) => 1
* coalesce(null, 1, 2) => 1
* coalesce(null, null, 2) => 2
* coalesce(null, null, null) => null
* }}}
*/
case class Coalesce(children: Seq[Expression]) extends Expression {

/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
Expand Down Expand Up @@ -70,6 +81,62 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
}


/**
* Evaluates to `true` if it's NaN or null
*/
case class IsNaN(child: Expression) extends UnaryExpression
with Predicate with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))

override def nullable: Boolean = false

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
true
} else {
child.dataType match {
case DoubleType => value.asInstanceOf[Double].isNaN
case FloatType => value.asInstanceOf[Float].isNaN
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
child.dataType match {
case FloatType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Float.isNaN(${eval.primitive});
}
"""
case DoubleType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Double.isNaN(${eval.primitive});
}
"""
}
}
}


/**
* An expression that is evaluated to true if the input is null.
*/
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

Expand All @@ -83,13 +150,14 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
ev.primitive = eval.isNull
eval.code
}

override def toString: String = s"IS NULL $child"
}


/**
* An expression that is evaluated to true if the input is not null.
*/
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
override def toString: String = s"IS NOT NULL $child"

override def eval(input: InternalRow): Any = {
child.eval(input) != null
Expand All @@ -103,12 +171,13 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
}
}


/**
* A predicate that is evaluated to be true if there are at least `n` non-null values.
* A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
*/
case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def foldable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"

private[this] val childrenArray = children.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -120,56 +119,6 @@ case class InSet(child: Expression, hset: Set[Any])
}
}

/**
* Evaluates to `true` if it's NaN or null
*/
case class IsNaN(child: Expression) extends UnaryExpression
with Predicate with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))

override def nullable: Boolean = false

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
true
} else {
child.dataType match {
case DoubleType => value.asInstanceOf[Double].isNaN
case FloatType => value.asInstanceOf[Float].isNaN
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
child.dataType match {
case FloatType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Float.isNaN(${eval.primitive});
}
"""
case DoubleType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Double.isNaN(${eval.primitive});
}
"""
}
}
}

case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,52 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{BooleanType, StringType, ShortType}
import org.apache.spark.sql.types._

class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("null checking") {
val row = create_row("^Ba*n", null, true, null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)

checkEvaluation(c1.isNull, false, row)
checkEvaluation(c1.isNotNull, true, row)

checkEvaluation(c2.isNull, true, row)
checkEvaluation(c2.isNotNull, false, row)

checkEvaluation(Literal.create(1, ShortType).isNull, false)
checkEvaluation(Literal.create(1, ShortType).isNotNull, true)

checkEvaluation(Literal.create(null, ShortType).isNull, true)
checkEvaluation(Literal.create(null, ShortType).isNotNull, false)
def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
testFunc(false, BooleanType)
testFunc(1.toByte, ByteType)
testFunc(1.toShort, ShortType)
testFunc(1, IntegerType)
testFunc(1L, LongType)
testFunc(1.0F, FloatType)
testFunc(1.0, DoubleType)
testFunc(Decimal(1.5), DecimalType.Unlimited)
testFunc(new java.sql.Date(10), DateType)
testFunc(new java.sql.Timestamp(10), TimestampType)
testFunc("abcd", StringType)
}

checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row)
checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)
test("isnull and isnotnull") {
testAllTypes { (value: Any, tpe: DataType) =>
checkEvaluation(IsNull(Literal.create(value, tpe)), false)
checkEvaluation(IsNotNull(Literal.create(value, tpe)), true)
checkEvaluation(IsNull(Literal.create(null, tpe)), true)
checkEvaluation(IsNotNull(Literal.create(null, tpe)), false)
}
}

checkEvaluation(
If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row)
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row)
checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal.create(false, BooleanType),
Literal.create("a", StringType), Literal.create("b", StringType)), "b", row)
test("IsNaN") {
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
checkEvaluation(IsNaN(Literal(math.log(-3))), true)
checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
checkEvaluation(IsNaN(Literal(5.5f)), false)
}

checkEvaluation(c1 in (c1, c2), true, row)
checkEvaluation(
Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row)
checkEvaluation(
Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row)
test("coalesce") {
testAllTypes { (value: Any, tpe: DataType) =>
val lit = Literal.create(value, tpe)
val nullLit = Literal.create(null, tpe)
checkEvaluation(Coalesce(Seq(nullLit)), null)
checkEvaluation(Coalesce(Seq(lit)), value)
checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
true)
}

test("IsNaN") {
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
checkEvaluation(IsNaN(Literal(math.log(-3))), true)
checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
checkEvaluation(IsNaN(Literal(5.5f)), false)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
}

test("INSET") {
Expand Down

0 comments on commit fba3f5b

Please sign in to comment.