Skip to content

Commit

Permalink
Support "WindowGroupLimit" optimization on GPU [databricks] (#10500)
Browse files Browse the repository at this point in the history
Fixes #8208.

This commit adds support for `WindowGroupLimitExec` to run on GPU.  This optimization was added in Apache Spark 3.5, to reduce the number of rows that participate in shuffles, for queries that contain filters on the result of ranking functions. For example:

```sql
SELECT foo, bar FROM (
  SELECT foo, bar, 
         RANK() OVER (PARTITION BY foo ORDER BY bar) AS rnk
  FROM mytable )
WHERE rnk < 10
```

Such a query would require a shuffle to bring all rows in a window-group to be made available in the same task.
In Spark 3.5, an optimization was added in [SPARK-37099](https://issues.apache.org/jira/browse/SPARK-37099) to take advantage of the `rnk < 10` predicate to reduce shuffle load.
Specifically, since only 9 (i.e. 10-1) ranks participate in the window function, only those many rows need be shuffled into the task, per input batch.  By pre-filtering rows that can't possibly satisfy the condition, the number of shuffled records can be reduced.

The GPU implementation (i.e. `GpuWindowGroupLimitExec`) differs slightly from the CPU implementation, because it needs to execute on the entire input column batch.  As a result, `GpuWindowGroupLimitExec` runs the rank scan on each input batch, and then filters out ranks that exceed the limit specified in the predicate (`rnk < 10`). After the shuffle, the `RANK()` is calculated again by `GpuRunningWindowExec`, to produce the final result.

The current implementation addresses `RANK()` and `DENSE_RANK` window functions.  Other ranking functions (like `ROW_NUMBER()`) can be added at a later date.

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Feb 29, 2024
1 parent 3e6840f commit f85d5ef
Show file tree
Hide file tree
Showing 5 changed files with 422 additions and 2 deletions.
82 changes: 81 additions & 1 deletion integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import DateType, TimestampType, NumericType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
from spark_session import is_before_spark_320, is_databricks113_or_later, spark_version
from spark_session import is_before_spark_320, is_before_spark_350, is_databricks113_or_later, spark_version, with_cpu_session
import warnings

_grpkey_longs_with_no_nulls = [
Expand Down Expand Up @@ -2042,6 +2042,86 @@ def assert_query_runs_on(exec, conf):
assert_query_runs_on(exec='GpuBatchedBoundedWindowExec', conf=conf_200)


@pytest.mark.skipif(condition=is_before_spark_350(),
reason="WindowGroupLimit not available for spark.version < 3.5")
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1k', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
_grpkey_longs_with_nulls,
_grpkey_longs_with_dates,
_grpkey_longs_with_nullable_dates,
_grpkey_longs_with_decimals,
_grpkey_longs_with_nullable_decimals,
pytest.param(_grpkey_longs_with_nullable_larger_decimals,
marks=pytest.mark.skipif(
condition=spark_bugs_in_decimal_sorting(),
reason='https://github.com/NVIDIA/spark-rapids/issues/7429'))
],
ids=idfn)
@pytest.mark.parametrize('rank_clause', [
'RANK() OVER (PARTITION BY a ORDER BY b) ',
'DENSE_RANK() OVER (PARTITION BY a ORDER BY b) ',
'RANK() OVER (ORDER BY a,b,c) ',
'DENSE_RANK() OVER (ORDER BY a,b,c) ',
])
def test_window_group_limits_for_ranking_functions(data_gen, batch_size, rank_clause):
"""
This test verifies that window group limits are applied for queries with ranking-function based
row filters.
This test covers RANK() and DENSE_RANK(), for window function with and without `PARTITIONED BY`
clauses.
"""
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.castFloatToDecimal.enabled': True}

query = """
SELECT * FROM (
SELECT *, {} AS rnk
FROM window_agg_table
)
WHERE rnk < 3
""".format(rank_clause)

assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=4096),
"window_agg_table",
query,
conf = conf)


@allow_non_gpu('WindowGroupLimitExec')
@pytest.mark.skipif(condition=is_before_spark_350(),
reason="WindowGroupLimit not available for spark.version < 3.5")
@ignore_order(local=True)
@approximate_float
def test_window_group_limits_fallback_for_row_number():
"""
This test verifies that window group limits are applied for queries with ranking-function based
row filters.
This test covers RANK() and DENSE_RANK(), for window function with and without `PARTITIONED BY`
clauses.
"""
conf = {'spark.rapids.sql.batchSizeBytes': '1g',
'spark.rapids.sql.castFloatToDecimal.enabled': True}

data_gen = _grpkey_longs_with_no_nulls
query = """
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS rnk
FROM window_agg_table
)
WHERE rnk < 3
"""

assert_gpu_sql_fallback_collect(
lambda spark: gen_df(spark, data_gen, length=512),
cpu_fallback_class_name="WindowGroupLimitExec",
table_name="window_agg_table",
sql=query,
conf=conf)


def test_lru_cache_datagen():
# log cache info at the end of integration tests, not related to window functions
info = gen_df_help.cache_info()
Expand Down
Loading

0 comments on commit f85d5ef

Please sign in to comment.