From 3d40c366892303cd0de8259b31aebe7a748d89e6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 2 Aug 2017 14:18:35 +0800 Subject: [PATCH 1/2] codegen support String and Timestamp type. --- .../spark/sql/catalyst/expressions/misc.scala | 70 +++++++++++++++---- .../expressions/MiscExpressionsSuite.scala | 52 +++++++++++--- 2 files changed, 98 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 17311a837c9db..8a6caf4f7fa39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -274,29 +274,69 @@ case class Trunc(data: Expression, truncExpr: Expression) if (truncFormat == -1) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """) + int ${ev.value} = ${ctx.defaultValue(DateType)};""") } else { val d = data.genCode(ctx) - ev.copy(code = s""" + val dt = ctx.freshName("dt") + val pre = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); - }""") + int ${ev.value} = ${ctx.defaultValue(DateType)};""" + data.dataType match { + case DateType => + ev.copy(code = pre + s""" + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate(${d.value}, $truncFormat); + }""") + case TimestampType => + val ts = ctx.freshName("ts") + ev.copy(code = pre + s""" + String $ts = $dtu.timestampToString(${d.value}); + scala.Option $dt = $dtu.stringToDate(UTF8String.fromString($ts)); + if (!${ev.isNull}) { + ${ev.value} = $dtu.truncDate((Integer)dt.get(), $truncFormat); + }""") + case StringType => + ev.copy(code = pre + s""" + scala.Option $dt = $dtu.stringToDate(${d.value}); + if (!${ev.isNull} && $dt.isDefined()) { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncFormat); + }""") + } } } else { nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { val truncParam = ctx.freshName("truncParam") - s""" - int $truncParam = $dtu.parseTruncLevel($fmt); - if ($truncParam == -1) { - ${ev.isNull} = true; - } else { - ${ev.value} = $dtu.truncDate($dateVal, $truncParam); - } - """ + val dt = ctx.freshName("dt") + val pre = s"int $truncParam = $dtu.parseTruncLevel($fmt);" + data.dataType match { + case DateType => + pre + s""" + if ($truncParam == -1) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate($dateVal, $truncParam); + }""" + case TimestampType => + val ts = ctx.freshName("ts") + pre + s""" + String $ts = $dtu.timestampToString($dateVal); + scala.Option $dt = $dtu.stringToDate(UTF8String.fromString($ts)); + if ($truncParam == -1 || $dt.isEmpty()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam); + }""" + case StringType => + pre + s""" + scala.Option $dt = $dtu.stringToDate($dateVal); + ${ev.value} = ${ctx.defaultValue(DateType)}; + if ($truncParam == -1 || $dt.isEmpty()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam); + }""" + } }) } } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 6af21ca05cd6c..c65bc72f67fc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Date +import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -74,23 +74,57 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("trunc date") { - def test(input: Date, fmt: String, expected: Date): Unit = { + def testDate(input: Date, fmt: String, expected: Date): Unit = { checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)), expected) checkEvaluation( Trunc(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), expected) } - val date = Date.valueOf("2015-07-22") + + def testString(input: String, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, StringType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, StringType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + + def testTimestamp(input: Timestamp, fmt: String, expected: Date): Unit = { + checkEvaluation(Trunc(Literal.create(input, TimestampType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + Trunc(Literal.create(input, TimestampType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + + val dateStr = "2015-07-22" + val date = Date.valueOf(dateStr) + val ts = new Timestamp(date.getTime) + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => - test(date, fmt, Date.valueOf("2015-01-01")) + testDate(date, fmt, Date.valueOf("2015-01-01")) + testString(dateStr, fmt, Date.valueOf("2015-01-01")) + testTimestamp(ts, fmt, Date.valueOf("2015-01-01")) } Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => - test(date, fmt, Date.valueOf("2015-07-01")) + testDate(date, fmt, Date.valueOf("2015-07-01")) + testString(dateStr, fmt, Date.valueOf("2015-07-01")) + testTimestamp(ts, fmt, Date.valueOf("2015-07-01")) } - test(date, "DD", null) - test(date, null, null) - test(null, "MON", null) - test(null, null, null) + testDate(date, "DD", null) + testDate(date, null, null) + testDate(null, "MON", null) + testDate(null, null, null) + + testString(dateStr, "DD", null) + testString(dateStr, null, null) + testString(null, "MON", null) + testString(null, null, null) + + testTimestamp(ts, "DD", null) + testTimestamp(ts, null, null) + testTimestamp(null, "MON", null) + testTimestamp(null, null, null) } } From 931f07de787081cdd6822dbf396ec1b8d205f25e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 2 Aug 2017 20:42:48 +0800 Subject: [PATCH 2/2] Revert trunc(date, format). --- python/pyspark/sql/functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ebe7d572f2b4c..51bb4557e8ee2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1028,12 +1028,12 @@ def to_timestamp(col, format=None): @since(1.5) -def trunc(data, truncParam): +def trunc(date, format): """ - Returns date truncated to the unit specified by the truncParam or + Returns date truncated to the unit specified by the format or numeric truncated by specified decimal places. - :param truncParam: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date and any int for numeric. >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) @@ -1050,7 +1050,7 @@ def trunc(data, truncParam): [Row(zero=1234567891.0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.trunc(_to_java_column(data), truncParam)) + return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) @since(1.5)