Skip to content

Commit

Permalink
null to nan in Math Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 17, 2015
1 parent eba6a1a commit 188be51
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ object FunctionRegistry {
expression[Log]("ln"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
expression[Log2]("log2"),
expression[UnaryMinus]("negative"),
expression[Pi]("pi"),
expression[Log2]("log2"),
expression[Pow]("pow"),
expression[Pow]("power"),
expression[Pmod]("pmod"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
override def toString: String = s"$name($child)"

protected override def nullSafeEval(input: Any): Any = {
val result = f(input.asInstanceOf[Double])
if (result.isNaN) null else result
f(input.asInstanceOf[Double])
}

// name of function in java.lang.Math
def funcName: String = name.toLowerCase

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
${ev.primitive} = java.lang.Math.${funcName}($eval);
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
})
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
}
}

Expand All @@ -101,8 +93,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
override def dataType: DataType = DoubleType

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
if (result.isNaN) null else result
f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Expand Down Expand Up @@ -404,14 +395,7 @@ case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG")
case class Log2(child: Expression)
extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2);
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
})
defineCodeGen(ctx, ev, c => s"java.lang.Math.log($c) / java.lang.Math.log(2)")
}
}

Expand Down Expand Up @@ -578,27 +562,18 @@ case class Atan2(left: Expression, right: Expression)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
// With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
if (result.isNaN) null else result
math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)")
}
}

case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(math.pow, "POWER") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)")
}
}

Expand Down Expand Up @@ -701,16 +676,11 @@ case class Logarithm(left: Expression, right: Expression)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val logCode = if (left.isInstanceOf[EulerNumber]) {
if (left.isInstanceOf[EulerNumber]) {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
}
logCode + s"""
if (Double.isNaN(${ev.primitive})) {
${ev.isNull} = true;
}
"""
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import com.google.common.math.LongMath

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.types._


Expand All @@ -46,19 +50,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param c expression
* @param f The functions in scala.math or elsewhere used to generate expected results
* @param domain The set of values to run the function with
* @param expectNull Whether the given values should return null or not
* @param expectNaN Whether the given values should eval to NaN or not
* @tparam T Generic type for primitives
* @tparam U Generic type for the output of the given function `f`
*/
private def testUnary[T, U](
c: Expression => Expression,
f: T => U,
domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
expectNull: Boolean = false,
expectNaN: Boolean = false,
evalType: DataType = DoubleType): Unit = {
if (expectNull) {
if (expectNaN) {
domain.foreach { value =>
checkEvaluation(c(Literal(value)), null, EmptyRow)
checkNaN(c(Literal(value)), EmptyRow)
}
} else {
domain.foreach { value =>
Expand All @@ -74,15 +78,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
* @param c The DataFrame function
* @param f The functions in scala.math
* @param domain The set of values to run the function with
* @param expectNaN Whether the given values should eval to NaN or not
*/
private def testBinary(
c: (Expression, Expression) => Expression,
f: (Double, Double) => Double,
domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
expectNull: Boolean = false): Unit = {
if (expectNull) {
expectNaN: Boolean = false): Unit = {
if (expectNaN) {
domain.foreach { case (v1, v2) =>
checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null))
checkNaN(c(Literal(v1), Literal(v2)), EmptyRow)
}
} else {
domain.foreach { case (v1, v2) =>
Expand Down Expand Up @@ -112,6 +117,62 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
}

private def checkNaN(
expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
checkNaNWithoutCodegen(expression, inputRow)
checkNaNWithGeneratedProjection(expression, inputRow)
checkNaNWithOptimization(expression, inputRow)
}

private def checkNaNWithoutCodegen(
expression: Expression,
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if (!actual.asInstanceOf[Double].isNaN) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
s"expected: NaN")
}
}


private def checkNaNWithGeneratedProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {

val plan = try {
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
} catch {
case e: Throwable =>
val ctx = GenerateProjection.newCodeGenContext()
val evaluated = expression.gen(ctx)
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code}
|$e
""".stripMargin)
}

val actual = plan(inputRow).apply(0)
if (!actual.asInstanceOf[Double].isNaN) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN")
}
}

private def checkNaNWithOptimization(
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
val optimizedPlan = DefaultOptimizer.execute(plan)
checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
}

test("e") {
testLeaf(EulerNumber, math.E)
}
Expand All @@ -126,7 +187,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("asin") {
testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1))
testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true)
testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true)
}

test("sinh") {
Expand All @@ -139,7 +200,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("acos") {
testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1))
testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true)
testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true)
}

test("cosh") {
Expand Down Expand Up @@ -205,17 +266,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("log") {
testUnary(Log, math.log, (0 to 20).map(_ * 0.1))
testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true)
testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNaN = true)
}

test("log10") {
testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1))
testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true)
testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNaN = true)
}

test("log1p") {
testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1))
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true)
testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNaN = true)
}

test("bin") {
Expand All @@ -238,21 +299,21 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("log2") {
def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
testUnary(Log2, f, (0 to 20).map(_ * 0.1))
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true)
testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNaN = true)
}

test("sqrt") {
testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true)
testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true)

checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow)
checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow)
checkNaN(Sqrt(Literal(-1.0)), EmptyRow)
checkNaN(Sqrt(Literal(-1.5)), EmptyRow)
}

test("pow") {
testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true)
}

test("shift left") {
Expand Down Expand Up @@ -345,14 +406,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))

// negative input should yield null output
checkEvaluation(
// negative input should yield NaN output
checkNaN(
Logarithm(Literal(-1.0), Literal(1.0)),
null,
create_row(null))
checkEvaluation(
checkNaN(
Logarithm(Literal(1.0), Literal(-1.0)),
null,
create_row(null))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ class MathExpressionsSuite extends QueryTest {
nnDoubleData.select(c('b)),
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity)
)
} else {
checkAnswer(
nnDoubleData.select(c('b)),
(1 to 10).map(n => Row(null))
)
}

checkAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Spark SQL use Long for TimestampType, lose the precision under 1us
"timestamp_1",
"timestamp_2",
"timestamp_udf"
"timestamp_udf",

// Unlike Hive, we do support log base in (0, 1.0], therefore disable this
"udf7"
)

/**
Expand Down Expand Up @@ -816,19 +819,18 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf2",
"udf5",
"udf6",
// "udf7", turn this on after we figure out null vs nan vs infinity
"udf8",
"udf9",
"udf_10_trims",
"udf_E",
"udf_PI",
"udf_abs",
// "udf_acos", turn this on after we figure out null vs nan vs infinity
"udf_acos",
"udf_add",
"udf_array",
"udf_array_contains",
"udf_ascii",
// "udf_asin", turn this on after we figure out null vs nan vs infinity
"udf_asin",
"udf_atan",
"udf_avg",
"udf_bigint",
Expand Down Expand Up @@ -915,7 +917,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_regexp_replace",
"udf_repeat",
"udf_rlike",
// "udf_round", turn this on after we figure out null vs nan vs infinity
"udf_round",
"udf_round_3",
"udf_rpad",
"udf_rtrim",
Expand Down

0 comments on commit 188be51

Please sign in to comment.