Skip to content

Commit

Permalink
[SPARK-42045][SQL] ANSI SQL mode: Round/Bround should return an error…
Browse files Browse the repository at this point in the history
… on tiny/small/big integer overflow

### What changes were proposed in this pull request?

Similar to #39546, this PR is to change Round/Bround to return an error on tiny/small/big integer overflow.

### Why are the changes needed?

In ANSI SQL mode, integer overflow should cause error instead of returning an unreasonable result.
For example, round(127y, -1) should return error instead of returning -126

### Does this PR introduce _any_ user-facing change?

Yes, in ANSI SQL mode, SQL function Round and Bround will return an error on tiny/small/big integer overflow

### How was this patch tested?

UTs

Closes #39557 from gengliangwang/fixRoundOtherInt.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
  • Loading branch information
gengliangwang committed Jan 14, 2023
1 parent 7f9c226 commit ba79d1a
Show file tree
Hide file tree
Showing 5 changed files with 970 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,16 @@ abstract class RoundBase(child: Expression, scale: Expression,
} else {
Decimal(decimal.toBigDecimal.setScale(_scale, mode), p, s)
}
case ByteType if ansiEnabled =>
MathUtils.withOverflow(
f = BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByteExact,
context = getContextOrNull)
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType if ansiEnabled =>
MathUtils.withOverflow(
f = BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShortExact,
context = getContextOrNull)
case ShortType =>
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
case IntegerType if ansiEnabled =>
Expand All @@ -1547,6 +1555,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
context = getContextOrNull)
case IntegerType =>
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
case LongType if ansiEnabled =>
MathUtils.withOverflow(
f = BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLongExact,
context = getContextOrNull)
case LongType =>
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLong
case FloatType =>
Expand All @@ -1569,6 +1581,26 @@ abstract class RoundBase(child: Expression, scale: Expression,
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val ce = child.genCode(ctx)

def codegenForIntegralType(dt: String): String = {
if (_scale < 0) {
if (ansiEnabled) {
val errorContext = getContextOrNullCode(ctx)
val evalCode = s"""
|${ev.value} = new java.math.BigDecimal(${ce.value}).
|setScale(${_scale}, java.math.BigDecimal.${modeStr}).${dt}ValueExact();
|""".stripMargin
MathUtils.withOverflowCode(evalCode, errorContext)
} else {
s"""
|${ev.value} = new java.math.BigDecimal(${ce.value}).
|setScale(${_scale}, java.math.BigDecimal.${modeStr}).${dt}Value();
|""".stripMargin
}
} else {
s"${ev.value} = ${ce.value};"
}
}

val evaluationCode = dataType match {
case DecimalType.Fixed(p, s) =>
if (_scale >= 0) {
Expand All @@ -1583,47 +1615,13 @@ abstract class RoundBase(child: Expression, scale: Expression,
${ev.isNull} = ${ev.value} == null;"""
}
case ByteType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
codegenForIntegralType("byte")
case ShortType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
codegenForIntegralType("short")
case IntegerType =>
if (_scale < 0) {
if (ansiEnabled) {
val errorContext = getContextOrNullCode(ctx)
val evalCode = s"""
|${ev.value} = new java.math.BigDecimal(${ce.value}).
|setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValueExact();
|""".stripMargin
MathUtils.withOverflowCode(evalCode, errorContext)
} else {
s"""
|${ev.value} = new java.math.BigDecimal(${ce.value}).
|setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();
|""".stripMargin
}
} else {
s"${ev.value} = ${ce.value};"
}
codegenForIntegralType("int")
case LongType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
codegenForIntegralType("long")
case FloatType => // if child eval to NaN or Infinity, just return it.
s"""
if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -839,15 +839,20 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("SPARK-42045: integer overflow in round/bround") {
val input = 2147483647
val scale = -1
Seq(Round(input, scale, ansiEnabled = true),
BRound(input, scale, ansiEnabled = true)).foreach { expr =>
checkExceptionInExpression[SparkArithmeticException](expr, "Overflow")
}
Seq(Round(input, scale, ansiEnabled = false),
BRound(input, scale, ansiEnabled = false)).foreach { expr =>
checkEvaluation(expr, -2147483646)
Seq(
(Byte.MaxValue, ByteType, -1, -126.toByte),
(Short.MaxValue, ShortType, -1, -32766.toShort),
(Int.MaxValue, IntegerType, -1, -2147483646),
(Long.MaxValue, LongType, -1, -9223372036854775806L)
).foreach { case (input, dt, scale, expected) =>
Seq(Round(Literal(input, dt), scale, ansiEnabled = true),
BRound(Literal(input, dt), scale, ansiEnabled = true)).foreach { expr =>
checkExceptionInExpression[SparkArithmeticException](expr, "Overflow")
}
Seq(Round(Literal(input, dt), scale, ansiEnabled = false),
BRound(Literal(input, dt), scale, ansiEnabled = false)).foreach { expr =>
checkEvaluation(expr, expected)
}
}
}

Expand Down
58 changes: 56 additions & 2 deletions sql/core/src/test/resources/sql-tests/inputs/math.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
-- Round with Byte input
SELECT round(25y, 1);
SELECT round(25y, 0);
SELECT round(25y, -1);
SELECT round(25y, -2);
SELECT round(25y, -3);
SELECT round(127y, -1);
SELECT round(-128y, -1);

-- Round with short integer input
SELECT round(525s, 1);
SELECT round(525s, 0);
SELECT round(525s, -1);
SELECT round(525s, -2);
SELECT round(525s, -3);
SELECT round(32767s, -1);
SELECT round(-32768s, -1);

-- Round with integer input
SELECT round(525, 1);
SELECT round(525, 0);
Expand All @@ -7,11 +25,47 @@ SELECT round(525, -3);
SELECT round(2147483647, -1);
SELECT round(-2147483647, -1);

-- BRound with integer input
-- Round with big integer input
SELECT round(525L, 1);
SELECT round(525L, 0);
SELECT round(525L, -1);
SELECT round(525L, -2);
SELECT round(525L, -3);
SELECT round(9223372036854775807L, -1);
SELECT round(-9223372036854775808L, -1);

-- Bround with byte input
SELECT bround(25y, 1);
SELECT bround(25y, 0);
SELECT bround(25y, -1);
SELECT bround(25y, -2);
SELECT bround(25y, -3);
SELECT bround(127y, -1);
SELECT bround(-128y, -1);

-- Bround with Short input
SELECT bround(525s, 1);
SELECT bround(525s, 0);
SELECT bround(525s, -1);
SELECT bround(525s, -2);
SELECT bround(525s, -3);
SELECT bround(32767s, -1);
SELECT bround(-32768s, -1);

-- Bround with integer input
SELECT bround(525, 1);
SELECT bround(525, 0);
SELECT bround(525, -1);
SELECT bround(525, -2);
SELECT bround(525, -3);
SELECT bround(2147483647, -1);
SELECT bround(-2147483647, -1);
SELECT bround(-2147483647, -1);

-- Bround with big integer input
SELECT bround(525L, 1);
SELECT bround(525L, 0);
SELECT bround(525L, -1);
SELECT bround(525L, -2);
SELECT bround(525L, -3);
SELECT bround(9223372036854775807L, -1);
SELECT bround(-9223372036854775808L, -1);
Loading

0 comments on commit ba79d1a

Please sign in to comment.