Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37682][SQL]Apply 'merged column' and 'bit vector' in RewriteDistinctAggregates #34953

Closed

Conversation

Flyangz
Copy link

@Flyangz Flyangz commented Dec 20, 2021

What changes were proposed in this pull request?

Adjust the grouping rules of distinctAggGroups, specifically in RewriteDistinctAggregates.groupDistinctAggExpr, so that some 'distinct' can be grouped together, and conditions(eg. CaseWhen, If) involved in them will be stored in the 'if_vector' to avoid unnecessary expanding. The 'if_vector' and 'filter_vector' introduced here can reduce the number of columns in the expand. Besides, children in distinct aggregate function with same datatype will share same project column.
Here is a example comparing the difference between the original expand rewriting and the new with 'merged column' and 'bit vector' (in sql):

SELECT
  COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_filter_cnt_dist,
  COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_filter_cnt_dist,
  COUNT(DISTINCT IF(value > 5, cat1, null)) as cat1_if_cnt_dist,
  COUNT(DISTINCT id) as id_cnt_dist,
  SUM(DISTINCT value) as id_sum_dist
FROM data
GROUP BY key

Current rule will rewrite the above sql plan to the following (pseudo) logical plan:

Aggregate(
   key = ['key]
   functions = [
       count('cat1) FILTER (WHERE (('gid = 1) AND 'max(id > 1))),
       count('(IF((value > 5), cat1, null))) FILTER (WHERE ('gid = 5)),
       count('cat2) FILTER (WHERE (('gid = 3) AND 'max(id > 2))),
       count('id) FILTER (WHERE ('gid = 2)),
       sum('value) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 'gid]
     functions = [max('id > 1), max('id > 2)]
     output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 'gid,
               'max(id > 1), 'max(id > 2)])
    Expand(
       projections = [
         ('key, 'cat1, null, null, null, null, 1, ('id > 1), null),
         ('key, null, null, null, null, 'id, 2, null, null),
         ('key, null, null, 'cat2, null, null, 3, null, ('id > 2)),
         ('key, null, 'value, null, null, null, 4, null, null),
         ('key, null, null, null, if (('value > 5)) 'cat1 else null, null, 5, null, null)
       ]
       output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id,
                 'gid, '(id > 1), '(id > 2)])
      LocalTableScan [...]

After applying 'merged column' and 'bit vector' tricks, the logical plan will become:

Aggregate(
   key = ['key]
   functions = [
       count('merged_string_1) FILTER (WHERE (('gid = 1) AND NOT (('filter_vector_1 & 1) = 0))),
       count(if (NOT (('if_vector_1 & 1) = 0)) 'merged_string_1 else null) FILTER (WHERE ('gid = 1)),
       count('merged_string_1) FILTER (WHERE (('gid = 2) AND NOT (('filter_vector_1 & 1) = 0))),
       count('merged_integer_1) FILTER (WHERE ('gid = 3)),
       sum('merged_integer_1) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'merged_string_1, 'merged_integer_1, 'gid]
     functions = [bit_or('if_vector_1),bit_or('filter_vector_1)]
     output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 'bit_or(if_vector_1), 'bit_or(filter_vector_1)])
    Expand(
       projections = [
         ('key, 'cat1, null, 1, if ('value > 5) 1 else 0, if ('id > 1) 1 else 0),
         ('key, 'cat2, null, 2, null, if ('id > 2) 1 else 0),
         ('key, null, 'id, 3, null, null),
         ('key, null, 'value, 4, null, null)
       ]
       output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 'if_vector_1, 'filter_vector_1])
      LocalTableScan [...]

Why are the changes needed?

It can save mass memory and improve performance in some cases like:

SELECT 
  count(distinct case when cond1 then col1 end),
  count(distinct case when cond2 then col1 end),
  ...
FROM data

Does this PR introduce any user-facing change?

No

How was this patch tested?

Existing test and a new UT in DataFrameAggregateSuite to test 'Vector Size larger than 64'.
I have written some SQL locally to test the correctness of the distinct calculation, but it seems difficult to cover most of the cases. Perhaps spark's existing test set will be more comprehensive, so I didn't leave it in the code.

@github-actions github-actions bot added the SQL label Dec 20, 2021
@AmplabJenkins
Copy link

Can one of the admins verify this patch?

@Flyangz Flyangz force-pushed the improve-RewriteDistinctAggregates branch from b21f6eb to 031d2f6 Compare December 22, 2021 05:49
@Flyangz Flyangz force-pushed the improve-RewriteDistinctAggregates branch from 031d2f6 to edad706 Compare December 22, 2021 05:57
@SparksFyz
Copy link

We encountered a problem when execute SQL contains multiple count distinct expressions. EXPAND operator generates huge size of data lead to running out of disk space when shuffle,
especially combined with GROUPING SET(It can generator another EXPAND operator lead to more expansion, shuffle write data size exceed 100T in some cases). This PR contains two optimizations to reduce data expansion:

  1. Merge same data type columns into one column.
  2. Resolve conditions such as case when or filter and merge conditions into a Long type BitVector column, exceed 64 will create another one.

There are two cases to help us understand two optimizations by comparing the projection for expand operator:

Op1: Merge Column. Column c1 and c2 is same type, for example String

select
  dim
  ,sum(c1) as m1
  ,count(distinct c1) as m2
  ,count(distinct c2) as m3
from table
group by dim

image
PS: Merge Columns can reduce overhead of null values, it can reduce 5% - 10% data size from our test.

Op2: BitVector

select 
  dim
  ,sum(c1) as m1
  ,count(distinct case when c1 > 1 then c2 end) as m2
  ,count(distinct case when c1 > 2 then c2 end) as m3
from table
group by dim

image
PS: This Optimization can reduce both columns and rows. In addition, d_value and c2_value can project to null when bitVector equals 0. This OP usually reduces more than 50% data size in out test.

We have tested some typical spark jobs which contain multiple count distinct from prod environment. Job stats are mentioned below:

Case 1: Simple case for only merge columns
Before the PR:
image

After the PR:
image

Case2: A litter bit complex SQL which contains more dim and more count distinct metrics:
Before the PR:
image

After the PR:
image

@Flyangz Flyangz changed the title [SPARK-37682][SQL][WIP]Apply 'merged column' and 'bit vector' in RewriteDistinctAggregates [SPARK-37682][SQL]Apply 'merged column' and 'bit vector' in RewriteDistinctAggregates Dec 23, 2021
@Flyangz
Copy link
Author

Flyangz commented Dec 23, 2021

ping @cloud-fan @maropu

@github-actions
Copy link

github-actions bot commented Apr 3, 2022

We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable.
If you'd like to revive this PR, please reopen it and ask a committer to remove the Stale tag!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants