Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9169][SQL] Improve unit test coverage for null expressions. #7490

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types._


/**
* An expression that is evaluated to the first non-null input.
*
* {{{
* coalesce(1, 2) => 1
* coalesce(null, 1, 2) => 1
* coalesce(null, null, 2) => 2
* coalesce(null, null, null) => null
* }}}
*/
case class Coalesce(children: Seq[Expression]) extends Expression {

/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
Expand Down Expand Up @@ -70,6 +81,62 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
}


/**
* Evaluates to `true` if it's NaN or null
*/
case class IsNaN(child: Expression) extends UnaryExpression
with Predicate with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))

override def nullable: Boolean = false

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
true
} else {
child.dataType match {
case DoubleType => value.asInstanceOf[Double].isNaN
case FloatType => value.asInstanceOf[Float].isNaN
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
child.dataType match {
case FloatType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Float.isNaN(${eval.primitive});
}
"""
case DoubleType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Double.isNaN(${eval.primitive});
}
"""
}
}
}


/**
* An expression that is evaluated to true if the input is null.
*/
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

Expand All @@ -83,13 +150,14 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
ev.primitive = eval.isNull
eval.code
}

override def toString: String = s"IS NULL $child"
}


/**
* An expression that is evaluated to true if the input is not null.
*/
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
override def toString: String = s"IS NOT NULL $child"

override def eval(input: InternalRow): Any = {
child.eval(input) != null
Expand All @@ -103,12 +171,13 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
}
}


/**
* A predicate that is evaluated to be true if there are at least `n` non-null values.
* A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
*/
case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def foldable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"

private[this] val childrenArray = children.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -120,56 +119,6 @@ case class InSet(child: Expression, hset: Set[Any])
}
}

/**
* Evaluates to `true` if it's NaN or null
*/
case class IsNaN(child: Expression) extends UnaryExpression
with Predicate with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))

override def nullable: Boolean = false

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
true
} else {
child.dataType match {
case DoubleType => value.asInstanceOf[Double].isNaN
case FloatType => value.asInstanceOf[Float].isNaN
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = child.gen(ctx)
child.dataType match {
case FloatType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Float.isNaN(${eval.primitive});
}
"""
case DoubleType =>
s"""
${eval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (${eval.isNull}) {
${ev.primitive} = true;
} else {
${ev.primitive} = Double.isNaN(${eval.primitive});
}
"""
}
}
}

case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,52 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{BooleanType, StringType, ShortType}
import org.apache.spark.sql.types._

class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("null checking") {
val row = create_row("^Ba*n", null, true, null)
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)

checkEvaluation(c1.isNull, false, row)
checkEvaluation(c1.isNotNull, true, row)

checkEvaluation(c2.isNull, true, row)
checkEvaluation(c2.isNotNull, false, row)

checkEvaluation(Literal.create(1, ShortType).isNull, false)
checkEvaluation(Literal.create(1, ShortType).isNotNull, true)

checkEvaluation(Literal.create(null, ShortType).isNull, true)
checkEvaluation(Literal.create(null, ShortType).isNotNull, false)
def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
testFunc(false, BooleanType)
testFunc(1.toByte, ByteType)
testFunc(1.toShort, ShortType)
testFunc(1, IntegerType)
testFunc(1L, LongType)
testFunc(1.0F, FloatType)
testFunc(1.0, DoubleType)
testFunc(Decimal(1.5), DecimalType.Unlimited)
testFunc(new java.sql.Date(10), DateType)
testFunc(new java.sql.Timestamp(10), TimestampType)
testFunc("abcd", StringType)
}

checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row)
checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)
test("isnull and isnotnull") {
testAllTypes { (value: Any, tpe: DataType) =>
checkEvaluation(IsNull(Literal.create(value, tpe)), false)
checkEvaluation(IsNotNull(Literal.create(value, tpe)), true)
checkEvaluation(IsNull(Literal.create(null, tpe)), true)
checkEvaluation(IsNotNull(Literal.create(null, tpe)), false)
}
}

checkEvaluation(
If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row)
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row)
checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal.create(false, BooleanType),
Literal.create("a", StringType), Literal.create("b", StringType)), "b", row)
test("IsNaN") {
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
checkEvaluation(IsNaN(Literal(math.log(-3))), true)
checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
checkEvaluation(IsNaN(Literal(5.5f)), false)
}

checkEvaluation(c1 in (c1, c2), true, row)
checkEvaluation(
Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row)
checkEvaluation(
Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row)
test("coalesce") {
testAllTypes { (value: Any, tpe: DataType) =>
val lit = Literal.create(value, tpe)
val nullLit = Literal.create(null, tpe)
checkEvaluation(Coalesce(Seq(nullLit)), null)
checkEvaluation(Coalesce(Seq(lit)), value)
checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
true)
}

test("IsNaN") {
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
checkEvaluation(IsNaN(Literal(math.log(-3))), true)
checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
checkEvaluation(IsNaN(Literal(5.5f)), false)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
}

test("INSET") {
Expand Down