From 479b31fa046e8402f4f93cdbad5fe93ef1ea570f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 22 Sep 2018 01:08:02 +0800 Subject: [PATCH] Implements by BigDecimal.RoundingMode.DOWN --- .../expressions/mathExpressions.scala | 63 ++--------------- .../spark/sql/catalyst/util/MathUtils.scala | 70 ------------------- .../org/apache/spark/sql/types/Decimal.scala | 1 + .../expressions/MathExpressionsSuite.scala | 4 +- .../resources/sql-tests/inputs/operators.sql | 3 +- .../sql-tests/results/operators.sql.out | 50 ++++++++++++- 6 files changed, 59 insertions(+), 132 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 32fa036747e9d..942a4d0f99a51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter} +import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1264,61 +1264,8 @@ case class BRound(child: Expression, scale: Expression) 1234567891 """) // scalastyle:on line.size.limit -case class Truncate(number: Expression, scale: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - def this(number: Expression) = this(number, Literal(0)) - - override def left: Expression = number - override def right: Expression = scale - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, FloatType, DecimalType), IntegerType) - - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case TypeCheckSuccess => - if (scale.foldable) { - TypeCheckSuccess - } else { - TypeCheckFailure("Only foldable Expression is allowed for scale arguments") - } - case f => f - } - } - - override def dataType: DataType = left.dataType - override def nullable: Boolean = true - override def prettyName: String = "truncate" - - private lazy val scaleV: Any = scale.eval(EmptyRow) - private lazy val _scale: Int = scaleV.asInstanceOf[Int] - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - number.dataType match { - case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], _scale) - case FloatType => MathUtils.trunc(input1.asInstanceOf[Float], _scale) - case DecimalType.Fixed(_, _) => MathUtils.trunc(input1.asInstanceOf[Decimal], _scale) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val mu = MathUtils.getClass.getName.stripSuffix("$") - - val javaType = CodeGenerator.javaType(dataType) - if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = code""" - boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") - } else { - val d = number.genCode(ctx) - ev.copy(code = code""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $mu.trunc(${d.value}, ${_scale}); - }""") - } - } +case class Truncate(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.DOWN, "ROUND_DOWN") + with Serializable with ImplicitCastInputTypes { + def this(child: Expression) = this(child, Literal(0)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala deleted file mode 100644 index 9f00af128666f..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.util - -import java.math.{BigDecimal => JBigDecimal} - -import org.apache.spark.sql.types.Decimal - -object MathUtils { - - /** - * Returns double type input truncated to scale decimal places. - */ - def trunc(input: Double, scale: Int): Double = { - trunc(JBigDecimal.valueOf(input), scale).toDouble - } - - /** - * Returns float type input truncated to scale decimal places. - */ - def trunc(input: Float, scale: Int): Float = { - trunc(JBigDecimal.valueOf(input), scale).toFloat - } - - /** - * Returns decimal type input truncated to scale decimal places. - */ - def trunc(input: Decimal, scale: Int): Decimal = { - trunc(input.toJavaBigDecimal, scale) - } - - /** - * Returns BigDecimal type input truncated to scale decimal places. - */ - def trunc(input: JBigDecimal, scale: Int): Decimal = { - // Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0 - // /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487) - val pow = if (scale >= 0) { - JBigDecimal.valueOf(Math.pow(10, scale)) - } else { - JBigDecimal.valueOf(Math.pow(10, Math.abs(scale))) - } - - val truncatedValue = if (scale > 0) { - val longValue = input.multiply(pow).longValue() - JBigDecimal.valueOf(longValue).divide(pow) - } else if (scale == 0) { - JBigDecimal.valueOf(input.longValue()) - } else { - val longValue = input.divide(pow).longValue() - JBigDecimal.valueOf(longValue).multiply(pow) - } - - Decimal(truncatedValue) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb202045..b0ffb816817f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -413,6 +413,7 @@ object Decimal { val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + val ROUND_DOWN = BigDecimal.RoundingMode.DOWN /** Maximum number of decimal digits an Int can represent */ val MAX_INT_DIGITS = 9 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 320a2e821d8ae..459075cf81939 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -685,7 +685,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testFloat(0.123F, 0, 0F) testFloat(null, null, null) testFloat(null, 0, null) - testFloat(1D, null, null) + testFloat(1F, null, null) testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234)) testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000)) @@ -694,6 +694,6 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testDecimal(Decimal(0.123), 0, Decimal(0)) testDecimal(null, null, null) testDecimal(null, 0, null) - testDecimal(1D, null, null) + testDecimal(Decimal(1), null, null) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index d14a7fdc9d2ce..3c8f30eaa8b91 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -94,7 +94,8 @@ select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(n select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); -- truncate -select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4); +select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4); +select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4); select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4); select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4); select truncate(cast(1234567891.1234567891 as long), 9.03); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fd1d0db9e3f78..9fff062490e1d 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 61 -- !query 0 @@ -452,3 +452,51 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) struct -- !query 54 output NULL NULL + + +-- !query 55 +select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4) +-- !query 55 schema +struct +-- !query 55 output +1.23456E9 1.234567891E9 1.2345678911234E9 + + +-- !query 56 +select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4) +-- !query 56 schema +struct +-- !query 56 output +1.23456E9 1.23456794E9 1.23456794E9 + + +-- !query 57 +select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4) +-- !query 57 schema +struct +-- !query 57 output +1234560000 1234567891 1234567891 + + +-- !query 58 +select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4) +-- !query 58 schema +struct +-- !query 58 output +1234560000 1234567891 1234567891 + + +-- !query 59 +select truncate(cast(1234567891.1234567891 as long), 9.03) +-- !query 59 schema +struct +-- !query 59 output +1234567891 + + +-- !query 60 +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)) +-- !query 60 schema +struct +-- !query 60 output +1.234567891E9 1.23456794E9 1234567891