Skip to content

Commit

Permalink
Refactor code.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Jun 14, 2017
1 parent c1019c9 commit 3d92a48
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 74 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,17 +1028,17 @@ def to_timestamp(col, format=None):


@since(1.5)
def trunc(date, format=0):
def trunc(date, format):
"""
Returns date truncated to the unit specified by the format or
number truncated by specified decimal places..
number truncated by specified decimal places.
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(to_date(df.d), 'year').alias('year')).collect()
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
[Row(year=datetime.date(1997, 1, 1))]
>>> df.select(trunc(to_date(df.d), 'mon').alias('month')).collect()
>>> df.select(trunc(df.d, 'mon').alias('month')).collect()
[Row(month=datetime.date(1997, 2, 1))]
>>> df = spark.createDataFrame([(1234567891.1234567891,)], ['d'])
>>> df.select(trunc(df.d, 4).alias('positive')).collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,45 +141,57 @@ case class Uuid() extends LeafExpression {
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(data, fmt) - Returns `data` truncated by the format model `fmt`.
If `data` is DateType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`.
If `data` is DoubleType, returns `data` truncated to `fmt` decimal places.
_FUNC_(data[, fmt]) - Returns `data` truncated by the format model `fmt`.
If `data` is DateType/StringType, returns `data` with the time portion of the day truncated to the unit specified by the format model `fmt`.
If `data` is DecimalType/DoubleType, returns `data` truncated to `fmt` decimal places.
""",
extended = """
Examples:
> SELECT _FUNC_('2009-02-12', 'MM');
2009-02-01
2009-02-01.
> SELECT _FUNC_('2015-10-27', 'YEAR');
2015-01-01
> SELECT _FUNC_('2015-10-27');
2015-10-01
> SELECT _FUNC_(1234567891.1234567891, 4);
1234567891.1234
> SELECT _FUNC_(1234567891.1234567891, -4);
1234560000
""")
> SELECT _FUNC_(1234567891.1234567891);
1234567891
""")
// scalastyle:on line.size.limit
case class Trunc(data: Expression, format: Expression = Literal(0))
case class Trunc(data: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

def this(numeric: Expression) = {
this(numeric, Literal(0))
def this(data: Expression) = {
this(data, Literal(
if (data.dataType.isInstanceOf[DecimalType] || data.dataType.isInstanceOf[DoubleType]) {
0
} else {
"MM"
}))
}

override def left: Expression = data
override def right: Expression = format

val isTruncNumber = format.dataType.isInstanceOf[IntegerType]

override def dataType: DataType = data.dataType

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, DateType), TypeCollection(StringType, IntegerType))
Seq(TypeCollection(DateType, StringType, DoubleType, DecimalType),
TypeCollection(StringType, IntegerType))

override def nullable: Boolean = true

override def prettyName: String = "trunc"

private lazy val truncFormat: Int = dataType match {
case doubleType: DoubleType =>
format.eval().asInstanceOf[Int]
case dateType: DateType =>
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
private lazy val truncFormat: Int = if (isTruncNumber) {
format.eval().asInstanceOf[Int]
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}

override def eval(input: InternalRow): Any = {
Expand All @@ -188,83 +200,79 @@ case class Trunc(data: Expression, format: Expression = Literal(0))
if (null == d || null == form) {
null
} else {
dataType match {
case doubleType: DoubleType =>
val scale = if (format.foldable) {
truncFormat
} else {
format.eval().asInstanceOf[Int]
}
BigDecimalUtils.trunc(d.asInstanceOf[Double], scale)
case dateType: DateType =>
val level = if (format.foldable) {
truncFormat
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
if (level == -1) {
// unknown format
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
}
if (isTruncNumber) {
val scale = if (format.foldable) truncFormat else format.eval().asInstanceOf[Int]
data.dataType match {
case DoubleType => BigDecimalUtils.trunc(d.asInstanceOf[Double], scale)
case DecimalType.Fixed(_, _) =>
BigDecimalUtils.trunc(d.asInstanceOf[Decimal].toJavaBigDecimal, scale)
}
} else {
val level = if (format.foldable) {
truncFormat
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
if (level == -1) {
// unknown format
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
}
}
}

}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {

dataType match {
case doubleType: DoubleType =>
val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$")
if (isTruncNumber) {
val bdu = BigDecimalUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
val d = data.genCode(ctx)
ev.copy(code = s"""
if (format.foldable) {
val d = data.genCode(ctx)
ev.copy(code = s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $bdu.trunc(${d.value}, $truncFormat);
}""")
} else {
nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => {
s"${ev.value} = $bdu.trunc($doubleVal, $fmt);"
})
}
case dateType: DateType =>
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
} else {
nullSafeCodeGen(ctx, ev, (doubleVal, fmt) => {
s"${ev.value} = $bdu.trunc($doubleVal, $fmt);"
})
}
} else {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (truncFormat == -1) {
ev.copy(code = s"""
if (format.foldable) {
if (truncFormat == -1) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
} else {
val d = data.genCode(ctx)
ev.copy(code = s"""
} else {
val d = data.genCode(ctx)
ev.copy(code = s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate(${d.value}, $truncFormat);
}""")
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
val form = ctx.freshName("form")
s"""
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
val form = ctx.freshName("form")
s"""
int $form = $dtu.parseTruncLevel($fmt);
if ($form == -1) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $form);
}
"""
})
}
})
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("trunc") {

// numeric
def testTruncNumber(input: Double, fmt: Int, expected: Double): Unit = {
checkEvaluation(Trunc(Literal.create(input, DoubleType),
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/test/resources/sql-tests/inputs/datetime.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ select to_timestamp(null), to_timestamp('2016-12-31 00:12:00'), to_timestamp('20
select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27'), dayofweek(null), dayofweek('1582-10-15 13:10:15');

-- trunc date
select trunc(to_date('2015-07-22'), 'yyyy'), trunc(to_date('2015-07-22'), 'YYYY'),
trunc(to_date('2015-07-22'), 'year'), trunc(to_date('2015-07-22'), 'YEAR'),
trunc(to_date('2015-07-22'), 'yy'), trunc(to_date('2015-07-22'), 'YY');
select trunc(to_date('2015-07-22'), 'month'), trunc(to_date('2015-07-22'), 'MONTH'),
trunc(to_date('2015-07-22'), 'mon'), trunc(to_date('2015-07-22'), 'MON'),
select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'),
trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'),
trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY');
select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'),
trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'),
trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM');
select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null);
32 changes: 31 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/datetime.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 4
-- Number of queries: 7


-- !query 0
Expand Down Expand Up @@ -32,3 +32,33 @@ select dayofweek('2007-02-03'), dayofweek('2009-07-30'), dayofweek('2017-05-27')
struct<dayofweek(CAST(2007-02-03 AS DATE)):int,dayofweek(CAST(2009-07-30 AS DATE)):int,dayofweek(CAST(2017-05-27 AS DATE)):int,dayofweek(CAST(NULL AS DATE)):int,dayofweek(CAST(1582-10-15 13:10:15 AS DATE)):int>
-- !query 3 output
7 5 7 NULL 6


-- !query 4
select trunc('2015-07-22', 'yyyy'), trunc('2015-07-22', 'YYYY'),
trunc('2015-07-22', 'year'), trunc('2015-07-22', 'YEAR'),
trunc(to_date('2015-07-22'), 'yy'), trunc('2015-07-22', 'YY')
-- !query 4 schema
struct<>
-- !query 4 output
java.lang.ClassCastException
org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer


-- !query 5
select trunc('2015-07-22', 'month'), trunc('2015-07-22', 'MONTH'),
trunc('2015-07-22', 'mon'), trunc('2015-07-22', 'MON'),
trunc(to_date('2015-07-22'), 'mm'), trunc(to_date('2015-07-22'), 'MM')
-- !query 5 schema
struct<>
-- !query 5 output
java.lang.ClassCastException
org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer


-- !query 6
select trunc('2015-07-22', 'DD'), trunc('2015-07-22', null), trunc(null, 'MON'), trunc(null, null)
-- !query 6 schema
struct<trunc(2015-07-22, DD):string,trunc(2015-07-22, CAST(NULL AS STRING)):string,trunc(CAST(NULL AS DATE), MON):date,trunc(CAST(NULL AS DATE), CAST(NULL AS STRING)):date>
-- !query 6 output
NULL NULL NULL NULL
19 changes: 18 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/operators.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 51
-- Number of queries: 53


-- !query 0
Expand Down Expand Up @@ -420,3 +420,20 @@ select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, nu
struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double>
-- !query 50 output
1 NULL 0 NULL NULL NULL


-- !query 51
select trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, -4),
trunc(1234567891.1234567891, 4), trunc(1234567891.1234567891, 0), trunc(1234567891.1234567891)
-- !query 51 schema
struct<trunc(1234567891.1234567891, 4):decimal(20,10),trunc(1234567891.1234567891, -4):decimal(20,10),trunc(1234567891.1234567891, 4):decimal(20,10),trunc(1234567891.1234567891, 0):decimal(20,10),trunc(1234567891.1234567891, 0):decimal(20,10)>
-- !query 51 output
1234567891.1234 1234560000 1234567891.1234 1234567891 1234567891


-- !query 52
select trunc(1234567891.1234567891, null), trunc(null, 4), trunc(null, null)
-- !query 52 schema
struct<trunc(1234567891.1234567891, CAST(NULL AS STRING)):decimal(20,10),trunc(CAST(NULL AS DATE), 4):date,trunc(CAST(NULL AS DATE), CAST(NULL AS STRING)):date>
-- !query 52 output
NULL NULL NULL

0 comments on commit 3d92a48

Please sign in to comment.