Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23900][SQL] format_number support user specifed format as argument #21010

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,8 @@ case class Encode(value: Expression, charset: Expression)
Examples:
> SELECT _FUNC_(12332.123456, 4);
12,332.1235
> SELECT _FUNC_(12332.123456, '##################.###');
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update ExpressionDescription too.

12332.123
""")
case class FormatNumber(x: Expression, d: Expression)
extends BinaryExpression with ExpectsInputTypes {
Expand All @@ -2030,14 +2032,20 @@ case class FormatNumber(x: Expression, d: Expression)
override def right: Expression = d
override def dataType: DataType = StringType
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(NumericType, TypeCollection(IntegerType, StringType))

private val defaultFormat = "#,###,###,###,###,###,##0"

// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
// This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after
// serialization (numberFormat has not been updated for dValue = 0).
@transient
private var lastDValue: Option[Int] = None
private var lastDIntValue: Option[Int] = None

@transient
private var lastDStringValue: Option[String] = None

// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
Expand All @@ -2050,33 +2058,49 @@ case class FormatNumber(x: Expression, d: Expression)
private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))

override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
val dValue = dObject.asInstanceOf[Int]
if (dValue < 0) {
return null
}

lastDValue match {
case Some(last) if last == dValue =>
// use the current pattern
case _ =>
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length)
pattern.append("#,###,###,###,###,###,##0")

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
right.dataType match {
case IntegerType =>
val dValue = dObject.asInstanceOf[Int]
if (dValue < 0) {
return null
}

lastDValue = Some(dValue)
lastDIntValue match {
case Some(last) if last == dValue =>
// use the current pattern
case _ =>
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length)
pattern.append(defaultFormat)

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
}

lastDIntValue = Some(dValue)

numberFormat.applyLocalizedPattern(pattern.toString)
numberFormat.applyLocalizedPattern(pattern.toString)
}
case StringType =>
val dValue = dObject.asInstanceOf[UTF8String].toString
lastDStringValue match {
case Some(last) if last == dValue =>
case _ =>
pattern.delete(0, pattern.length)
lastDStringValue = Some(dValue)
if (dValue.toString.isEmpty) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dValue is already a string.

numberFormat.applyLocalizedPattern(defaultFormat)
} else {
numberFormat.applyLocalizedPattern(dValue)
}
}
}

x.dataType match {
Expand Down Expand Up @@ -2108,35 +2132,52 @@ case class FormatNumber(x: Expression, d: Expression)
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
// as a decimal separator.
val usLocale = "US"
val i = ctx.freshName("i")
val dFormat = ctx.freshName("dFormat")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dFormat is not used in any place. We can remove it.

val lastDValue =
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pattern is only needed for IntegerType case.

val numberFormat = ctx.addMutableState(df, "numberFormat",
v => s"""$v = new $df("", new $dfs($l.$usLocale));""")

s"""
if ($d >= 0) {
$pattern.delete(0, $pattern.length());
if ($d != $lastDValue) {
$pattern.append("#,###,###,###,###,###,##0");

if ($d > 0) {
$pattern.append(".");
for (int $i = 0; $i < $d; $i++) {
$pattern.append("0");
right.dataType match {
case IntegerType =>
val i = ctx.freshName("i")
val lastDIntValue =
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
s"""
if ($d >= 0) {
$pattern.delete(0, $pattern.length());
if ($d != $lastDIntValue) {
$pattern.append("$defaultFormat");

if ($d > 0) {
$pattern.append(".");
for (int $i = 0; $i < $d; $i++) {
$pattern.append("0");
}
}
$lastDIntValue = $d;
$numberFormat.applyLocalizedPattern($pattern.toString());
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
} else {
${ev.value} = null;
${ev.isNull} = true;
}
$lastDValue = $d;
$numberFormat.applyLocalizedPattern($pattern.toString());
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
} else {
${ev.value} = null;
${ev.isNull} = true;
}
"""
"""
case StringType =>
val lastDStringValue =
ctx.addMutableState("String", "lastDValue", v => s"""$v = "$defaultFormat";""")
s"""
if (!$d.toString().equals($lastDStringValue)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only call $d.toString() once.

$lastDStringValue = $d.toString();
if ($d.toString().isEmpty()) {
$numberFormat.applyLocalizedPattern("$defaultFormat");
} else {
$numberFormat.applyLocalizedPattern($d.toString());
}
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
"""
}
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,23 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false)

checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123")
checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4")
checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4")
checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4")
checkEvaluation(FormatNumber(Literal(12831273.23481d),
Literal("###,###,###,###,###.###")), "12,831,273.235")
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274")
checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")),
"123,123,324,123")
checkEvaluation(
FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)),
Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null)
assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false)
}

test("find in set") {
Expand Down