Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-14664][SQL] Implement DecimalAggregates optimization for Windo…
…w queries ## What changes were proposed in this pull request? This PR aims to implement decimal aggregation optimization for window queries by improving existing `DecimalAggregates`. Historically, `DecimalAggregates` optimizer is designed to transform general `sum/avg(decimal)`, but it breaks recently added windows queries like the followings. The following queries work well without the current `DecimalAggregates` optimizer. **Sum** ```scala scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").head java.lang.RuntimeException: Unsupported window function: MakeDecimal((sum(UnscaledValue(a#31)),mode=Complete,isDistinct=false),12,1) scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#23] : +- INPUT +- Window [MakeDecimal((sum(UnscaledValue(a#21)),mode=Complete,isDistinct=false),12,1) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#23] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#21] +- Scan OneRowRelation[] ``` **Average** ```scala scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").head java.lang.RuntimeException: Unsupported window function: cast(((avg(UnscaledValue(a#40)),mode=Complete,isDistinct=false) / 10.0) as decimal(6,5)) scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44] : +- INPUT +- Window [cast(((avg(UnscaledValue(a#42)),mode=Complete,isDistinct=false) / 10.0) as decimal(6,5)) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#42] +- Scan OneRowRelation[] ``` After this PR, those queries work fine and new optimized physical plans look like the followings. **Sum** ```scala scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35] : +- INPUT +- Window [MakeDecimal((sum(UnscaledValue(a#33)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),12,1) AS sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#33] +- Scan OneRowRelation[] ``` **Average** ```scala scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47] : +- INPUT +- Window [cast(((avg(UnscaledValue(a#45)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) / 10.0) as decimal(6,5)) AS avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#45] +- Scan OneRowRelation[] ``` In this PR, *SUM over window* pattern matching is based on the code of hvanhovell ; he should be credited for the work he did. ## How was this patch tested? Pass the Jenkins tests (with newly added testcases) Author: Dongjoon Hyun <[email protected]> Closes #12421 from dongjoon-hyun/SPARK-14664.
- Loading branch information