Skip to content

Commit

Permalink
Implements by BigDecimal.RoundingMode.DOWN
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Sep 21, 2018
1 parent 87cea0b commit 479b31f
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter}
import org.apache.spark.sql.catalyst.util.NumberConverter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -1264,61 +1264,8 @@ case class BRound(child: Expression, scale: Expression)
1234567891
""")
// scalastyle:on line.size.limit
case class Truncate(number: Expression, scale: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

def this(number: Expression) = this(number, Literal(0))

override def left: Expression = number
override def right: Expression = scale

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, FloatType, DecimalType), IntegerType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckSuccess =>
if (scale.foldable) {
TypeCheckSuccess
} else {
TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
}
case f => f
}
}

override def dataType: DataType = left.dataType
override def nullable: Boolean = true
override def prettyName: String = "truncate"

private lazy val scaleV: Any = scale.eval(EmptyRow)
private lazy val _scale: Int = scaleV.asInstanceOf[Int]

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
number.dataType match {
case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], _scale)
case FloatType => MathUtils.trunc(input1.asInstanceOf[Float], _scale)
case DecimalType.Fixed(_, _) => MathUtils.trunc(input1.asInstanceOf[Decimal], _scale)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mu = MathUtils.getClass.getName.stripSuffix("$")

val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val d = number.genCode(ctx)
ev.copy(code = code"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $mu.trunc(${d.value}, ${_scale});
}""")
}
}
case class Truncate(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.DOWN, "ROUND_DOWN")
with Serializable with ImplicitCastInputTypes {
def this(child: Expression) = this(child, Literal(0))
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ object Decimal {
val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
val ROUND_DOWN = BigDecimal.RoundingMode.DOWN

/** Maximum number of decimal digits an Int can represent */
val MAX_INT_DIGITS = 9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testFloat(0.123F, 0, 0F)
testFloat(null, null, null)
testFloat(null, 0, null)
testFloat(1D, null, null)
testFloat(1F, null, null)

testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234))
testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000))
Expand All @@ -694,6 +694,6 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testDecimal(Decimal(0.123), 0, Decimal(0))
testDecimal(null, null, null)
testDecimal(null, 0, null)
testDecimal(1D, null, null)
testDecimal(Decimal(1), null, null)
}
}
3 changes: 2 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/operators.sql
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(n
select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint));

-- truncate
select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4);
select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4);
select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4);
select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4);
select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4);
select truncate(cast(1234567891.1234567891 as long), 9.03);
Expand Down
50 changes: 49 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/operators.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 55
-- Number of queries: 61


-- !query 0
Expand Down Expand Up @@ -452,3 +452,51 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint)
struct<pmod(CAST(3.13 AS DECIMAL(10,0)), CAST(0 AS DECIMAL(10,0))):decimal(10,0),pmod(CAST(2 AS SMALLINT), CAST(0 AS SMALLINT)):smallint>
-- !query 54 output
NULL NULL


-- !query 55
select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4)
-- !query 55 schema
struct<truncate(CAST(1234567891.1234567891 AS DOUBLE), -4):double,truncate(CAST(1234567891.1234567891 AS DOUBLE), 0):double,truncate(CAST(1234567891.1234567891 AS DOUBLE), 4):double>
-- !query 55 output
1.23456E9 1.234567891E9 1.2345678911234E9


-- !query 56
select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4)
-- !query 56 schema
struct<truncate(CAST(1234567891.1234567891 AS FLOAT), -4):float,truncate(CAST(1234567891.1234567891 AS FLOAT), 0):float,truncate(CAST(1234567891.1234567891 AS FLOAT), 4):float>
-- !query 56 output
1.23456E9 1.23456794E9 1.23456794E9


-- !query 57
select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4)
-- !query 57 schema
struct<truncate(CAST(1234567891.1234567891 AS DECIMAL(10,0)), -4):decimal(10,-4),truncate(CAST(1234567891.1234567891 AS DECIMAL(10,0)), 0):decimal(10,0),truncate(CAST(1234567891.1234567891 AS DECIMAL(10,0)), 4):decimal(10,0)>
-- !query 57 output
1234560000 1234567891 1234567891


-- !query 58
select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4)
-- !query 58 schema
struct<truncate(CAST(1234567891.1234567891 AS BIGINT), -4):bigint,truncate(CAST(1234567891.1234567891 AS BIGINT), 0):bigint,truncate(CAST(1234567891.1234567891 AS BIGINT), 4):bigint>
-- !query 58 output
1234560000 1234567891 1234567891


-- !query 59
select truncate(cast(1234567891.1234567891 as long), 9.03)
-- !query 59 schema
struct<truncate(CAST(1234567891.1234567891 AS BIGINT), CAST(9.03 AS INT)):bigint>
-- !query 59 output
1234567891


-- !query 60
select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal))
-- !query 60 schema
struct<truncate(CAST(1234567891.1234567891 AS DOUBLE), 0):double,truncate(CAST(1234567891.1234567891 AS FLOAT), 0):float,truncate(CAST(1234567891.1234567891 AS DECIMAL(10,0)), 0):decimal(10,0)>
-- !query 60 output
1.234567891E9 1.23456794E9 1234567891

0 comments on commit 479b31f

Please sign in to comment.