Skip to content

Commit

Permalink
Add float type.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Sep 19, 2018
1 parent bf7103a commit c715694
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1267,43 +1267,58 @@ case class BRound(child: Expression, scale: Expression)
case class Truncate(number: Expression, scale: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

def this(number: Expression) = this(number, Literal(0))

override def left: Expression = number
override def right: Expression = scale

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, DecimalType), IntegerType)
Seq(TypeCollection(DoubleType, FloatType, DecimalType), IntegerType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckSuccess =>
if (scale.foldable) {
TypeCheckSuccess
} else {
TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
}
case f => f
}
}

override def dataType: DataType = left.dataType
override def nullable: Boolean = true
override def prettyName: String = "truncate"

private lazy val foldableTruncScale: Int = scale.eval().asInstanceOf[Int]
private lazy val scaleV: Any = scale.eval(EmptyRow)
private lazy val _scale: Int = scaleV.asInstanceOf[Int]

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val truncScale = if (scale.foldable) {
foldableTruncScale
} else {
scale.eval().asInstanceOf[Int]
}
number.dataType match {
case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], truncScale)
case DecimalType.Fixed(_, _) =>
MathUtils.trunc(input1.asInstanceOf[Decimal].toJavaBigDecimal, truncScale)
case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], _scale)
case FloatType => MathUtils.trunc(input1.asInstanceOf[Float], _scale)
case DecimalType.Fixed(_, _) => MathUtils.trunc(input1.asInstanceOf[Decimal], _scale)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mu = MathUtils.getClass.getName.stripSuffix("$")
if (scale.foldable) {

val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val d = number.genCode(ctx)
ev.copy(code = code"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $mu.trunc(${d.value}, $foldableTruncScale);
${ev.value} = $mu.trunc(${d.value}, ${_scale});
}""")
} else {
nullSafeCodeGen(ctx, ev, (doubleVal, truncParam) =>
s"${ev.value} = $mu.trunc($doubleVal, $truncParam);")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,35 @@ package org.apache.spark.sql.catalyst.util

import java.math.{BigDecimal => JBigDecimal}

import org.apache.spark.sql.types.Decimal

object MathUtils {

/**
* Returns double type input truncated to scale decimal places.
*/
def trunc(input: Double, scale: Int): Double = {
trunc(JBigDecimal.valueOf(input), scale).doubleValue()
trunc(JBigDecimal.valueOf(input), scale).toDouble
}

/**
* Returns float type input truncated to scale decimal places.
*/
def trunc(input: Float, scale: Int): Float = {
trunc(JBigDecimal.valueOf(input), scale).toFloat
}

/**
* Returns decimal type input truncated to scale decimal places.
*/
def trunc(input: Decimal, scale: Int): Decimal = {
trunc(input.toJavaBigDecimal, scale)
}

/**
* Returns BigDecimal type input truncated to scale decimal places.
*/
def trunc(input: JBigDecimal, scale: Int): JBigDecimal = {
def trunc(input: JBigDecimal, scale: Int): Decimal = {
// Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0
// /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487)
val pow = if (scale >= 0) {
Expand All @@ -39,7 +55,7 @@ object MathUtils {
JBigDecimal.valueOf(Math.pow(10, Math.abs(scale)))
}

if (scale > 0) {
val truncatedValue = if (scale > 0) {
val longValue = input.multiply(pow).longValue()
JBigDecimal.valueOf(longValue).divide(pow)
} else if (scale == 0) {
Expand All @@ -48,5 +64,7 @@ object MathUtils {
val longValue = input.divide(pow).longValue()
JBigDecimal.valueOf(longValue).multiply(pow)
}

Decimal(truncatedValue)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -646,29 +646,54 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("Truncate number") {
def testTruncate(input: Double, fmt: Int, expected: Double): Unit = {
assert(Truncate(Literal.create(123.123, DoubleType),
NonFoldableLiteral.create(1, IntegerType)).checkInputDataTypes().isFailure)
assert(Truncate(Literal.create(123.123, DoubleType),
Literal.create(1, IntegerType)).checkInputDataTypes().isSuccess)

def testDouble(input: Any, scale: Any, expected: Any): Unit = {
checkEvaluation(Truncate(Literal.create(input, DoubleType),
Literal.create(fmt, IntegerType)),
Literal.create(scale, IntegerType)),
expected)
checkEvaluation(Truncate(Literal.create(input, DoubleType),
NonFoldableLiteral.create(fmt, IntegerType)),
}

def testFloat(input: Any, scale: Any, expected: Any): Unit = {
checkEvaluation(Truncate(Literal.create(input, FloatType),
Literal.create(scale, IntegerType)),
expected)
}

def testDecimal(input: Any, scale: Any, expected: Any): Unit = {
checkEvaluation(Truncate(Literal.create(input, DecimalType.DoubleDecimal),
Literal.create(scale, IntegerType)),
expected)
}

testTruncate(1234567891.1234567891, 4, 1234567891.1234)
testTruncate(1234567891.1234567891, -4, 1234560000)
testTruncate(1234567891.1234567891, 0, 1234567891)
testTruncate(0.123, -1, 0)
testTruncate(0.123, 0, 0)

checkEvaluation(Truncate(Literal.create(1D, DoubleType),
NonFoldableLiteral.create(null, IntegerType)),
null)
checkEvaluation(Truncate(Literal.create(null, DoubleType),
NonFoldableLiteral.create(1, IntegerType)),
null)
checkEvaluation(Truncate(Literal.create(null, DoubleType),
NonFoldableLiteral.create(null, IntegerType)),
null)
testDouble(1234567891.1234567891D, 4, 1234567891.1234D)
testDouble(1234567891.1234567891D, -4, 1234560000D)
testDouble(1234567891.1234567891D, 0, 1234567891D)
testDouble(0.123D, -1, 0D)
testDouble(0.123D, 0, 0D)
testDouble(null, null, null)
testDouble(null, 0, null)
testDouble(1D, null, null)

testFloat(1234567891.1234567891F, 4, 1234567891.1234F)
testFloat(1234567891.1234567891F, -4, 1234560000F)
testFloat(1234567891.1234567891F, 0, 1234567891F)
testFloat(0.123F, -1, 0F)
testFloat(0.123F, 0, 0F)
testFloat(null, null, null)
testFloat(null, 0, null)
testFloat(1D, null, null)

testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234))
testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000))
testDecimal(Decimal(1234567891.1234567891), 0, Decimal(1234567891))
testDecimal(Decimal(0.123), -1, Decimal(0))
testDecimal(Decimal(0.123), 0, Decimal(0))
testDecimal(null, null, null)
testDecimal(null, 0, null)
testDecimal(1D, null, null)
}
}
22 changes: 12 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2215,21 +2215,23 @@ object functions {
def radians(columnName: String): Column = radians(Column(columnName))

/**
* Returns number truncated to the unit specified by the scale.
* Returns the value of the column `e` truncated to 0 places.
*
* For example, `truncate(1234567891.1234567891, 4)` returns 1234567891.1234
*
* @param number The number to be truncated
* @param scale: A scale used to truncate number
* @group math_funcs
* @since 2.4.0
*/
def truncate(e: Column): Column = truncate(e, 0)

/**
* Returns the value of column `e` truncated to the unit specified by the scale.
* If scale is omitted, then the value of column `e` is truncated to 0 places.
* Scale can be negative to truncate (make zero) scale digits left of the decimal point.
*
* @return The number truncated to scale decimal places.
* If scale is omitted, then number is truncated to 0 places.
* scale can be negative to truncate (make zero) scale digits left of the decimal point.
* @group math_funcs
* @since 2.4.0
*/
def truncate(number: Column, scale: Int): Column = withExpr {
Truncate(number.expr, Literal(scale))
def truncate(e: Column, scale: Int): Column = withExpr {
Truncate(e.expr, Literal(scale))
}

//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/operators.sql
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,5 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint)
select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4);
select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4);
select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4);
select truncate(cast(1234567891.1234567891 as long), 9.03)
select truncate(cast(1234567891.1234567891 as long), 9.03);
select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal));
10 changes: 9 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/operators.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 63
-- Number of queries: 64


-- !query 0
Expand Down Expand Up @@ -516,3 +516,11 @@ select truncate(cast(1234567891.1234567891 as long), 9.03)
struct<truncate(CAST(CAST(1234567891.1234567891 AS BIGINT) AS DOUBLE), CAST(9.03 AS INT)):double>
-- !query 62 output
1.234567891E9


-- !query 63
select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal))
-- !query 63 schema
struct<truncate(CAST(1234567891.1234567891 AS DOUBLE), 0):double,truncate(CAST(1234567891.1234567891 AS FLOAT), 0):float,truncate(CAST(1234567891.1234567891 AS DECIMAL(10,0)), 0):decimal(10,0)>
-- !query 63 output
1.234567891E9 1.23456794E9 1234567891

0 comments on commit c715694

Please sign in to comment.