From c7156943a2a32ba57e67aa6d8fa7035a09847e07 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 19 Sep 2018 18:13:22 +0800 Subject: [PATCH] Add float type. --- .../expressions/mathExpressions.scala | 47 +++++++++----- .../spark/sql/catalyst/util/MathUtils.scala | 24 ++++++- .../expressions/MathExpressionsSuite.scala | 63 +++++++++++++------ .../org/apache/spark/sql/functions.scala | 22 ++++--- .../resources/sql-tests/inputs/operators.sql | 3 +- .../sql-tests/results/operators.sql.out | 10 ++- 6 files changed, 119 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a63f7abe095c3..32fa036747e9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1267,43 +1267,58 @@ case class BRound(child: Expression, scale: Expression) case class Truncate(number: Expression, scale: Expression) extends BinaryExpression with ImplicitCastInputTypes { + def this(number: Expression) = this(number, Literal(0)) + override def left: Expression = number override def right: Expression = scale override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType), IntegerType) + Seq(TypeCollection(DoubleType, FloatType, DecimalType), IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess + } else { + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + } + case f => f + } + } override def dataType: DataType = left.dataType + override def nullable: Boolean = true + override def prettyName: String = "truncate" - private lazy val foldableTruncScale: Int = scale.eval().asInstanceOf[Int] + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val truncScale = if (scale.foldable) { - foldableTruncScale - } else { - scale.eval().asInstanceOf[Int] - } number.dataType match { - case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], truncScale) - case DecimalType.Fixed(_, _) => - MathUtils.trunc(input1.asInstanceOf[Decimal].toJavaBigDecimal, truncScale) + case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], _scale) + case FloatType => MathUtils.trunc(input1.asInstanceOf[Float], _scale) + case DecimalType.Fixed(_, _) => MathUtils.trunc(input1.asInstanceOf[Decimal], _scale) } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mu = MathUtils.getClass.getName.stripSuffix("$") - if (scale.foldable) { + + val javaType = CodeGenerator.javaType(dataType) + if (scaleV == null) { // if scale is null, no need to eval its child at all + ev.copy(code = code""" + boolean ${ev.isNull} = true; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") + } else { val d = number.genCode(ctx) ev.copy(code = code""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $mu.trunc(${d.value}, $foldableTruncScale); + ${ev.value} = $mu.trunc(${d.value}, ${_scale}); }""") - } else { - nullSafeCodeGen(ctx, ev, (doubleVal, truncParam) => - s"${ev.value} = $mu.trunc($doubleVal, $truncParam);") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 912d99d957f8b..9f00af128666f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -18,19 +18,35 @@ package org.apache.spark.sql.catalyst.util import java.math.{BigDecimal => JBigDecimal} +import org.apache.spark.sql.types.Decimal + object MathUtils { /** * Returns double type input truncated to scale decimal places. */ def trunc(input: Double, scale: Int): Double = { - trunc(JBigDecimal.valueOf(input), scale).doubleValue() + trunc(JBigDecimal.valueOf(input), scale).toDouble + } + + /** + * Returns float type input truncated to scale decimal places. + */ + def trunc(input: Float, scale: Int): Float = { + trunc(JBigDecimal.valueOf(input), scale).toFloat + } + + /** + * Returns decimal type input truncated to scale decimal places. + */ + def trunc(input: Decimal, scale: Int): Decimal = { + trunc(input.toJavaBigDecimal, scale) } /** * Returns BigDecimal type input truncated to scale decimal places. */ - def trunc(input: JBigDecimal, scale: Int): JBigDecimal = { + def trunc(input: JBigDecimal, scale: Int): Decimal = { // Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0 // /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487) val pow = if (scale >= 0) { @@ -39,7 +55,7 @@ object MathUtils { JBigDecimal.valueOf(Math.pow(10, Math.abs(scale))) } - if (scale > 0) { + val truncatedValue = if (scale > 0) { val longValue = input.multiply(pow).longValue() JBigDecimal.valueOf(longValue).divide(pow) } else if (scale == 0) { @@ -48,5 +64,7 @@ object MathUtils { val longValue = input.divide(pow).longValue() JBigDecimal.valueOf(longValue).multiply(pow) } + + Decimal(truncatedValue) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index d2f2bb8e65c40..320a2e821d8ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -646,29 +646,54 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Truncate number") { - def testTruncate(input: Double, fmt: Int, expected: Double): Unit = { + assert(Truncate(Literal.create(123.123, DoubleType), + NonFoldableLiteral.create(1, IntegerType)).checkInputDataTypes().isFailure) + assert(Truncate(Literal.create(123.123, DoubleType), + Literal.create(1, IntegerType)).checkInputDataTypes().isSuccess) + + def testDouble(input: Any, scale: Any, expected: Any): Unit = { checkEvaluation(Truncate(Literal.create(input, DoubleType), - Literal.create(fmt, IntegerType)), + Literal.create(scale, IntegerType)), expected) - checkEvaluation(Truncate(Literal.create(input, DoubleType), - NonFoldableLiteral.create(fmt, IntegerType)), + } + + def testFloat(input: Any, scale: Any, expected: Any): Unit = { + checkEvaluation(Truncate(Literal.create(input, FloatType), + Literal.create(scale, IntegerType)), + expected) + } + + def testDecimal(input: Any, scale: Any, expected: Any): Unit = { + checkEvaluation(Truncate(Literal.create(input, DecimalType.DoubleDecimal), + Literal.create(scale, IntegerType)), expected) } - testTruncate(1234567891.1234567891, 4, 1234567891.1234) - testTruncate(1234567891.1234567891, -4, 1234560000) - testTruncate(1234567891.1234567891, 0, 1234567891) - testTruncate(0.123, -1, 0) - testTruncate(0.123, 0, 0) - - checkEvaluation(Truncate(Literal.create(1D, DoubleType), - NonFoldableLiteral.create(null, IntegerType)), - null) - checkEvaluation(Truncate(Literal.create(null, DoubleType), - NonFoldableLiteral.create(1, IntegerType)), - null) - checkEvaluation(Truncate(Literal.create(null, DoubleType), - NonFoldableLiteral.create(null, IntegerType)), - null) + testDouble(1234567891.1234567891D, 4, 1234567891.1234D) + testDouble(1234567891.1234567891D, -4, 1234560000D) + testDouble(1234567891.1234567891D, 0, 1234567891D) + testDouble(0.123D, -1, 0D) + testDouble(0.123D, 0, 0D) + testDouble(null, null, null) + testDouble(null, 0, null) + testDouble(1D, null, null) + + testFloat(1234567891.1234567891F, 4, 1234567891.1234F) + testFloat(1234567891.1234567891F, -4, 1234560000F) + testFloat(1234567891.1234567891F, 0, 1234567891F) + testFloat(0.123F, -1, 0F) + testFloat(0.123F, 0, 0F) + testFloat(null, null, null) + testFloat(null, 0, null) + testFloat(1D, null, null) + + testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234)) + testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000)) + testDecimal(Decimal(1234567891.1234567891), 0, Decimal(1234567891)) + testDecimal(Decimal(0.123), -1, Decimal(0)) + testDecimal(Decimal(0.123), 0, Decimal(0)) + testDecimal(null, null, null) + testDecimal(null, 0, null) + testDecimal(1D, null, null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 521d5710a9591..6e2df69d5d874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2215,21 +2215,23 @@ object functions { def radians(columnName: String): Column = radians(Column(columnName)) /** - * Returns number truncated to the unit specified by the scale. + * Returns the value of the column `e` truncated to 0 places. * - * For example, `truncate(1234567891.1234567891, 4)` returns 1234567891.1234 - * - * @param number The number to be truncated - * @param scale: A scale used to truncate number + * @group math_funcs + * @since 2.4.0 + */ + def truncate(e: Column): Column = truncate(e, 0) + + /** + * Returns the value of column `e` truncated to the unit specified by the scale. + * If scale is omitted, then the value of column `e` is truncated to 0 places. + * Scale can be negative to truncate (make zero) scale digits left of the decimal point. * - * @return The number truncated to scale decimal places. - * If scale is omitted, then number is truncated to 0 places. - * scale can be negative to truncate (make zero) scale digits left of the decimal point. * @group math_funcs * @since 2.4.0 */ - def truncate(number: Column, scale: Int): Column = withExpr { - Truncate(number.expr, Literal(scale)) + def truncate(e: Column, scale: Int): Column = withExpr { + Truncate(e.expr, Literal(scale)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 70b43c48c45d5..39efa356a5024 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -101,4 +101,5 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4); select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4); select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4); -select truncate(cast(1234567891.1234567891 as long), 9.03) +select truncate(cast(1234567891.1234567891 as long), 9.03); +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 3464d7e0f84a4..d275901933777 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 63 +-- Number of queries: 64 -- !query 0 @@ -516,3 +516,11 @@ select truncate(cast(1234567891.1234567891 as long), 9.03) struct -- !query 62 output 1.234567891E9 + + +-- !query 63 +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)) +-- !query 63 schema +struct +-- !query 63 output +1.234567891E9 1.23456794E9 1234567891