From 3d92a48a54c5bf0222c032fc5c205ea1792630a2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 14 Jun 2017 23:17:38 +0800 Subject: [PATCH] Refactor code. --- python/pyspark/sql/functions.py | 8 +- .../spark/sql/catalyst/expressions/misc.scala | 132 ++++++++++-------- .../expressions/MiscExpressionsSuite.scala | 1 - .../resources/sql-tests/inputs/datetime.sql | 10 +- .../sql-tests/results/datetime.sql.out | 32 ++++- .../sql-tests/results/operators.sql.out | 19 ++- 6 files changed, 128 insertions(+), 74 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 58daa6204e979..7bb46f44848da 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1028,17 +1028,17 @@ def to_timestamp(col, format=None): @since(1.5) -def trunc(date, format=0): +def trunc(date, format): """ Returns date truncated to the unit specified by the format or - number truncated by specified decimal places.. + number truncated by specified decimal places. :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) - >>> df.select(trunc(to_date(df.d), 'year').alias('year')).collect() + >>> df.select(trunc(df.d, 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] - >>> df.select(trunc(to_date(df.d), 'mon').alias('month')).collect() + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() [Row(month=datetime.date(1997, 2, 1))] >>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d']) >>> df.select(trunc(df.d, 4).alias('positive')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 88152bc2bb5b5..ae4ffec971473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -141,45 +141,57 @@ case class Uuid() extends LeafExpression { // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(data, fmt) - Returns `data` truncated by the format model `fmt`. - If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. - If `data` is DoubleType, returns `data` truncated to `fmt` decimal places. + _FUNC_(data[, fmt]) - Returns `data` truncated by the format model `fmt`. + If `data` is DateType/StringType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`. + If `data` is DecimalType/DoubleType, returns `data` truncated to `fmt` decimal places. """, extended = """ Examples: > SELECT _FUNC_('2009-02-12', 'MM'); - 2009-02-01 + 2009-02-01. > SELECT _FUNC_('2015-10-27', 'YEAR'); 2015-01-01 + > SELECT _FUNC_('2015-10-27'); + 2015-10-01 > SELECT _FUNC_(1234567891.1234567891, 4); 1234567891.1234 > SELECT _FUNC_(1234567891.1234567891, -4); 1234560000 - """) + > SELECT _FUNC_(1234567891.1234567891); + 1234567891 + """) // scalastyle:on line.size.limit -case class Trunc(data: Expression, format: Expression = Literal(0)) +case class Trunc(data: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { - def this(numeric: Expression) = { - this(numeric, Literal(0)) + def this(data: Expression) = { + this(data, Literal( + if (data.dataType.isInstanceOf[DecimalType] || data.dataType.isInstanceOf[DoubleType]) { + 0 + } else { + "MM" + })) } override def left: Expression = data override def right: Expression = format + val isTruncNumber = format.dataType.isInstanceOf[IntegerType] + override def dataType: DataType = data.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DateType), TypeCollection(StringType, IntegerType)) + Seq(TypeCollection(DateType, StringType, DoubleType, DecimalType), + TypeCollection(StringType, IntegerType)) override def nullable: Boolean = true + override def prettyName: String = "trunc" - private lazy val truncFormat: Int = dataType match { - case doubleType: DoubleType => - format.eval().asInstanceOf[Int] - case dateType: DateType => - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + private lazy val truncFormat: Int = if (isTruncNumber) { + format.eval().asInstanceOf[Int] + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) } override def eval(input: InternalRow): Any = { @@ -188,73 +200,70 @@ case class Trunc(data: Expression, format: Expression = Literal(0)) if (null == d || null == form) { null } else { - dataType match { - case doubleType: DoubleType => - val scale = if (format.foldable) { - truncFormat - } else { - format.eval().asInstanceOf[Int] - } - BigDecimalUtils.trunc(d.asInstanceOf[Double], scale) - case dateType: DateType => - val level = if (format.foldable) { - truncFormat - } else { - DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) - } - if (level == -1) { - // unknown format - null - } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], level) - } + if (isTruncNumber) { + val scale = if (format.foldable) truncFormat else format.eval().asInstanceOf[Int] + data.dataType match { + case DoubleType => BigDecimalUtils.trunc(d.asInstanceOf[Double], scale) + case DecimalType.Fixed(_, _) => + BigDecimalUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale) + } + } else { + val level = if (format.foldable) { + truncFormat + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (level == -1) { + // unknown format + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) + } } } - } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - dataType match { - case doubleType: DoubleType => - val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$") + if (isTruncNumber) { + val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$") - if (format.foldable) { - val d = data.genCode(ctx) - ev.copy(code = s""" + if (format.foldable) { + val d = data.genCode(ctx) + ev.copy(code = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $bdu.trunc(${d.value}, $truncFormat); }""") - } else { - nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => { - s"${ev.value} = $bdu.trunc($doubleVal, $fmt);" - }) - } - case dateType: DateType => - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + } else { + nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => { + s"${ev.value} = $bdu.trunc($doubleVal, $fmt);" + }) + } + } else { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - if (format.foldable) { - if (truncFormat == -1) { - ev.copy(code = s""" + if (format.foldable) { + if (truncFormat == -1) { + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") - } else { - val d = data.genCode(ctx) - ev.copy(code = s""" + } else { + val d = data.genCode(ctx) + ev.copy(code = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); }""") - } - } else { - nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { - val form = ctx.freshName("form") - s""" + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" int $form = $dtu.parseTruncLevel($fmt); if ($form == -1) { ${ev.isNull} = true; @@ -262,9 +271,8 @@ case class Trunc(data: Expression, format: Expression = Literal(0)) ${ev.value} = $dtu.truncDate($dateVal, $form); } """ - }) - } + }) + } } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 0797f8cbc8046..dcf58526b757a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -47,7 +47,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("trunc") { - // numeric def testTruncNumber(input: Double, fmt: Int, expected: Double): Unit = { checkEvaluation(Trunc(Literal.create(input, DoubleType), diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 1bc072720c0d9..f9f8351f08ded 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -10,10 +10,10 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20 select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15'); -- trunc date -select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'), - trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'), - trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY'); -select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'), - trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'), +select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), + trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY'); +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM'); select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index a28b91c77324b..9f0e3176ac6ae 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 7 -- !query 0 @@ -32,3 +32,33 @@ select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27') struct -- !query 3 output 7 5 7 NULL 6 + + +-- !query 4 +select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'), + trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'), + trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY') +-- !query 4 schema +struct<> +-- !query 4 output +java.lang.ClassCastException +org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer + + +-- !query 5 +select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'), + trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'), + trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM') +-- !query 5 schema +struct<> +-- !query 5 output +java.lang.ClassCastException +org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer + + +-- !query 6 +select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null) +-- !query 6 schema +struct +-- !query 6 output +NULL NULL NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 51ccf764d952f..83dededf83764 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 51 +-- Number of queries: 53 -- !query 0 @@ -420,3 +420,20 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double> -- !query 50 output 1 NULL 0 NULL NULL NULL + + +-- !query 51 +select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4), + trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891) +-- !query 51 schema +struct +-- !query 51 output +1234567891.1234 1234560000 1234567891.1234 1234567891 1234567891 + + +-- !query 52 +select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null) +-- !query 52 schema +struct +-- !query 52 output +NULL NULL NULL