Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/SPARK-20754-trunc' into SPARK-20754
Browse files Browse the repository at this point in the history
-trunc
  • Loading branch information
wangyum committed Oct 28, 2017
2 parents ea72fe0 + 931f07d commit b59a2df
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 28 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,12 +1081,12 @@ def to_timestamp(col, format=None):


@since(1.5)
def trunc(data, truncParam):
def trunc(date, format):
"""
Returns date truncated to the unit specified by the truncParam or
Returns date truncated to the unit specified by the format or
numeric truncated by specified decimal places.
:param truncParam: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date
and any int for numeric.
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
Expand All @@ -1103,7 +1103,7 @@ def trunc(data, truncParam):
[Row(zero=1234567891.0)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.trunc(_to_java_column(data), truncParam))
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))


@since(1.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,29 +274,69 @@ case class Trunc(data: Expression, truncExpr: Expression)
if (truncFormat == -1) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
""")
int ${ev.value} = ${ctx.defaultValue(DateType)};""")
} else {
val d = data.genCode(ctx)
ev.copy(code = s"""
val dt = ctx.freshName("dt")
val pre = s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate(${d.value}, $truncFormat);
}""")
int ${ev.value} = ${ctx.defaultValue(DateType)};"""
data.dataType match {
case DateType =>
ev.copy(code = pre + s"""
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate(${d.value}, $truncFormat);
}""")
case TimestampType =>
val ts = ctx.freshName("ts")
ev.copy(code = pre + s"""
String $ts = $dtu.timestampToString(${d.value});
scala.Option<SQLDate> $dt = $dtu.stringToDate(UTF8String.fromString($ts));
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate((Integer)dt.get(), $truncFormat);
}""")
case StringType =>
ev.copy(code = pre + s"""
scala.Option<SQLDate> $dt = $dtu.stringToDate(${d.value});
if (!${ev.isNull} && $dt.isDefined()) {
${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncFormat);
}""")
}
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
val truncParam = ctx.freshName("truncParam")
s"""
int $truncParam = $dtu.parseTruncLevel($fmt);
if ($truncParam == -1) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $truncParam);
}
"""
val dt = ctx.freshName("dt")
val pre = s"int $truncParam = $dtu.parseTruncLevel($fmt);"
data.dataType match {
case DateType =>
pre + s"""
if ($truncParam == -1) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $truncParam);
}"""
case TimestampType =>
val ts = ctx.freshName("ts")
pre + s"""
String $ts = $dtu.timestampToString($dateVal);
scala.Option<SQLDate> $dt = $dtu.stringToDate(UTF8String.fromString($ts));
if ($truncParam == -1 || $dt.isEmpty()) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam);
}"""
case StringType =>
pre + s"""
scala.Option<SQLDate> $dt = $dtu.stringToDate($dateVal);
${ev.value} = ${ctx.defaultValue(DateType)};
if ($truncParam == -1 || $dt.isEmpty()) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate((Integer)$dt.get(), $truncParam);
}"""
}
})
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.Date
import java.sql.{Date, Timestamp}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -74,23 +74,57 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("trunc date") {
def test(input: Date, fmt: String, expected: Date): Unit = {
def testDate(input: Date, fmt: String, expected: Date): Unit = {
checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)),
expected)
checkEvaluation(
Trunc(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
expected)
}
val date = Date.valueOf("2015-07-22")

def testString(input: String, fmt: String, expected: Date): Unit = {
checkEvaluation(Trunc(Literal.create(input, StringType), Literal.create(fmt, StringType)),
expected)
checkEvaluation(
Trunc(Literal.create(input, StringType), NonFoldableLiteral.create(fmt, StringType)),
expected)
}

def testTimestamp(input: Timestamp, fmt: String, expected: Date): Unit = {
checkEvaluation(Trunc(Literal.create(input, TimestampType), Literal.create(fmt, StringType)),
expected)
checkEvaluation(
Trunc(Literal.create(input, TimestampType), NonFoldableLiteral.create(fmt, StringType)),
expected)
}

val dateStr = "2015-07-22"
val date = Date.valueOf(dateStr)
val ts = new Timestamp(date.getTime)

Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
test(date, fmt, Date.valueOf("2015-01-01"))
testDate(date, fmt, Date.valueOf("2015-01-01"))
testString(dateStr, fmt, Date.valueOf("2015-01-01"))
testTimestamp(ts, fmt, Date.valueOf("2015-01-01"))
}
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
test(date, fmt, Date.valueOf("2015-07-01"))
testDate(date, fmt, Date.valueOf("2015-07-01"))
testString(dateStr, fmt, Date.valueOf("2015-07-01"))
testTimestamp(ts, fmt, Date.valueOf("2015-07-01"))
}
test(date, "DD", null)
test(date, null, null)
test(null, "MON", null)
test(null, null, null)
testDate(date, "DD", null)
testDate(date, null, null)
testDate(null, "MON", null)
testDate(null, null, null)

testString(dateStr, "DD", null)
testString(dateStr, null, null)
testString(null, "MON", null)
testString(null, null, null)

testTimestamp(ts, "DD", null)
testTimestamp(ts, null, null)
testTimestamp(null, "MON", null)
testTimestamp(null, null, null)
}
}

0 comments on commit b59a2df

Please sign in to comment.