Skip to content

Commit

Permalink
[SPARK-42045][SQL] ANSI SQL mode: Round/Bround should return an error…
Browse files Browse the repository at this point in the history
… on integer overflow

### What changes were proposed in this pull request?

In ANSI SQL mode, Round/Bround should return an error on integer overflow.
Note this PR is for integer only. Once it is merge, I will create one follow-up PR for all the rest integral types: byte, short, and long.
Also, the function ceil and floor accepts decimal type input, so there is no need to change them.

### Why are the changes needed?

In ANSI SQL mode, integer overflow should cause error instead of returning an unreasonable result.
For example, `round(2147483647, -1)` should return error instead of returning `-2147483646`

### Does this PR introduce _any_ user-facing change?

Yes, in ANSI SQL mode, SQL function Round and Bround will return an error on integer overflow

### How was this patch tested?

UT

Closes #39546 from gengliangwang/fixRound.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
  • Loading branch information
gengliangwang committed Jan 13, 2023
1 parent 785f1bb commit 4272112
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -1447,11 +1449,13 @@ case class Logarithm(left: Expression, right: Expression)
*/
abstract class RoundBase(child: Expression, scale: Expression,
mode: BigDecimal.RoundingMode.Value, modeStr: String)
extends BinaryExpression with Serializable with ImplicitCastInputTypes {
extends BinaryExpression with Serializable with ImplicitCastInputTypes with SupportQueryContext {

override def left: Expression = child
override def right: Expression = scale

protected def ansiEnabled: Boolean = false

// round of Decimal would eval to null if it fails to `changePrecision`
override def nullable: Boolean = true

Expand Down Expand Up @@ -1501,6 +1505,14 @@ abstract class RoundBase(child: Expression, scale: Expression,
private lazy val scaleV: Any = scale.eval(EmptyRow)
protected lazy val _scale: Int = scaleV.asInstanceOf[Int]

override def initQueryContext(): Option[SQLQueryContext] = {
if (ansiEnabled) {
Some(origin.context)
} else {
None
}
}

override def eval(input: InternalRow): Any = {
if (scaleV == null) { // if scale is null, no need to eval its child at all
null
Expand Down Expand Up @@ -1529,6 +1541,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
case IntegerType if ansiEnabled =>
MathUtils.withOverflow(
f = BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toIntExact,
context = getContextOrNull)
case IntegerType =>
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
case LongType =>
Expand Down Expand Up @@ -1584,9 +1600,19 @@ abstract class RoundBase(child: Expression, scale: Expression,
}
case IntegerType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();"""
if (ansiEnabled) {
val errorContext = getContextOrNullCode(ctx)
val evalCode = s"""
|${ev.value} = new java.math.BigDecimal(${ce.value}).
|setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValueExact();
|""".stripMargin
MathUtils.withOverflowCode(evalCode, errorContext)
} else {
s"""
|${ev.value} = new java.math.BigDecimal(${ce.value}).
|setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();
|""".stripMargin
}
} else {
s"${ev.value} = ${ce.value};"
}
Expand Down Expand Up @@ -1648,9 +1674,17 @@ abstract class RoundBase(child: Expression, scale: Expression,
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Round(child: Expression, scale: Expression)
case class Round(
child: Expression,
scale: Expression,
override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") {
def this(child: Expression) = this(child, Literal(0))
def this(child: Expression) = this(child, Literal(0), SQLConf.get.ansiEnabled)

def this(child: Expression, scale: Expression) = this(child, scale, SQLConf.get.ansiEnabled)

override def flatArguments: Iterator[Any] = Iterator(child, scale)

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Round =
copy(child = newLeft, scale = newRight)
}
Expand All @@ -1673,9 +1707,17 @@ case class Round(child: Expression, scale: Expression)
since = "2.0.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class BRound(child: Expression, scale: Expression)
case class BRound(
child: Expression,
scale: Expression,
override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") {
def this(child: Expression) = this(child, Literal(0))
def this(child: Expression) = this(child, Literal(0), SQLConf.get.ansiEnabled)

def this(child: Expression, scale: Expression) = this(child, scale, SQLConf.get.ansiEnabled)

override def flatArguments: Iterator[Any] = Iterator(child, scale)

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, scale = newRight)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ object MathUtils {

def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))

private def withOverflow[A](
def withOverflow[A](
f: => A,
hint: String = "",
context: SQLQueryContext = null): A = {
Expand All @@ -86,4 +86,14 @@ object MathUtils {
throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint, context)
}
}

def withOverflowCode(evalCode: String, context: String): String = {
s"""
|try {
| $evalCode
|} catch (ArithmeticException e) {
| throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", $context);
|}
|""".stripMargin
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.time.temporal.ChronoUnit

import com.google.common.math.LongMath

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkArithmeticException, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.implicitCast
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand Down Expand Up @@ -838,6 +838,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135), Literal(-2))), Decimal(200))
}

test("SPARK-42045: integer overflow in round/bround") {
val input = 2147483647
val scale = -1
Seq(Round(input, scale, ansiEnabled = true),
BRound(input, scale, ansiEnabled = true)).foreach { expr =>
checkExceptionInExpression[SparkArithmeticException](expr, "Overflow")
}
Seq(Round(input, scale, ansiEnabled = false),
BRound(input, scale, ansiEnabled = false)).foreach { expr =>
checkEvaluation(expr, -2147483646)
}
}

test("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM") {
checkEvaluation(Signum(Literal(Period.ZERO)), 0.0)
checkEvaluation(Signum(Literal(Period.ofYears(10))), 1.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PhysicalAggregationSuite extends PlanTest {

// Verify that Round's scale parameter is a Literal.
resultExpressions(1) match {
case Alias(Round(_, _: Literal), _) =>
case Alias(Round(_, _: Literal, _), _) =>
case other => fail("unexpected result expression: " + other)
}
}
Expand Down
1 change: 1 addition & 0 deletions sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--IMPORT math.sql
17 changes: 17 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/math.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- Round with integer input
SELECT round(525, 1);
SELECT round(525, 0);
SELECT round(525, -1);
SELECT round(525, -2);
SELECT round(525, -3);
SELECT round(2147483647, -1);
SELECT round(-2147483647, -1);

-- BRound with integer input
SELECT bround(525, 1);
SELECT bround(525, 0);
SELECT bround(525, -1);
SELECT bround(525, -2);
SELECT bround(525, -3);
SELECT bround(2147483647, -1);
SELECT bround(-2147483647, -1);
175 changes: 175 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
SELECT round(525, 1)
-- !query schema
struct<round(525, 1):int>
-- !query output
525


-- !query
SELECT round(525, 0)
-- !query schema
struct<round(525, 0):int>
-- !query output
525


-- !query
SELECT round(525, -1)
-- !query schema
struct<round(525, -1):int>
-- !query output
530


-- !query
SELECT round(525, -2)
-- !query schema
struct<round(525, -2):int>
-- !query output
500


-- !query
SELECT round(525, -3)
-- !query schema
struct<round(525, -3):int>
-- !query output
1000


-- !query
SELECT round(2147483647, -1)
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkArithmeticException
{
"errorClass" : "ARITHMETIC_OVERFLOW",
"sqlState" : "22003",
"messageParameters" : {
"alternative" : "",
"config" : "\"spark.sql.ansi.enabled\"",
"message" : "Overflow"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 28,
"fragment" : "round(2147483647, -1)"
} ]
}


-- !query
SELECT round(-2147483647, -1)
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkArithmeticException
{
"errorClass" : "ARITHMETIC_OVERFLOW",
"sqlState" : "22003",
"messageParameters" : {
"alternative" : "",
"config" : "\"spark.sql.ansi.enabled\"",
"message" : "Overflow"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 29,
"fragment" : "round(-2147483647, -1)"
} ]
}


-- !query
SELECT bround(525, 1)
-- !query schema
struct<bround(525, 1):int>
-- !query output
525


-- !query
SELECT bround(525, 0)
-- !query schema
struct<bround(525, 0):int>
-- !query output
525


-- !query
SELECT bround(525, -1)
-- !query schema
struct<bround(525, -1):int>
-- !query output
520


-- !query
SELECT bround(525, -2)
-- !query schema
struct<bround(525, -2):int>
-- !query output
500


-- !query
SELECT bround(525, -3)
-- !query schema
struct<bround(525, -3):int>
-- !query output
1000


-- !query
SELECT bround(2147483647, -1)
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkArithmeticException
{
"errorClass" : "ARITHMETIC_OVERFLOW",
"sqlState" : "22003",
"messageParameters" : {
"alternative" : "",
"config" : "\"spark.sql.ansi.enabled\"",
"message" : "Overflow"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 29,
"fragment" : "bround(2147483647, -1)"
} ]
}


-- !query
SELECT bround(-2147483647, -1)
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkArithmeticException
{
"errorClass" : "ARITHMETIC_OVERFLOW",
"sqlState" : "22003",
"messageParameters" : {
"alternative" : "",
"config" : "\"spark.sql.ansi.enabled\"",
"message" : "Overflow"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 30,
"fragment" : "bround(-2147483647, -1)"
} ]
}
Loading

0 comments on commit 4272112

Please sign in to comment.