Skip to content

Commit

Permalink
Enable window-group-limit optimization on [databricks] (#10550)
Browse files Browse the repository at this point in the history
* WindowGroupLimit support for [databricks].

Fixes #10531.

This is a followup to #10500, which added support to push down window-group-limit filters before the shuffle phase.

#10500 inadvertently neglected to ensure that the optimization works on Databricks. (It turns out that window-group-limit was cherry-picked into Databricks 13.3, despite the nominal Spark version being `3.4.1`.)

This change ensures that the same optimization is available on Databricks 13.3 (and beyond).

---------

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Mar 11, 2024
1 parent 3d3ade2 commit 9cf2acb
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def is_databricks113_or_later():
def is_databricks122_or_later():
return is_databricks_version_or_later(12, 2)

def is_databricks133_or_later():
return is_databricks_version_or_later(13, 3)

def supports_delta_lake_deletion_vectors():
if is_databricks_runtime():
return is_databricks122_or_later()
Expand Down
14 changes: 8 additions & 6 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import DateType, TimestampType, NumericType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
from spark_session import is_before_spark_320, is_before_spark_350, is_databricks113_or_later, spark_version, with_cpu_session
from spark_session import is_before_spark_320, is_databricks113_or_later, is_databricks133_or_later, is_spark_350_or_later, spark_version, with_cpu_session
import warnings

_grpkey_longs_with_no_nulls = [
Expand Down Expand Up @@ -2042,8 +2042,9 @@ def assert_query_runs_on(exec, conf):
assert_query_runs_on(exec='GpuBatchedBoundedWindowExec', conf=conf_200)


@pytest.mark.skipif(condition=is_before_spark_350(),
reason="WindowGroupLimit not available for spark.version < 3.5")
@pytest.mark.skipif(condition=not (is_spark_350_or_later() or is_databricks133_or_later()),
reason="WindowGroupLimit not available for spark.version < 3.5 "
"and Databricks version < 13.3")
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1k', '1g'], ids=idfn)
Expand Down Expand Up @@ -2087,12 +2088,13 @@ def test_window_group_limits_for_ranking_functions(data_gen, batch_size, rank_cl
lambda spark: gen_df(spark, data_gen, length=4096),
"window_agg_table",
query,
conf = conf)
conf=conf)


@allow_non_gpu('WindowGroupLimitExec')
@pytest.mark.skipif(condition=is_before_spark_350(),
reason="WindowGroupLimit not available for spark.version < 3.5")
@pytest.mark.skipif(condition=not (is_spark_350_or_later() or is_databricks133_or_later()),
reason="WindowGroupLimit not available for spark.version < 3.5 "
" and Databricks version < 13.3")
@ignore_order(local=True)
@approximate_float
def test_window_group_limits_fallback_for_row_number():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

/*** spark-rapids-shim-json-lines
{"spark": "341db"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
import org.apache.spark.sql.execution.window.WindowGroupLimitExec
import org.apache.spark.sql.rapids.GpuV1WriteUtils.GpuEmpty2Null
import org.apache.spark.sql.rapids.execution.python.GpuPythonUDAF
import org.apache.spark.sql.types.StringType
Expand Down Expand Up @@ -167,7 +168,15 @@ trait Spark341PlusDBShims extends Spark332PlusDBShims {
}
).disabledByDefault("Collect Limit replacement can be slower on the GPU, if huge number " +
"of rows in a batch it could help by limiting the number of rows transferred from " +
"GPU to CPU")
"GPU to CPU"),
GpuOverrides.exec[WindowGroupLimitExec](
"Apply group-limits for row groups destined for rank-based window functions like " +
"row_number(), rank(), and dense_rank()",
ExecChecks( // Similar to WindowExec.
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(),
TypeSig.all),
(limit, conf, p, r) => new GpuWindowGroupLimitExecMeta(limit, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
Expand Down

0 comments on commit 9cf2acb

Please sign in to comment.