Skip to content

Commit

Permalink
Add Std dev for windowing
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri committed Oct 20, 2021
1 parent 40b35b2 commit f9ca4fb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
17 changes: 17 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def do_it(spark):
.withColumn('inc_min_c', f.min('c').over(inclusiveWindowSpec)) \
.withColumn('rank_val', f.rank().over(baseWindowSpec)) \
.withColumn('dense_rank_val', f.dense_rank().over(baseWindowSpec)) \
.withColumn('stddev_val', f.stddev('').over(baseWindowSpec)) \
.withColumn('row_num', f.row_number().over(baseWindowSpec))
assert_gpu_and_cpu_are_equal_collect(do_it, conf={'spark.rapids.sql.hasNans': 'false'})

Expand Down Expand Up @@ -905,3 +906,19 @@ def test_window_ride_along(ride_along):
' row_number() over (order by a) as row_num '
'from window_agg_table ',
conf = allow_negative_scale_of_decimal_conf)

def test_window_stddev():
window_spec_agg = Window.partitionBy('_1')
window_spec = Window.partitionBy('_1').orderBy("_2")

def do_it(spark):
data = [[1,3],[1,5],[2,3],[2,7],[9,9]]
schema=[StructField("_1", IntegerType(), True), StructField("_2", IntegerType(), True)]
df=spark.createDataFrame(SparkContext.getOrCreate().parallelize(data), StructType(schema))
return df.withColumn("row", f.row_number().over(window_spec))\
.withColumn("stddev", f.stddev("_2").over(window_spec_agg)).select("stddev")

assert_gpu_and_cpu_are_equal_collect(do_it, conf={
'spark.rapids.sql.decimalType.enabled': 'true',
'spark.rapids.sql.castDecimalToFloat.enabled': 'true'})

Original file line number Diff line number Diff line change
Expand Up @@ -3128,9 +3128,20 @@ object GpuOverrides extends Logging {
}),
expr[StddevSamp](
"Aggregation computing sample standard deviation",
ExprChecksImpl(
ExprChecks.groupByOnly(
TypeSig.DOUBLE, TypeSig.DOUBLE,
Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))),
Seq(ParamCheck("input", TypeSig.DOUBLE,
TypeSig.DOUBLE))).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
).asInstanceOf[ExprChecksImpl].contexts),
(a, conf, p, r) => new AggExprMeta[StddevSamp](a, conf, p, r) {
override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = {
val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate
Expand Down

0 comments on commit f9ca4fb

Please sign in to comment.