Skip to content

Commit

Permalink
[feature](Nereids): Optimize Query Plan by Pulling Up Join with Commo…
Browse files Browse the repository at this point in the history
…n Child from Union (apache#42033)

This pr adds rewrite rule PullUpJoinFromUnion, supports pull up join
from union all separately, separates this rule from apache#28682, deletes the
original PullUpJoinFromUnionAll rule.
  • Loading branch information
feiniaofeiafei committed Nov 11, 2024
1 parent 91d95f6 commit 80c6c52
Show file tree
Hide file tree
Showing 75 changed files with 2,242 additions and 1,711 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,26 +355,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
bottomUp(new EliminateJoinByFK()),
topDown(new EliminateJoinByUnique())
),

// this rule should be after topic "Column pruning and infer predicate"
topic("Join pull up",
topDown(
new EliminateFilter(),
new PushDownFilterThroughProject(),
new MergeProjects()
),
topDown(
new PullUpJoinFromUnionAll()
),
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new)
),

// this rule should be invoked after topic "Join pull up"
topic("eliminate Aggregate according to fd items",
topDown(new EliminateGroupByKey()),
topDown(new PushDownAggThroughJoinOnPkFk())
topDown(new PushDownAggThroughJoinOnPkFk()),
topDown(new PullUpJoinFromUnionAll())
),

topic("Limit optimization",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ public enum RuleType {

// split limit
SPLIT_LIMIT(RuleTypeClass.REWRITE),
PULL_UP_JOIN_FROM_UNIONALL(RuleTypeClass.REWRITE),
PULL_UP_JOIN_FROM_UNION_ALL(RuleTypeClass.REWRITE),
// limit push down
PUSH_LIMIT_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;

import org.junit.jupiter.api.Test;

class PullUpJoinFromUnionTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
connectContext.setDatabase("default_cluster:test");
createTables(
"CREATE TABLE IF NOT EXISTS t1 (\n"
+ " id int not null,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n",
"CREATE TABLE IF NOT EXISTS t2 (\n"
+ " id int not null,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n",
"CREATE TABLE IF NOT EXISTS t3 (\n"
+ " id int,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n"
);
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
}

@Test
void testSimple() {
String sql = "select * from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select * from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testProject() {
String sql = "select t2.id from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));

sql = "select t2.id, t1.name from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id, t1.name from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testConstant() {
String sql = "select t2.id, t1.name, 1 as id1 from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id, t1.name, 2 as id2 from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testComplexProject() {
String sql = "select t2.id + 1, t1.name + 1, 1 as id1 from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id + 1, t1.name + 1, 2 as id2 from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalUnion(), any()));
}

@Test
void testMissJoinSlot() {
String sql = "select t1.name + 1, 1 as id1 from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t1.name + 1, 2 as id2 from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalUnion(), any()));
}

@Test
void testFilter() {
String sql = "select * from t1 join t2 on t1.id = t2.id where t1.name = '' "
+ "union all "
+ "select * from t1 join t3 on t1.id = t3.id where t1.name = '' ;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));

sql = "select t2.id from t1 join t2 on t1.id = t2.id where t1.name = '' "
+ "union all "
+ "select t3.id from t1 join t3 on t1.id = t3.id where t1.name = '' ;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testMultipleJoinConditions() {
String sql = "select * from t1 join t2 on t1.id = t2.id and t1.name = t2.name "
+ "union all "
+ "select * from t1 join t3 on t1.id = t3.id and t1.name = t3.name;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testNonEqualityJoinConditions() {
String sql = "select * from t1 join t2 on t1.id < t2.id "
+ "union all "
+ "select * from t1 join t3 on t1.id < t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testSubqueries() {
String sql = "select * from t1 join (select * from t2 where t2.id > 10) s2 on t1.id = s2.id "
+ "union all "
+ "select * from t1 join (select * from t3 where t3.id > 10) s3 on t1.id = s3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalProject
------hashJoin[INNER_JOIN shuffle] hashCondition=((PULL_UP_UNIFIED_OUTPUT_ALIAS = customer.c_customer_sk)) otherCondition=() build RFs:RF2 c_customer_sk->[ss_customer_sk,ws_bill_customer_sk]
------hashJoin[INNER_JOIN shuffle] hashCondition=((ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF2 c_customer_sk->[ss_customer_sk,ws_bill_customer_sk]
--------PhysicalProject
----------PhysicalUnion
------------PhysicalProject
Expand Down
62 changes: 26 additions & 36 deletions regression-test/data/nereids_hint_tpcds_p0/shape/query14.out
Original file line number Diff line number Diff line change
Expand Up @@ -55,31 +55,21 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------PhysicalDistribute[DistributionSpecGather]
----------hashAgg[LOCAL]
------------PhysicalProject
--------------PhysicalUnion
----------------PhysicalDistribute[DistributionSpecExecutionAny]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF9 d_date_sk->[ss_sold_date_sk]
--------------hashJoin[INNER_JOIN broadcast] hashCondition=((ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF9 d_date_sk->[cs_sold_date_sk,ss_sold_date_sk,ws_sold_date_sk]
----------------PhysicalProject
------------------PhysicalUnion
--------------------PhysicalDistribute[DistributionSpecExecutionAny]
----------------------PhysicalProject
------------------------PhysicalOlapScan[store_sales] apply RFs: RF9
--------------------PhysicalDistribute[DistributionSpecExecutionAny]
----------------------PhysicalProject
------------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------------PhysicalOlapScan[date_dim]
----------------PhysicalDistribute[DistributionSpecExecutionAny]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF10 d_date_sk->[cs_sold_date_sk]
----------------------PhysicalProject
------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF10
----------------------PhysicalProject
------------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------------PhysicalOlapScan[date_dim]
----------------PhysicalDistribute[DistributionSpecExecutionAny]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF11 d_date_sk->[ws_sold_date_sk]
------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF9
--------------------PhysicalDistribute[DistributionSpecExecutionAny]
----------------------PhysicalProject
------------------------PhysicalOlapScan[web_sales] apply RFs: RF11
----------------------PhysicalProject
------------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalOlapScan[web_sales] apply RFs: RF9
----------------PhysicalProject
------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------PhysicalOlapScan[date_dim]
----PhysicalResultSink
------PhysicalTopN[MERGE_SORT]
--------PhysicalDistribute[DistributionSpecGather]
Expand All @@ -97,16 +87,16 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------PhysicalDistribute[DistributionSpecHash]
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF14 i_item_sk->[ss_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF13 ss_item_sk->[ss_item_sk]
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF12 i_item_sk->[ss_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF11 ss_item_sk->[ss_item_sk]
------------------------------------------PhysicalProject
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF12 d_date_sk->[ss_sold_date_sk]
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF10 d_date_sk->[ss_sold_date_sk]
----------------------------------------------PhysicalProject
------------------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF12 RF13 RF14
------------------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF10 RF11 RF12
----------------------------------------------PhysicalProject
------------------------------------------------filter((date_dim.d_moy = 11) and (date_dim.d_year = 2001))
--------------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF14
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF12
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[item]
----------------------------PhysicalProject
Expand All @@ -120,16 +110,16 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------PhysicalDistribute[DistributionSpecHash]
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF17 i_item_sk->[cs_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF16 ss_item_sk->[cs_item_sk]
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF15 i_item_sk->[cs_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF14 ss_item_sk->[cs_item_sk]
------------------------------------------PhysicalProject
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF15 d_date_sk->[cs_sold_date_sk]
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF13 d_date_sk->[cs_sold_date_sk]
----------------------------------------------PhysicalProject
------------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF15 RF16 RF17
------------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF13 RF14 RF15
----------------------------------------------PhysicalProject
------------------------------------------------filter((date_dim.d_moy = 11) and (date_dim.d_year = 2001))
--------------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF17
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF15
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[item]
----------------------------PhysicalProject
Expand All @@ -143,16 +133,16 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------PhysicalDistribute[DistributionSpecHash]
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF20 i_item_sk->[ss_item_sk,ws_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF19 ss_item_sk->[ws_item_sk]
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF18 i_item_sk->[ss_item_sk,ws_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF17 ss_item_sk->[ws_item_sk]
------------------------------------------PhysicalProject
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF18 d_date_sk->[ws_sold_date_sk]
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF16 d_date_sk->[ws_sold_date_sk]
----------------------------------------------PhysicalProject
------------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF18 RF19 RF20
------------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF16 RF17 RF18
----------------------------------------------PhysicalProject
------------------------------------------------filter((date_dim.d_moy = 11) and (date_dim.d_year = 2001))
--------------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF20
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF18
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[item]
----------------------------PhysicalProject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalProject
------hashJoin[INNER_JOIN shuffle] hashCondition=((PULL_UP_UNIFIED_OUTPUT_ALIAS = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk,ss_customer_sk,ws_bill_customer_sk]
------hashJoin[INNER_JOIN shuffle] hashCondition=((ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk,ss_customer_sk,ws_bill_customer_sk]
--------PhysicalProject
----------PhysicalUnion
------------PhysicalProject
Expand Down
Loading

0 comments on commit 80c6c52

Please sign in to comment.