Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23900][SQL] format_number support user specifed format as argument #21010

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2016,12 +2016,15 @@ case class Encode(value: Expression, charset: Expression)
usage = """
_FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2`
decimal places. If `expr2` is 0, the result has no decimal point or fractional part.
`expr2` also accept a user specified format.
This is supposed to function like MySQL's FORMAT.
""",
examples = """
Examples:
> SELECT _FUNC_(12332.123456, 4);
12,332.1235
> SELECT _FUNC_(12332.123456, '##################.###');
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update ExpressionDescription too.

12332.123
""")
case class FormatNumber(x: Expression, d: Expression)
extends BinaryExpression with ExpectsInputTypes {
Expand All @@ -2030,14 +2033,20 @@ case class FormatNumber(x: Expression, d: Expression)
override def right: Expression = d
override def dataType: DataType = StringType
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(NumericType, TypeCollection(IntegerType, StringType))

private val defaultFormat = "#,###,###,###,###,###,##0"

// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
// This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after
// serialization (numberFormat has not been updated for dValue = 0).
@transient
private var lastDValue: Option[Int] = None
private var lastDIntValue: Option[Int] = None

@transient
private var lastDStringValue: Option[String] = None

// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
Expand All @@ -2050,33 +2059,49 @@ case class FormatNumber(x: Expression, d: Expression)
private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))

override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
val dValue = dObject.asInstanceOf[Int]
if (dValue < 0) {
return null
}

lastDValue match {
case Some(last) if last == dValue =>
// use the current pattern
case _ =>
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length)
pattern.append("#,###,###,###,###,###,##0")

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
right.dataType match {
case IntegerType =>
val dValue = dObject.asInstanceOf[Int]
if (dValue < 0) {
return null
}

lastDValue = Some(dValue)
lastDIntValue match {
case Some(last) if last == dValue =>
// use the current pattern
case _ =>
// construct a new DecimalFormat only if a new dValue
pattern.delete(0, pattern.length)
pattern.append(defaultFormat)

// decimal place
if (dValue > 0) {
pattern.append(".")

var i = 0
while (i < dValue) {
i += 1
pattern.append("0")
}
}

lastDIntValue = Some(dValue)

numberFormat.applyLocalizedPattern(pattern.toString)
numberFormat.applyLocalizedPattern(pattern.toString)
}
case StringType =>
val dValue = dObject.asInstanceOf[UTF8String].toString
lastDStringValue match {
case Some(last) if last == dValue =>
case _ =>
pattern.delete(0, pattern.length)
lastDStringValue = Some(dValue)
if (dValue.isEmpty) {
numberFormat.applyLocalizedPattern(defaultFormat)
} else {
numberFormat.applyLocalizedPattern(dValue)
}
}
}

x.dataType match {
Expand Down Expand Up @@ -2108,35 +2133,53 @@ case class FormatNumber(x: Expression, d: Expression)
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
// as a decimal separator.
val usLocale = "US"
val i = ctx.freshName("i")
val dFormat = ctx.freshName("dFormat")
val lastDValue =
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
val numberFormat = ctx.addMutableState(df, "numberFormat",
v => s"""$v = new $df("", new $dfs($l.$usLocale));""")

s"""
if ($d >= 0) {
$pattern.delete(0, $pattern.length());
if ($d != $lastDValue) {
$pattern.append("#,###,###,###,###,###,##0");

if ($d > 0) {
$pattern.append(".");
for (int $i = 0; $i < $d; $i++) {
$pattern.append("0");
right.dataType match {
case IntegerType =>
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
val i = ctx.freshName("i")
val lastDIntValue =
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
s"""
if ($d >= 0) {
$pattern.delete(0, $pattern.length());
if ($d != $lastDIntValue) {
$pattern.append("$defaultFormat");

if ($d > 0) {
$pattern.append(".");
for (int $i = 0; $i < $d; $i++) {
$pattern.append("0");
}
}
$lastDIntValue = $d;
$numberFormat.applyLocalizedPattern($pattern.toString());
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
} else {
${ev.value} = null;
${ev.isNull} = true;
}
$lastDValue = $d;
$numberFormat.applyLocalizedPattern($pattern.toString());
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
} else {
${ev.value} = null;
${ev.isNull} = true;
}
"""
"""
case StringType =>
val lastDStringValue =
ctx.addMutableState("String", "lastDValue", v => s"""$v = "$defaultFormat";""")
val dValue = ctx.addMutableState("String", "dValue")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make this mutable state?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make this mutable state?

s"""
$dValue = $d.toString();
if (!$dValue.equals($lastDStringValue)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the first dValue is the same as the default format? Can you add the test case?

$lastDStringValue = $dValue;
if ($dValue.isEmpty()) {
$numberFormat.applyLocalizedPattern("$defaultFormat");
} else {
$numberFormat.applyLocalizedPattern($dValue);
}
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
"""
}
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,23 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false)

checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123")
checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4")
checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4")
checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4")
checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4")
checkEvaluation(FormatNumber(Literal(12831273.23481d),
Literal("###,###,###,###,###.###")), "12,831,273.235")
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274")
checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")),
"123,123,324,123")
checkEvaluation(
FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)),
Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null)
assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false)
}

test("find in set") {
Expand Down