-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23900][SQL] format_number support user specifed format as argument #21010
Changes from 2 commits
202fa3d
e273045
0bc77e8
09129bc
9ccb648
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2016,12 +2016,15 @@ case class Encode(value: Expression, charset: Expression) | |
usage = """ | ||
_FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` | ||
decimal places. If `expr2` is 0, the result has no decimal point or fractional part. | ||
`expr2` also accept a user specified format. | ||
This is supposed to function like MySQL's FORMAT. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(12332.123456, 4); | ||
12,332.1235 | ||
> SELECT _FUNC_(12332.123456, '##################.###'); | ||
12332.123 | ||
""") | ||
case class FormatNumber(x: Expression, d: Expression) | ||
extends BinaryExpression with ExpectsInputTypes { | ||
|
@@ -2030,14 +2033,20 @@ case class FormatNumber(x: Expression, d: Expression) | |
override def right: Expression = d | ||
override def dataType: DataType = StringType | ||
override def nullable: Boolean = true | ||
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) | ||
override def inputTypes: Seq[AbstractDataType] = | ||
Seq(NumericType, TypeCollection(IntegerType, StringType)) | ||
|
||
private val defaultFormat = "#,###,###,###,###,###,##0" | ||
|
||
// Associated with the pattern, for the last d value, and we will update the | ||
// pattern (DecimalFormat) once the new coming d value differ with the last one. | ||
// This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after | ||
// serialization (numberFormat has not been updated for dValue = 0). | ||
@transient | ||
private var lastDValue: Option[Int] = None | ||
private var lastDIntValue: Option[Int] = None | ||
|
||
@transient | ||
private var lastDStringValue: Option[String] = None | ||
|
||
// A cached DecimalFormat, for performance concern, we will change it | ||
// only if the d value changed. | ||
|
@@ -2050,33 +2059,49 @@ case class FormatNumber(x: Expression, d: Expression) | |
private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) | ||
|
||
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { | ||
val dValue = dObject.asInstanceOf[Int] | ||
if (dValue < 0) { | ||
return null | ||
} | ||
|
||
lastDValue match { | ||
case Some(last) if last == dValue => | ||
// use the current pattern | ||
case _ => | ||
// construct a new DecimalFormat only if a new dValue | ||
pattern.delete(0, pattern.length) | ||
pattern.append("#,###,###,###,###,###,##0") | ||
|
||
// decimal place | ||
if (dValue > 0) { | ||
pattern.append(".") | ||
|
||
var i = 0 | ||
while (i < dValue) { | ||
i += 1 | ||
pattern.append("0") | ||
} | ||
right.dataType match { | ||
case IntegerType => | ||
val dValue = dObject.asInstanceOf[Int] | ||
if (dValue < 0) { | ||
return null | ||
} | ||
|
||
lastDValue = Some(dValue) | ||
lastDIntValue match { | ||
case Some(last) if last == dValue => | ||
// use the current pattern | ||
case _ => | ||
// construct a new DecimalFormat only if a new dValue | ||
pattern.delete(0, pattern.length) | ||
pattern.append(defaultFormat) | ||
|
||
// decimal place | ||
if (dValue > 0) { | ||
pattern.append(".") | ||
|
||
var i = 0 | ||
while (i < dValue) { | ||
i += 1 | ||
pattern.append("0") | ||
} | ||
} | ||
|
||
lastDIntValue = Some(dValue) | ||
|
||
numberFormat.applyLocalizedPattern(pattern.toString) | ||
numberFormat.applyLocalizedPattern(pattern.toString) | ||
} | ||
case StringType => | ||
val dValue = dObject.asInstanceOf[UTF8String].toString | ||
lastDStringValue match { | ||
case Some(last) if last == dValue => | ||
case _ => | ||
pattern.delete(0, pattern.length) | ||
lastDStringValue = Some(dValue) | ||
if (dValue.isEmpty) { | ||
numberFormat.applyLocalizedPattern(defaultFormat) | ||
} else { | ||
numberFormat.applyLocalizedPattern(dValue) | ||
} | ||
} | ||
} | ||
|
||
x.dataType match { | ||
|
@@ -2108,35 +2133,53 @@ case class FormatNumber(x: Expression, d: Expression) | |
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') | ||
// as a decimal separator. | ||
val usLocale = "US" | ||
val i = ctx.freshName("i") | ||
val dFormat = ctx.freshName("dFormat") | ||
val lastDValue = | ||
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") | ||
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") | ||
val numberFormat = ctx.addMutableState(df, "numberFormat", | ||
v => s"""$v = new $df("", new $dfs($l.$usLocale));""") | ||
|
||
s""" | ||
if ($d >= 0) { | ||
$pattern.delete(0, $pattern.length()); | ||
if ($d != $lastDValue) { | ||
$pattern.append("#,###,###,###,###,###,##0"); | ||
|
||
if ($d > 0) { | ||
$pattern.append("."); | ||
for (int $i = 0; $i < $d; $i++) { | ||
$pattern.append("0"); | ||
right.dataType match { | ||
case IntegerType => | ||
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") | ||
val i = ctx.freshName("i") | ||
val lastDIntValue = | ||
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") | ||
s""" | ||
if ($d >= 0) { | ||
$pattern.delete(0, $pattern.length()); | ||
if ($d != $lastDIntValue) { | ||
$pattern.append("$defaultFormat"); | ||
|
||
if ($d > 0) { | ||
$pattern.append("."); | ||
for (int $i = 0; $i < $d; $i++) { | ||
$pattern.append("0"); | ||
} | ||
} | ||
$lastDIntValue = $d; | ||
$numberFormat.applyLocalizedPattern($pattern.toString()); | ||
} | ||
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); | ||
} else { | ||
${ev.value} = null; | ||
${ev.isNull} = true; | ||
} | ||
$lastDValue = $d; | ||
$numberFormat.applyLocalizedPattern($pattern.toString()); | ||
} | ||
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); | ||
} else { | ||
${ev.value} = null; | ||
${ev.isNull} = true; | ||
} | ||
""" | ||
""" | ||
case StringType => | ||
val lastDStringValue = | ||
ctx.addMutableState("String", "lastDValue", v => s"""$v = "$defaultFormat";""") | ||
val dValue = ctx.addMutableState("String", "dValue") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to make this mutable state? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to make this mutable state? |
||
s""" | ||
$dValue = $d.toString(); | ||
if (!$dValue.equals($lastDStringValue)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if the first |
||
$lastDStringValue = $dValue; | ||
if ($dValue.isEmpty()) { | ||
$numberFormat.applyLocalizedPattern("$defaultFormat"); | ||
} else { | ||
$numberFormat.applyLocalizedPattern($dValue); | ||
} | ||
} | ||
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); | ||
""" | ||
} | ||
}) | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to update
ExpressionDescription
too.