Skip to content

Commit

Permalink
Support timestamp and string type.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Jun 27, 2017
1 parent 7fee61b commit f8b1f44
Show file tree
Hide file tree
Showing 12 changed files with 136 additions and 92 deletions.
4 changes: 2 additions & 2 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1382,8 +1382,8 @@ test_that("column functions", {
c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy")
c21 <- posexplode_outer(c) + explode_outer(c)
c22 <- not(c)
c23 <- trunc(to_date(c), "year") + trunc(to_date(c), "yyyy") + trunc(to_date(c), "yy") +
trunc(to_date(c), "month") + trunc(to_date(c), "mon") + trunc(to_date(c), "mm")
c23 <- trunc(c, "year") + trunc(c, "yyyy") + trunc(c, "yy") +
trunc(c, "month") + trunc(c, "mon") + trunc(c, "mm")

# Test if base::is.nan() is exposed
expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE))
Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,17 +1028,18 @@ def to_timestamp(col, format=None):


@since(1.5)
def trunc(data, format):
def trunc(data, truncParam):
"""
Returns date truncated to the unit specified by the format or
number truncated by specified decimal places.
Returns date truncated to the unit specified by the truncParam or
numeric truncated by specified decimal places.
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
:param truncParam: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' for date
and any int for numeric.
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(to_date(df.d), 'year').alias('year')).collect()
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
[Row(year=datetime.date(1997, 1, 1))]
>>> df.select(trunc(to_date(df.d), 'mon').alias('month')).collect()
>>> df.select(trunc(df.d, 'mon').alias('month')).collect()
[Row(month=datetime.date(1997, 2, 1))]
>>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d'])
>>> df.select(trunc(df.d, 4).alias('positive')).collect()
Expand All @@ -1049,7 +1050,7 @@ def trunc(data, format):
[Row(zero=1234567891.0)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.trunc(_to_java_column(data), format))
return Column(sc._jvm.functions.trunc(_to_java_column(data), truncParam))


@since(1.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.UUID

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{BigDecimalUtils, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, MathUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -141,9 +141,9 @@ case class Uuid() extends LeafExpression {
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(data[, fmt]) - Returns `data` truncated by the format model `fmt`.
If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`.
If `data` is DecimalType/DoubleType, returns `data` truncated to `fmt` decimal places.
_FUNC_(data[, trunc_param]) - Returns `data` truncated by the format model `trunc_param`.
If `data` is date/timestamp/string type, returns `data` with the time portion of the day truncated to the unit specified by the format model `trunc_param`. If `trunc_param` is omitted, then the default `trunc_param` is 'MM'.
If `data` is decimal/double type, returns `data` truncated to `trunc_param` decimal places. If `trunc_param` is omitted, then the default `trunc_param` is 0.
""",
extended = """
Examples:
Expand All @@ -161,68 +161,87 @@ case class Uuid() extends LeafExpression {
1234567891
""")
// scalastyle:on line.size.limit
case class Trunc(data: Expression, format: Expression)
case class Trunc(data: Expression, truncExpr: Expression)
extends BinaryExpression with ExpectsInputTypes {

def this(data: Expression) = {
this(data, Literal(if (data.dataType.isInstanceOf[DateType]) "MM" else 0))
this(data, Literal(
if (data.dataType.isInstanceOf[DateType] ||
data.dataType.isInstanceOf[TimestampType] ||
data.dataType.isInstanceOf[StringType]) {
"MM"
} else {
0
})
)
}

override def left: Expression = data
override def right: Expression = format

override def dataType: DataType = data.dataType

override def inputTypes: Seq[AbstractDataType] = dataType match {
case NullType => Seq(dataType, TypeCollection(StringType, IntegerType))
case DateType => Seq(dataType, StringType)
case DoubleType | DecimalType.Fixed(_, _) => Seq(dataType, IntegerType)
case _ => Seq(TypeCollection(DateType, DoubleType, DecimalType),
TypeCollection(StringType, IntegerType))
override def right: Expression = truncExpr

private val isTruncNumber = truncExpr.dataType.isInstanceOf[IntegerType]
private val isTruncDate = truncExpr.dataType.isInstanceOf[StringType]

override def dataType: DataType = if (isTruncDate) DateType else data.dataType

override def inputTypes: Seq[AbstractDataType] = data.dataType match {
case NullType =>
Seq(dataType, TypeCollection(StringType, IntegerType))
case DateType | TimestampType | StringType =>
Seq(TypeCollection(DateType, TimestampType, StringType), StringType)
case DoubleType | DecimalType.Fixed(_, _) =>
Seq(TypeCollection(DoubleType, DecimalType), IntegerType)
case _ =>
Seq(TypeCollection(DateType, StringType, TimestampType, DoubleType, DecimalType),
TypeCollection(StringType, IntegerType))
}

override def nullable: Boolean = true

override def prettyName: String = "trunc"

private val isTruncNumber =
(dataType.isInstanceOf[DoubleType] || dataType.isInstanceOf[DecimalType]) &&
format.dataType.isInstanceOf[IntegerType]
private val isTruncDate =
dataType.isInstanceOf[DateType] && format.dataType.isInstanceOf[StringType]

private lazy val truncFormat: Int = if (isTruncNumber) {
format.eval().asInstanceOf[Int]
truncExpr.eval().asInstanceOf[Int]
} else if (isTruncDate) {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
DateTimeUtils.parseTruncLevel(truncExpr.eval().asInstanceOf[UTF8String])
} else {
0
}

override def eval(input: InternalRow): Any = {
val d = data.eval(input)
val form = format.eval()
if (null == d || null == form) {
val truncParam = truncExpr.eval()
if (null == d || null == truncParam) {
null
} else {
if (isTruncNumber) {
val scale = if (format.foldable) truncFormat else format.eval().asInstanceOf[Int]
val scale = if (truncExpr.foldable) truncFormat else truncExpr.eval().asInstanceOf[Int]
data.dataType match {
case DoubleType => BigDecimalUtils.trunc(d.asInstanceOf[Double], scale)
case DoubleType => MathUtils.trunc(d.asInstanceOf[Double], scale)
case DecimalType.Fixed(_, _) =>
BigDecimalUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale)
MathUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale)
}
} else if (isTruncDate) {
val level = if (format.foldable) {
val level = if (truncExpr.foldable) {
truncFormat
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
DateTimeUtils.parseTruncLevel(truncExpr.eval().asInstanceOf[UTF8String])
}
if (level == -1) {
// unknown format
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
data.dataType match {
case DateType => DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
case TimestampType =>
val ts = DateTimeUtils.timestampToString(d.asInstanceOf[Long])
val dt = DateTimeUtils.stringToDate(UTF8String.fromString(ts))
if (dt.isDefined) DateTimeUtils.truncDate(dt.get, level) else null
case StringType =>
val dt = DateTimeUtils.stringToDate(d.asInstanceOf[UTF8String])
if (dt.isDefined) DateTimeUtils.truncDate(dt.get, level) else null
}
}
} else {
null
Expand All @@ -233,9 +252,9 @@ case class Trunc(data: Expression, format: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {

if (isTruncNumber) {
val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$")
val bdu = MathUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (truncExpr.foldable) {
val d = data.genCode(ctx)
ev.copy(code = s"""
${d.code}
Expand All @@ -245,12 +264,13 @@ case class Trunc(data: Expression, format: Expression)
${ev.value} = $bdu.trunc(${d.value}, $truncFormat);
}""")
} else {
nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => s"${ev.value} = $bdu.trunc($doubleVal, $fmt);")
nullSafeCodeGen(ctx, ev, (doubleVal, truncParam) =>
s"${ev.value} = $bdu.trunc($doubleVal, $truncParam);")
}
} else if (isTruncDate) {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (truncExpr.foldable) {
if (truncFormat == -1) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
Expand All @@ -268,19 +288,19 @@ case class Trunc(data: Expression, format: Expression)
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
val form = ctx.freshName("form")
val truncParam = ctx.freshName("truncParam")
s"""
int $form = $dtu.parseTruncLevel($fmt);
if ($form == -1) {
int $truncParam = $dtu.parseTruncLevel($fmt);
if ($truncParam == -1) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $form);
${ev.value} = $dtu.truncDate($dateVal, $truncParam);
}
"""
})
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => s"${ev.isNull} = true;")
nullSafeCodeGen(ctx, ev, (dataVal, fmt) => s"${ev.isNull} = true;")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JBigDecimal}
/**
* Helper functions for BigDecimal.
*/
object BigDecimalUtils {
object MathUtils {

/**
* Returns double type input truncated to scale decimal places.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(evaluate(Uuid()) !== evaluate(Uuid()))
}

test("trunc") {
// numeric
def testTruncNumber(input: Double, fmt: Int, expected: Double): Unit = {
test("trunc numeric") {
def test(input: Double, fmt: Int, expected: Double): Unit = {
checkEvaluation(Trunc(Literal.create(input, DoubleType),
Literal.create(fmt, IntegerType)),
expected)
Expand All @@ -57,9 +56,11 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
expected)
}

testTruncNumber(1234567891.1234567891, 4, 1234567891.1234)
testTruncNumber(1234567891.1234567891, -4, 1234560000)
testTruncNumber(1234567891.1234567891, 0, 1234567891)
test(1234567891.1234567891, 4, 1234567891.1234)
test(1234567891.1234567891, -4, 1234560000)
test(1234567891.1234567891, 0, 1234567891)
test(0.123, -1, 0)
test(0.123, 0, 0)

checkEvaluation(Trunc(Literal.create(1D, DoubleType),
NonFoldableLiteral.create(null, IntegerType)),
Expand All @@ -70,9 +71,10 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Trunc(Literal.create(null, DoubleType),
NonFoldableLiteral.create(null, IntegerType)),
null)
}

// date
def testTruncDate(input: Date, fmt: String, expected: Date): Unit = {
test("trunc date") {
def test(input: Date, fmt: String, expected: Date): Unit = {
checkEvaluation(Trunc(Literal.create(input, DateType), Literal.create(fmt, StringType)),
expected)
checkEvaluation(
Expand All @@ -81,14 +83,14 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
val date = Date.valueOf("2015-07-22")
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
testTruncDate(date, fmt, Date.valueOf("2015-01-01"))
test(date, fmt, Date.valueOf("2015-01-01"))
}
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
testTruncDate(date, fmt, Date.valueOf("2015-07-01"))
test(date, fmt, Date.valueOf("2015-07-01"))
}
testTruncDate(date, "DD", null)
testTruncDate(date, null, null)
testTruncDate(null, "MON", null)
testTruncDate(null, null, null)
test(date, "DD", null)
test(date, null, null)
test(null, "MON", null)
test(null, null, null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.util

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.BigDecimalUtils._
import org.apache.spark.sql.catalyst.util.MathUtils._

class BigDecimalUtilsSuite extends SparkFunSuite {
class MathUtilsSuite extends SparkFunSuite {

test("trunc number") {
val bg = 1234567891.1234567891D
Expand Down
12 changes: 7 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2068,15 +2068,17 @@ object functions {
def radians(columnName: String): Column = radians(Column(columnName))

/**
* returns number truncated by specified decimal places.
*
* @param scale: 4. -4, 0
* Returns numeric truncated by specified decimal places.
* If scale is positive or 0, numeric is truncated to the absolute value of scale number
* of places to the right of the decimal point.
* If scale is negative, numeric is truncated to the absolute value of scale + 1 number
* of places to the left of the decimal point.
*
* @group math_funcs
* @since 2.3.0
*/
def trunc(db: Column, scale: Int = 0): Column = withExpr {
Trunc(db.expr, Literal(scale))
def trunc(numeric: Column, scale: Int): Column = withExpr {
Trunc(numeric.expr, Literal(scale))
}

//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
11 changes: 6 additions & 5 deletions sql/core/src/test/resources/sql-tests/inputs/datetime.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20
select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15');

-- trunc date
select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'),
trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'),
select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'),
trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'),
trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY');
select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'),
trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'),
select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'),
trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'),
trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM');
select trunc(to_date('2015-07-22'), 'DD'), trunc(to_date('2015-07-22'), null);
select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null);
select trunc('2015-07-2200', 'DD'), trunc('123', null);
select trunc(null, 'MON'), trunc(null, null);
1 change: 1 addition & 0 deletions sql/core/src/test/resources/sql-tests/inputs/operators.sql
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null);
select trunc(1234567891.1234567891, 'yyyy');
select trunc(to_date('2015-07-22'), 4);
select trunc('2015-07-22', 4);
select trunc(false, 4);
Loading

0 comments on commit f8b1f44

Please sign in to comment.