Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Float order-by columns for RANGE window functions [databricks] #8637

Merged
merged 3 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/additional-functionality/advanced_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ Name | Description | Default Value | Applicable at
<a name="sql.variableFloatAgg.enabled"></a>spark.rapids.sql.variableFloatAgg.enabled|Spark assumes that all operations produce the exact same result each time. This is not true for some floating point aggregations, which can produce slightly different results on the GPU as the aggregation is done in parallel. This can enable those operations if you know the query is only computing it once.|true|Runtime
<a name="sql.window.range.byte.enabled"></a>spark.rapids.sql.window.range.byte.enabled|When the order-by column of a range based window is byte type and the range boundary calculated for a value has overflow, CPU and GPU will get the different results. When set to false disables the range window acceleration for the byte type order-by column|false|Runtime
<a name="sql.window.range.decimal.enabled"></a>spark.rapids.sql.window.range.decimal.enabled|When set to false, this disables the range window acceleration for the DECIMAL type order-by column|true|Runtime
<a name="sql.window.range.double.enabled"></a>spark.rapids.sql.window.range.double.enabled|When set to false, this disables the range window acceleration for the double type order-by column|true|Runtime
<a name="sql.window.range.float.enabled"></a>spark.rapids.sql.window.range.float.enabled|When set to false, this disables the range window acceleration for the FLOAT type order-by column|true|Runtime
<a name="sql.window.range.int.enabled"></a>spark.rapids.sql.window.range.int.enabled|When the order-by column of a range based window is int type and the range boundary calculated for a value has overflow, CPU and GPU will get the different results. When set to false disables the range window acceleration for the int type order-by column|true|Runtime
<a name="sql.window.range.long.enabled"></a>spark.rapids.sql.window.range.long.enabled|When the order-by column of a range based window is long type and the range boundary calculated for a value has overflow, CPU and GPU will get the different results. When set to false disables the range window acceleration for the long type order-by column|true|Runtime
<a name="sql.window.range.short.enabled"></a>spark.rapids.sql.window.range.short.enabled|When the order-by column of a range based window is short type and the range boundary calculated for a value has overflow, CPU and GPU will get the different results. When set to false disables the range window acceleration for the short type order-by column|false|Runtime
Expand Down
8 changes: 4 additions & 4 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -12859,8 +12859,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -12880,8 +12880,8 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
24 changes: 18 additions & 6 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import NumericType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
from spark_session import is_before_spark_320, is_before_spark_340, is_databricks113_or_later
from spark_session import is_before_spark_320, is_databricks113_or_later
import warnings

_grpkey_longs_with_no_nulls = [
Expand Down Expand Up @@ -74,6 +74,16 @@
('b', DecimalGen(precision=38, scale=2, nullable=True)),
('c', DecimalGen(precision=38, scale=2, nullable=True))]

_grpkey_longs_with_nullable_floats = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', FloatGen(nullable=True)),
('c', IntegerGen(nullable=True))]

_grpkey_longs_with_nullable_doubles = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', DoubleGen(nullable=True)),
('c', IntegerGen(nullable=True))]

_grpkey_decimals_with_nulls = [
('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
('b', IntegerGen()),
Expand Down Expand Up @@ -879,15 +889,17 @@ def test_window_aggs_for_ranges_timestamps(data_gen):
pytest.param(_grpkey_longs_with_nullable_largest_decimals,
marks=pytest.mark.xfail(
condition=is_databricks113_or_later(),
reason='https://github.com/NVIDIA/spark-rapids/issues/7429'))
reason='https://github.com/NVIDIA/spark-rapids/issues/7429')),
_grpkey_longs_with_nullable_floats,
_grpkey_longs_with_nullable_doubles
], ids=idfn)
def test_window_aggregations_for_decimal_ranges(data_gen):
def test_window_aggregations_for_decimal_and_float_ranges(data_gen):
"""
Tests for range window aggregations, with DECIMAL order by columns.
Tests for range window aggregations, with DECIMAL/FLOATING POINT order by columns.
The table schema used:
a: Group By column
b: Order By column (decimal)
c: Aggregation column (incidentally, also decimal)
b: Order By column (decimals, floats, doubles)
c: Aggregation column (decimals or ints)

Since this test is for the order-by column type, and not for each specific windowing aggregation,
we use COUNT(1) throughout the test, for different window widths and ordering.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -943,10 +943,12 @@ object GpuOverrides extends Logging {
TypeSig.numericAndInterval,
Seq(
ParamCheck("lower",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128,
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 +
TypeSig.FLOAT + TypeSig.DOUBLE,
TypeSig.numericAndInterval),
ParamCheck("upper",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128,
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL_128 +
TypeSig.FLOAT + TypeSig.DOUBLE,
TypeSig.numericAndInterval))),
(windowFrame, conf, p, r) => new GpuSpecifiedWindowFrameMeta(windowFrame, conf, p, r) ),
expr[WindowSpecDefinition](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,47 @@ case class BoundGpuWindowFunction(
val dataType: DataType = windowFunc.dataType
}

case class ParsedBoundary(isUnbounded: Boolean, value: Either[BigInt, Long])
/**
* Abstraction for possible range-boundary specifications.
*
* This provides type disjunction for Long, BigInt and Double,
* the three types that might represent a range boundary.
*/
abstract class RangeBoundaryValue {
def long: Long = RangeBoundaryValue.long(this)
def bigInt: BigInt = RangeBoundaryValue.bigInt(this)
def double: Double = RangeBoundaryValue.double(this)
}

case class LongRangeBoundaryValue(value: Long) extends RangeBoundaryValue
case class BigIntRangeBoundaryValue(value: BigInt) extends RangeBoundaryValue
case class DoubleRangeBoundaryValue(value: Double) extends RangeBoundaryValue

object RangeBoundaryValue {

def long(boundary: RangeBoundaryValue): Long = boundary match {
case LongRangeBoundaryValue(l) => l
case other => throw new NoSuchElementException(s"Cannot get `long` from $other")
}

def bigInt(boundary: RangeBoundaryValue): BigInt = boundary match {
case BigIntRangeBoundaryValue(b) => b
case other => throw new NoSuchElementException(s"Cannot get `bigInt` from $other")
}

def double(boundary: RangeBoundaryValue): Double = boundary match {
case DoubleRangeBoundaryValue(d) => d
case other => throw new NoSuchElementException(s"Cannot get `double` from $other")
}

def long(value: Long): LongRangeBoundaryValue = LongRangeBoundaryValue(value)

def bigInt(value: BigInt): BigIntRangeBoundaryValue = BigIntRangeBoundaryValue(value)

def double(value: Double): DoubleRangeBoundaryValue = DoubleRangeBoundaryValue(value)
}

case class ParsedBoundary(isUnbounded: Boolean, value: RangeBoundaryValue)

object GroupedAggregations {
/**
Expand Down Expand Up @@ -754,22 +794,23 @@ object GroupedAggregations {
if (bound.isUnbounded) {
None
} else {
val valueLong = bound.value.right // Used for all cases except DECIMAL128.
val s = orderByType match {
case DType.INT8 => Scalar.fromByte(valueLong.get.toByte)
case DType.INT16 => Scalar.fromShort(valueLong.get.toShort)
case DType.INT32 => Scalar.fromInt(valueLong.get.toInt)
case DType.INT64 => Scalar.fromLong(valueLong.get)
case DType.INT8 => Scalar.fromByte(bound.value.long.toByte)
case DType.INT16 => Scalar.fromShort(bound.value.long.toShort)
case DType.INT32 => Scalar.fromInt(bound.value.long.toInt)
case DType.INT64 => Scalar.fromLong(bound.value.long)
case DType.FLOAT32 => Scalar.fromFloat(bound.value.double.toFloat)
case DType.FLOAT64 => Scalar.fromDouble(bound.value.double)
// Interval is not working for DateType
case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, valueLong.get)
case DType.TIMESTAMP_DAYS => Scalar.durationFromLong(DType.DURATION_DAYS, bound.value.long)
case DType.TIMESTAMP_MICROSECONDS =>
Scalar.durationFromLong(DType.DURATION_MICROSECONDS, valueLong.get)
Scalar.durationFromLong(DType.DURATION_MICROSECONDS, bound.value.long)
case x if x.getTypeId == DType.DTypeEnum.DECIMAL32 =>
Scalar.fromDecimal(x.getScale, valueLong.get.toInt)
Scalar.fromDecimal(x.getScale, bound.value.long.toInt)
case x if x.getTypeId == DType.DTypeEnum.DECIMAL64 =>
Scalar.fromDecimal(x.getScale, valueLong.get)
Scalar.fromDecimal(x.getScale, bound.value.long)
case x if x.getTypeId == DType.DTypeEnum.DECIMAL128 =>
Scalar.fromDecimal(x.getScale, bound.value.left.get.underlying())
Scalar.fromDecimal(x.getScale, bound.value.bigInt.underlying())
case x if x.getTypeId == DType.DTypeEnum.STRING =>
// Not UNBOUNDED. The only other supported boundary for String is CURRENT ROW, i.e. 0.
Scalar.fromString("")
Expand All @@ -782,36 +823,52 @@ object GroupedAggregations {
private def getRangeBoundaryValue(boundary: Expression, orderByType: DType): ParsedBoundary =
boundary match {
case special: GpuSpecialFrameBoundary =>
val isUnBounded = special.isUnbounded
val isDecimal128 = orderByType.getTypeId == DType.DTypeEnum.DECIMAL128
ParsedBoundary(isUnBounded, if (isDecimal128) Left(special.value) else Right(special.value))
ParsedBoundary(
isUnbounded = special.isUnbounded,
value = orderByType.getTypeId match {
case DType.DTypeEnum.DECIMAL128 => RangeBoundaryValue.bigInt(special.value)
case DType.DTypeEnum.FLOAT32 | DType.DTypeEnum.FLOAT64 =>
RangeBoundaryValue.double(special.value)
case _ => RangeBoundaryValue.long(special.value)
}
)
case GpuLiteral(ci: CalendarInterval, CalendarIntervalType) =>
// Get the total microseconds for TIMESTAMP_MICROSECONDS
var x = TimeUnit.DAYS.toMicros(ci.days) + ci.microseconds
if (x == Long.MinValue) x = Long.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, ByteType) =>
var x = value.asInstanceOf[Byte]
if (x == Byte.MinValue) x = Byte.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, ShortType) =>
var x = value.asInstanceOf[Short]
if (x == Short.MinValue) x = Short.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, IntegerType) =>
var x = value.asInstanceOf[Int]
if (x == Int.MinValue) x = Int.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, LongType) =>
var x = value.asInstanceOf[Long]
if (x == Long.MinValue) x = Long.MaxValue
ParsedBoundary(isUnbounded = false, Right(Math.abs(x)))
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case GpuLiteral(value, FloatType) =>
var x = value.asInstanceOf[Float]
if (x == Float.MinValue) x = Float.MaxValue
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.double(Math.abs(x)))
case GpuLiteral(value, DoubleType) =>
var x = value.asInstanceOf[Double]
if (x == Double.MinValue) x = Double.MaxValue
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.double(Math.abs(x)))
case GpuLiteral(value: Decimal, DecimalType()) =>
orderByType.getTypeId match {
case DType.DTypeEnum.DECIMAL32 | DType.DTypeEnum.DECIMAL64 =>
ParsedBoundary(isUnbounded = false, Right(Math.abs(value.toUnscaledLong)))
ParsedBoundary(isUnbounded = false,
RangeBoundaryValue.long(Math.abs(value.toUnscaledLong)))
case DType.DTypeEnum.DECIMAL128 =>
ParsedBoundary(isUnbounded = false, Left(value.toJavaBigDecimal.unscaledValue().abs))
ParsedBoundary(isUnbounded = false,
RangeBoundaryValue.bigInt(value.toJavaBigDecimal.unscaledValue().abs))
case anythingElse =>
throw new UnsupportedOperationException(s"Unexpected Decimal type: $anythingElse")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ abstract class GpuWindowExpressionMetaBase(
val orderSpec = wrapped.windowSpec.orderSpec
if (orderSpec.length > 1) {
// We only support a single order by column
willNotWorkOnGpu("only a single date/time or integral (Boolean exclusive)" +
willNotWorkOnGpu("only a single date/time or numeric (Boolean exclusive) " +
"based column in window range functions is supported")
}
val orderByTypeSupported = orderSpec.forall { so =>
so.dataType match {
case ByteType | ShortType | IntegerType | LongType |
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
DateType | TimestampType | StringType | DecimalType() => true
case _ => false
}
Expand Down Expand Up @@ -134,6 +134,12 @@ abstract class GpuWindowExpressionMetaBase(
s"Range window frame is not 100% compatible when the order by type is " +
s"long and the range value calculated has overflow. " +
s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_LONG} to true.")
case FloatType => if (!conf.isRangeWindowFloatEnabled) willNotWorkOnGpu(
s"Range window frame is currently disabled when the order by type is float. " +
s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_FLOAT} to true.")
case DoubleType => if (!conf.isRangeWindowDoubleEnabled) willNotWorkOnGpu(
s"Range window frame is currently disabled when the order by type is double. " +
s"To enable it please set ${RapidsConf.ENABLE_RANGE_WINDOW_DOUBLE} to true.")
case DecimalType() => if (!conf.isRangeWindowDecimalEnabled) willNotWorkOnGpu(
s"To enable DECIMAL order by columns with Range window frames, " +
s"please set ${RapidsConf.ENABLE_RANGE_WINDOW_DECIMAL} to true.")
Expand All @@ -144,7 +150,7 @@ abstract class GpuWindowExpressionMetaBase(
// check whether the boundaries are supported or not.
Seq(spec.lower, spec.upper).foreach {
case l @ Literal(_, ByteType | ShortType | IntegerType |
LongType | DecimalType()) =>
LongType | FloatType | DoubleType | DecimalType()) =>
checkRangeBoundaryConfig(l.dataType)
case Literal(ci: CalendarInterval, CalendarIntervalType) =>
// interval is only working for TimeStampType
Expand Down Expand Up @@ -356,7 +362,7 @@ abstract class GpuSpecifiedWindowFrameMetaBase(
* Tag RangeFrame for other types and get the value
*/
def getAndTagOtherTypesForRangeFrame(bounds : Expression, isLower : Boolean): Long = {
willNotWorkOnGpu(s"Bounds for Range-based window frames must be specified in Integral" +
willNotWorkOnGpu(s"Bounds for Range-based window frames must be specified in numeric" +
s" type (Boolean exclusive) or CalendarInterval. Found ${bounds.dataType}")
if (isLower) -1 else 1 // not check again
}
Expand All @@ -377,36 +383,58 @@ abstract class GpuSpecifiedWindowFrameMetaBase(
return None
}

val value: BigInt = bounds match {
case Literal(value, ByteType) => value.asInstanceOf[Byte].toLong
case Literal(value, ShortType) => value.asInstanceOf[Short].toLong
case Literal(value, IntegerType) => value.asInstanceOf[Int].toLong
case Literal(value, LongType) => value.asInstanceOf[Long]
case Literal(value: Decimal, DecimalType()) => value.toJavaBigDecimal.unscaledValue()
/**
* Check bounds value relative to current row:
* 1. lower-bound should not be ahead of the current row.
* 2. upper-bound should not be behind the current row.
*/
def checkBounds[T](boundsValue: T)
(implicit ev: Numeric[T]): Option[String] = {
if (isLower && ev.compare(boundsValue, ev.zero) > 0) {
Some(s"Lower-bounds ahead of current row is not supported. Found: $boundsValue")
}
else if (!isLower && ev.compare(boundsValue, ev.zero) < 0) {
Some(s"Upper-bounds behind current row is not supported. Found: $boundsValue")
}
else {
None
}
}

bounds match {
case Literal(value, ByteType) =>
checkBounds(value.asInstanceOf[Byte].toLong)
case Literal(value, ShortType) =>
checkBounds(value.asInstanceOf[Short].toLong)
case Literal(value, IntegerType) =>
checkBounds(value.asInstanceOf[Int].toLong)
case Literal(value, LongType) =>
checkBounds(value.asInstanceOf[Long])
case Literal(value, FloatType) =>
checkBounds(value.asInstanceOf[Float])
case Literal(value, DoubleType) =>
checkBounds(value.asInstanceOf[Double])
case Literal(value: Decimal, DecimalType()) =>
checkBounds(BigInt(value.toJavaBigDecimal.unscaledValue()))
case Literal(ci: CalendarInterval, CalendarIntervalType) =>
if (ci.months != 0) {
willNotWorkOnGpu("interval months isn't supported")
}
// return the total microseconds
try {
Math.addExact(
Math.multiplyExact(ci.days.toLong, TimeUnit.DAYS.toMicros(1)),
ci.microseconds)
checkBounds(
Math.addExact(
Math.multiplyExact(ci.days.toLong, TimeUnit.DAYS.toMicros(1)),
ci.microseconds))
} catch {
case _: ArithmeticException =>
willNotWorkOnGpu("windows over timestamps are converted to microseconds " +
s"and $ci is too large to fit")
if (isLower) -1 else 1 // not check again
None
}
case _ => getAndTagOtherTypesForRangeFrame(bounds, isLower)
}

if (isLower && value > 0) {
Some(s"Lower-bounds ahead of current row is not supported. Found: $value")
} else if (!isLower && value < 0) {
Some(s"Upper-bounds behind current row is not supported. Found: $value")
} else {
None
case _ =>
getAndTagOtherTypesForRangeFrame(bounds, isLower)
None
}
}

Expand Down
Loading