Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert String to DecimalType without casting to FloatType [databricks] #4081

Merged
merged 9 commits into from
Nov 18, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}

import org.apache.spark.sql.types.DecimalType
revans2 marked this conversation as resolved.
Show resolved Hide resolved

object FloatUtils extends Arm {

def nanToZero(cv: ColumnView): ColumnVector = {
Expand All @@ -40,8 +42,10 @@ object FloatUtils extends Arm {
def getNanScalar(dType: DType): Scalar = {
if (dType == DType.FLOAT64) {
Scalar.fromDouble(Double.NaN)
} else {
} else if (dType == DType.FLOAT32) {
Scalar.fromFloat(Float.NaN)
} else {
throw new IllegalArgumentException("NaNs are only supported for Float types")
}
}

Expand Down
156 changes: 119 additions & 37 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class CastExprMeta[INPUT <: CastBase](
"converting floating point data types to strings and this can produce results that " +
"differ from the default behavior in Spark. To enable this operation on the GPU, set" +
s" ${RapidsConf.ENABLE_CAST_FLOAT_TO_STRING} to true.")
case (_: StringType, dt: DecimalType) if (dt.precision + 1 > Decimal.MAX_LONG_DIGITS) =>
willNotWorkOnGpu(s"Converting to $dt will result in a 128-bit temporary type that is " +
revans2 marked this conversation as resolved.
Show resolved Hide resolved
s"not supported on the GPU")
case (_: StringType, _: FloatType | _: DoubleType) if !conf.isCastStringToFloatEnabled =>
willNotWorkOnGpu("Currently hex values aren't supported on the GPU. Also note " +
"that casting from string to float types on the GPU returns incorrect results when " +
Expand All @@ -97,17 +100,6 @@ class CastExprMeta[INPUT <: CastBase](
YearParseUtil.tagParseStringAsDate(conf, this)
case (_: StringType, _: DateType) =>
YearParseUtil.tagParseStringAsDate(conf, this)
case (_: StringType, _: DecimalType) if !conf.isCastStringToDecimalEnabled =>
// FIXME: https://github.com/NVIDIA/spark-rapids/issues/2019
willNotWorkOnGpu("Currently string to decimal type on the GPU might produce " +
"results which slightly differed from the correct results when the string represents " +
"any number exceeding the max precision that CAST_STRING_TO_FLOAT can keep. For " +
"instance, the GPU returns 99999999999999987 when given the input string " +
"\"99999999999999999\". The cause of divergence is that we can not cast strings " +
"containing scientific notation to decimal directly. So, we have to cast strings " +
"to floats firstly. Then, cast floats to decimals. The first step may lead to " +
"precision loss. To enable this operation on the GPU, set " +
s" ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.")
case (structType: StructType, StringType) =>
structType.foreach { field =>
recursiveTagExprForGpuCheck(field.dataType, StringType, depth + 1)
Expand Down Expand Up @@ -162,7 +154,60 @@ object GpuCast extends Arm {

val INVALID_FLOAT_CAST_MSG: String = "At least one value is either null or is an invalid number"

def sanitizeStringToFloat(input: ColumnVector, ansiEnabled: Boolean): ColumnVector = {
def sanitizeStringToDecimal(
input: ColumnView,
ansiEnabled: Boolean): ColumnVector = {

// This regex gets applied to filter out known edge cases that would result in incorrect
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the known edge cases? would be very nice to know what we have to include this as it is expensive. Looking at the code it appears that is_fixed_point cuts off early if it sees something that it does not expect, so it might be nice to have a follow on issue to actually fix that, either in CUDF or in Spark specific code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is it cutting off early?

Are you saying if I pass c = ["", "1.2", "3", ""] and if the boolean vector is initialized to true
d = c.is_fixed_point() = [false, true, true, true]

basically everything after the first value in d is bogus?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I was wrong. Reading through the code it looked like the check ignored anything after it saw something it didn't expect, but that is not true.

It looks like "1.5ABC" will result in a false being returned. Which if that is true, then I don't think we need the regular expression check at all any more. That is what triggered this? Why do we need the regexp. What "edge cases" does it cover that are not covered by the existing type check code?

Copy link
Collaborator Author

@razajafri razajafri Nov 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right we don't need the regex check anymore as the cudf is reporting everything we need. This check is still relevant in case of a float because it needs to convert the "infinity" => "inf"

// values. We further filter out invalid values using the cuDF isFixedPoint method.
val VALID_DEC_REGEX =
"^" + // start of line
"[+\\-]?" + // optional + or - at start of string
"(" +
"(" +
"(" +
"([0-9]+)|" + // digits, OR
"([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR
"([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing
")" +
"([eE][+\\-]?[0-9]+)?" + // exponent
")" +
")" +
"$" // end of line

withResource(input.strip()) { stripped =>
andygrove marked this conversation as resolved.
Show resolved Hide resolved
withResource(GpuScalar.from(null, DataTypes.StringType)) { nullString =>
// filter out strings containing breaking whitespace
val withoutWhitespace = withResource(ColumnVector.fromStrings("\r", "\n")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in ANSI mode is this not an error? Does the regular expression not match this, because it sure looks like the regexp would error out on anything that has any white space in it at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good point. ANSI doesn't like spaces, and throws an ansi exception. I will file an issue for Floats as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually an unnecessary check as \r is being checked as a string which would be caught by the regex check.

verticalWhitespace =>
withResource(stripped.contains(verticalWhitespace)) {
_.ifElse(nullString, stripped)
}
}
// filter out any strings that are not valid fixed-point numbers according
// to the regex pattern
withResource(withoutWhitespace) { _ =>
withResource(withoutWhitespace.matchesRe(VALID_DEC_REGEX)) { isFixedPoint =>
if (ansiEnabled) {
withResource(isFixedPoint.all()) { allMatch =>
// Check that all non-null values are valid floats.
if (allMatch.isValid && !allMatch.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_FLOAT_CAST_MSG)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
}
withoutWhitespace.incRefCount()
}
} else {
isFixedPoint.ifElse(withoutWhitespace, nullString)
}
}
}
}
}
}

def sanitizeStringToFloat(
input: ColumnVector,
ansiEnabled: Boolean): ColumnVector = {

// This regex gets applied after the transformation to normalize use of Inf and is
// just strict enough to filter out known edge cases that would result in incorrect
Expand All @@ -189,24 +234,24 @@ object GpuCast extends Arm {
withResource(GpuScalar.from(null, DataTypes.StringType)) { nullString =>
// filter out strings containing breaking whitespace
val withoutWhitespace = withResource(ColumnVector.fromStrings("\r", "\n")) {
verticalWhitespace =>
withResource(stripped.contains(verticalWhitespace)) {
_.ifElse(nullString, stripped)
}
verticalWhitespace =>
withResource(stripped.contains(verticalWhitespace)) {
_.ifElse(nullString, stripped)
}
}
// replace all possible versions of "Inf" and "Infinity" with "Inf"
val inf = withResource(withoutWhitespace) { _ =>
// replace all possible versions of "Inf" and "Infinity" with "Inf"
val inf = withResource(withoutWhitespace) { _ =>
withoutWhitespace.stringReplaceWithBackrefs(
"(?:[iI][nN][fF])" + "(?:[iI][nN][iI][tT][yY])?", "Inf")
}
// replace "+Inf" with "Inf" because cuDF only supports "Inf" and "-Inf"
val infWithoutPlus = withResource(inf) { _ =>
withResource(GpuScalar.from("+Inf", DataTypes.StringType)) { search =>
withResource(GpuScalar.from("Inf", DataTypes.StringType)) { replace =>
inf.stringReplace(search, replace)
"(?:[iI][nN][fF])" + "(?:[iI][nN][iI][tT][yY])?", "Inf")
}
// replace "+Inf" with "Inf" because cuDF only supports "Inf" and "-Inf"
val infWithoutPlus = withResource(inf) { _ =>
withResource(GpuScalar.from("+Inf", DataTypes.StringType)) { search =>
withResource(GpuScalar.from("Inf", DataTypes.StringType)) { replace =>
inf.stringReplace(search, replace)
}
}
}
}
// filter out any strings that are not valid floating point numbers according
// to the regex pattern
val floatOrNull = withResource(infWithoutPlus) { _ =>
Expand Down Expand Up @@ -502,13 +547,8 @@ object GpuCast extends Arm {
}
}
case (StringType, dt: DecimalType) =>
// To apply HALF_UP rounding strategy during casting to decimal, we firstly cast
// string to fp64. Then, cast fp64 to target decimal type to enforce HALF_UP rounding.
withResource(input.strip()) { trimmed =>
withResource(castStringToFloats(trimmed, ansiMode, DType.FLOAT64)) { fp =>
castFloatsToDecimal(fp, dt, ansiMode)
}
}
castStringToDecimal(input, ansiMode, dt)

case (ByteType | ShortType | IntegerType | LongType, dt: DecimalType) =>
castIntegralsToDecimal(input, dt, ansiMode)

Expand Down Expand Up @@ -777,17 +817,59 @@ object GpuCast extends Arm {
}
}

def castStringToDecimal(
input: ColumnView,
ansiEnabled: Boolean,
dt: DecimalType): ColumnVector = {
// 1. Sanitize strings to make sure all are fixed points
// 2. Identify all fixed point values
// 3. Cast String to newDt (newDt = dt. precision + 1, dt.scale + 1). Promote precision if
// needed. This step is required so we can round up if needed in the final step
// 4. Now cast newDt to dt (Decimal to Decimal)
def getInterimDecimalPromoteIfNeeded(dt: DecimalType): DecimalType = {
if (dt.precision + 1 > Decimal.MAX_LONG_DIGITS) {
//We don't support Decimal 128
throw new IllegalArgumentException("One or more values exceed the maximum supported " +
"Decimal precision while conversion")
}
DecimalType(dt.precision + 1, dt.scale + 1)
}

withResource(GpuCast.sanitizeStringToDecimal(input, ansiEnabled)) { sanitized =>
andygrove marked this conversation as resolved.
Show resolved Hide resolved
val interimSparkDt = getInterimDecimalPromoteIfNeeded(dt)
val interimDt = DecimalUtil.createCudfDecimal(interimSparkDt)
withResource(sanitized.isFixedPoint(interimDt)) { isFixedPoints =>
if (ansiEnabled) {
withResource(isFixedPoints.all()) { allFixedPoints =>
if (allFixedPoints.isValid && !allFixedPoints.getBoolean) {
throw new ArithmeticException(s"One or more values cannot be " +
s"represented as Decimal(${dt.precision}, ${dt.scale})")
}
}
}
// intermediate step needed so we can make sure we can round up
withResource(input.castTo(interimDt)) { interimDecimals =>
withResource(Scalar.fromNull(interimDt)) { nulls =>
withResource(isFixedPoints.ifElse(interimDecimals, nulls)) { decimals =>
// cast Decimal to the Decimal that's needed
castDecimalToDecimal(decimals, interimSparkDt, dt, ansiEnabled)
}
}
}
}
}
}

def castStringToFloats(
input: ColumnVector,
ansiEnabled: Boolean,
dType: DType): ColumnVector = {

// 1. convert the different infinities to "Inf"/"-Inf" which is the only variation cudf
// understands
// 2. identify the nans
// 3. identify the floats. "nan", "null" and letters are not considered floats
// 4. if ansi is enabled we want to throw and exception if the string is neither float nor nan
// 5. convert everything thats not floats to null
// 4. if ansi is enabled we want to throw an exception if the string is neither float nor nan
// 5. convert everything that's not floats to null
// 6. set the indices where we originally had nans to Float.NaN
//
// NOTE Limitation: "1.7976931348623159E308" and "-1.7976931348623159E308" are not considered
Expand Down Expand Up @@ -1146,7 +1228,7 @@ object GpuCast extends Arm {
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL64_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
val bound = math.pow(10, (dt.precision - dt.scale).min(containerScaleBound))
if (ansiMode) {
assertValuesInRange(rounded,
minValue = Scalar.fromDouble(-bound),
Expand Down
26 changes: 17 additions & 9 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -875,20 +875,28 @@ class CastOpSuite extends GpuExpressionTestSuite {
}

test("cast string to decimal") {
List(-18, -10, -3, 0, 1, 5, 15).foreach { scale =>
testCastToDecimal(DataTypes.StringType, scale,
List(-17, -10, -3, 0, 1, 5, 15).foreach { scale =>
testCastToDecimal(DataTypes.StringType, scale, precision = 17,
customRandGenerator = Some(new scala.util.Random(1234L)))
}
}

test("cast string to decimal (fail)") {
assertThrows[IllegalArgumentException](
List(-18, 18, 2, 32, 8).foreach { scale =>
testCastToDecimal(DataTypes.StringType, scale,
customRandGenerator = Some(new scala.util.Random(1234L)))
})
}

test("cast string to decimal (include NaN/INF/-INF)") {
def doubleStrings(ss: SparkSession): DataFrame = {
val df1 = floatsAsStrings(ss).selectExpr("cast(c0 as Double) as col")
val df2 = doublesAsStrings(ss).select(col("c0").as("col"))
df1.unionAll(df2)
}
List(-10, -1, 0, 1, 10).foreach { scale =>
testCastToDecimal(DataTypes.StringType, scale = scale,
testCastToDecimal(DataTypes.StringType, scale = scale, precision = 17,
customDataGenerator = Some(doubleStrings))
}
}
Expand All @@ -898,15 +906,15 @@ class CastOpSuite extends GpuExpressionTestSuite {
import ss.sqlContext.implicits._
column.toDF("col")
}
testCastToDecimal(DataTypes.StringType, scale = 7,
testCastToDecimal(DataTypes.StringType, scale = 7, precision = 17,
customDataGenerator = Some(specialGenerator(Seq("9999999999"))))
testCastToDecimal(DataTypes.StringType, scale = 2,
testCastToDecimal(DataTypes.StringType, scale = 2, precision = 17,
customDataGenerator = Some(specialGenerator(Seq("999999999999999"))))
testCastToDecimal(DataTypes.StringType, scale = 0,
testCastToDecimal(DataTypes.StringType, scale = 0, precision = 17,
customDataGenerator = Some(specialGenerator(Seq("99999999999999999"))))
testCastToDecimal(DataTypes.StringType, scale = -1,
testCastToDecimal(DataTypes.StringType, scale = -1, precision = 17,
customDataGenerator = Some(specialGenerator(Seq("99999999999999999"))))
testCastToDecimal(DataTypes.StringType, scale = -10,
testCastToDecimal(DataTypes.StringType, scale = -10, precision = 17,
customDataGenerator = Some(specialGenerator(Seq("99999999999999999"))))
}

Expand All @@ -915,7 +923,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
exponentsAsStringsDf(ss).select(col("c0").as("col"))
}
List(-10, -1, 0, 1, 10).foreach { scale =>
testCastToDecimal(DataTypes.StringType, scale = scale,
testCastToDecimal(DataTypes.StringType, scale = scale, precision = 17,
customDataGenerator = Some(exponentsAsStrings),
ansiEnabled = true)
}
Expand Down